Skip to content

Commit 0ec4995

Browse files
committed
added support for LUA modules
1 parent f050b31 commit 0ec4995

File tree

11 files changed

+153
-64
lines changed

11 files changed

+153
-64
lines changed

fixtures/demo.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
local demo = require("demo_mod")
2+
3+
function main(input)
4+
return demo.Mult(5, 5)
5+
end

fixtures/module.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
local demo_mod = {} -- The main table
2+
3+
function demo_mod.Mult(a, b)
4+
return a * b
5+
end
6+
7+
return demo_mod

gen/binary_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
//go:generate genny -in=$GOFILE -out=../z_binary_test.go gen "TIn=String,Number,Bool TOut=String,Number,Bool"
1212

1313
func Test_TInTOut(t *testing.T) {
14-
m := &Module{Name: "test"}
14+
m := &NativeModule{Name: "test"}
1515
m.Register("test1", func(v TIn) (TOut, error) {
1616
return newTestValue(TypeTOut).(TOut), nil
1717
})

gen/unary_in_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
//go:generate genny -in=$GOFILE -out=../z_unary_test.go gen "TIn=String,Number,Bool"
1212

1313
func Test_In_TIn(t *testing.T) {
14-
m := &Module{Name: "test"}
14+
m := &NativeModule{Name: "test"}
1515
m.Register("test1", func(v TIn) error {
1616
return nil
1717
})
@@ -45,7 +45,7 @@ func Test_In_TIn(t *testing.T) {
4545
}
4646

4747
func Test_Out_TIn(t *testing.T) {
48-
m := &Module{Name: "test"}
48+
m := &NativeModule{Name: "test"}
4949
m.Register("test1", func() (TIn, error) {
5050
return newTestValue(TypeTIn).(TIn), nil
5151
})

json/json.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ import (
1212
"github.com/yuin/gopher-lua"
1313
)
1414

15+
var (
16+
errNested = errors.New("cannot encode recursively nested tables to JSON")
17+
errSparseArray = errors.New("cannot encode sparse array")
18+
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
19+
)
20+
1521
// Loader is the module loader function.
1622
func Loader(L *lua.LState) int {
1723
t := L.NewTable()
@@ -51,11 +57,7 @@ func apiEncode(L *lua.LState) int {
5157
return 1
5258
}
5359

54-
var (
55-
errNested = errors.New("cannot encode recursively nested tables to JSON")
56-
errSparseArray = errors.New("cannot encode sparse array")
57-
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
58-
)
60+
// --------------------------------------------------------------------
5961

6062
type invalidTypeError lua.LValueType
6163

module.go

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,43 @@ var (
1919

2020
var builtin = make(map[reflect.Type]func(interface{}) lua.LGFunction, 8)
2121

22-
// Module represents a loadable module
23-
type Module struct {
22+
// Module represents a loadable module.
23+
type Module interface {
24+
inject(state *lua.LState) error
25+
}
26+
27+
// --------------------------------------------------------------------
28+
29+
// ScriptModule represents a loadable module written in LUA itself.
30+
type ScriptModule struct {
31+
Script *Script // The script that contains the module
32+
Name string // The name of the module
33+
Version string // The module version string
34+
}
35+
36+
// Inject loads the module into the state
37+
func (m *ScriptModule) inject(runtime *lua.LState) error {
38+
39+
// Inject the prerequisite modules of the module
40+
if err := m.Script.loadModules(runtime); err != nil {
41+
return err
42+
}
43+
44+
// Push the function to the runtime
45+
codeFn := runtime.NewFunctionFromProto(m.Script.code)
46+
preload := runtime.GetField(runtime.GetField(runtime.Get(lua.EnvironIndex), "package"), "preload")
47+
if _, ok := preload.(*lua.LTable); !ok {
48+
return errors.New("package.preload must be a table")
49+
50+
}
51+
runtime.SetField(preload, m.Name, codeFn)
52+
return nil
53+
}
54+
55+
// --------------------------------------------------------------------
56+
57+
// NativeModule represents a loadable native module.
58+
type NativeModule struct {
2459
lock sync.Mutex
2560
funcs map[string]fngen
2661
Name string // The name of the module
@@ -87,7 +122,7 @@ func (g *fngen) generate() lua.LGFunction {
87122
}
88123

89124
// Register registers a function into the module.
90-
func (m *Module) Register(name string, function interface{}) error {
125+
func (m *NativeModule) Register(name string, function interface{}) error {
91126
m.lock.Lock()
92127
defer m.lock.Unlock()
93128

@@ -106,14 +141,14 @@ func (m *Module) Register(name string, function interface{}) error {
106141
}
107142

108143
// Unregister unregisters a function from the module.
109-
func (m *Module) Unregister(name string) {
144+
func (m *NativeModule) Unregister(name string) {
110145
m.lock.Lock()
111146
defer m.lock.Unlock()
112147
delete(m.funcs, name)
113148
}
114149

115150
// Inject loads the module into the state
116-
func (m *Module) inject(state *lua.LState) {
151+
func (m *NativeModule) inject(state *lua.LState) error {
117152
table := make(map[string]lua.LGFunction, len(m.funcs))
118153
for name, g := range m.funcs {
119154
table[name] = g.generate()
@@ -125,6 +160,7 @@ func (m *Module) inject(state *lua.LState) {
125160
state.Push(mod)
126161
return 1
127162
})
163+
return nil
128164
}
129165

130166
// validate validates the function type

module_test.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111
"github.com/stretchr/testify/assert"
1212
)
1313

14-
func testModule() *Module {
15-
m := &Module{
14+
func testModule() Module {
15+
m := &NativeModule{
1616
Name: "test",
1717
Version: "1.0.0",
1818
}
@@ -58,7 +58,7 @@ func Test_Sum(t *testing.T) {
5858
}
5959

6060
func Test_NotAFunc(t *testing.T) {
61-
m := &Module{
61+
m := &NativeModule{
6262
Name: "test",
6363
Version: "1.0.0",
6464
}
@@ -68,3 +68,25 @@ func Test_NotAFunc(t *testing.T) {
6868
m.Unregister("hash")
6969
assert.Equal(t, 0, len(m.funcs))
7070
}
71+
72+
func Test_ScriptModule(t *testing.T) {
73+
74+
m, err := newScript("fixtures/module.lua")
75+
assert.NoError(t, err)
76+
77+
s, err := newScript("fixtures/demo.lua", &ScriptModule{
78+
Script: m,
79+
Name: "demo_mod",
80+
Version: "1.0.0",
81+
})
82+
assert.NoError(t, err)
83+
84+
out, err := s.Run(context.Background(), 10, m)
85+
assert.NoError(t, err)
86+
assert.Equal(t, TypeNumber, out.Type())
87+
assert.Equal(t, Number(25), out.(Number))
88+
assert.Equal(t, "25", out.String())
89+
90+
err = s.Close()
91+
assert.NoError(t, err)
92+
}

script.go

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/kelindar/lua/json"
1616
lua "github.com/yuin/gopher-lua"
1717
"github.com/yuin/gopher-lua/parse"
18-
//"layeh.com/gopher-json"
1918
"layeh.com/gopher-luar"
2019
)
2120

@@ -26,15 +25,16 @@ var (
2625
// Script represents a LUA script
2726
type Script struct {
2827
lock sync.Mutex
29-
name string // The name of the script
30-
argn int // The number of arguments
31-
exec *lua.LState // The runtime for the script
32-
main *lua.LFunction // The main function
33-
mods []*Module // The injected modules
28+
name string // The name of the script
29+
argn int // The number of arguments
30+
exec *lua.LState // The runtime for the script
31+
main *lua.LFunction // The main function
32+
mods []Module // The injected modules
33+
code *lua.FunctionProto // The precompiled code
3434
}
3535

3636
// FromReader reads a script fron an io.Reader
37-
func FromReader(name string, r io.Reader, modules ...*Module) (*Script, error) {
37+
func FromReader(name string, r io.Reader, modules ...Module) (*Script, error) {
3838
script := &Script{
3939
name: name,
4040
mods: modules,
@@ -43,7 +43,7 @@ func FromReader(name string, r io.Reader, modules ...*Module) (*Script, error) {
4343
}
4444

4545
// FromString reads a script fron a string
46-
func FromString(name, code string, modules ...*Module) (*Script, error) {
46+
func FromString(name, code string, modules ...Module) (*Script, error) {
4747
return FromReader(name, bytes.NewBufferString(code), modules...)
4848
}
4949

@@ -80,40 +80,38 @@ func (s *Script) Run(ctx context.Context, args ...interface{}) (Value, error) {
8080
}
8181

8282
// Update updates the content of the script.
83-
func (s *Script) Update(r io.Reader) error {
83+
func (s *Script) Update(r io.Reader) (err error) {
8484
runtime := newVM()
85-
fn, err := s.compile(r)
85+
s.code, err = s.compile(r)
8686
if err != nil {
8787
return err
8888
}
89-
9089
// Push the function to the runtime
91-
codeFn := runtime.NewFunctionFromProto(fn)
90+
codeFn := runtime.NewFunctionFromProto(s.code)
9291
runtime.Push(codeFn)
9392

9493
// Inject the modules
95-
runtime.PreloadModule("json", json.Loader)
96-
for _, m := range s.mods {
97-
m.inject(runtime)
94+
if err := s.loadModules(runtime); err != nil {
95+
return err
9896
}
9997

10098
// Initialize by calling the script
10199
if err := runtime.PCall(0, lua.MultRet, nil); err != nil {
102100
return err
103101
}
104102

105-
// Get the main function
106-
mainFn, err := findFunction(runtime, "main")
107-
if err != nil {
108-
return err
109-
}
110-
111-
// Make sure the most recent code is present in the state
103+
// Update the fields
112104
s.lock.Lock()
113105
defer s.lock.Unlock()
114-
s.argn = int(mainFn.Proto.NumParameters)
115106
s.exec = runtime
116-
s.main = mainFn
107+
s.argn = 0
108+
s.main = nil
109+
110+
// If we have a main function, set it
111+
if mainFn, err := findFunction(runtime, "main"); err == nil {
112+
s.argn = int(mainFn.Proto.NumParameters)
113+
s.main = mainFn
114+
}
117115
return nil
118116
}
119117

@@ -129,6 +127,18 @@ func (s *Script) compile(r io.Reader) (*lua.FunctionProto, error) {
129127
return lua.Compile(chunk, s.name)
130128
}
131129

130+
// LoadModules loads in the prerequisite modules
131+
func (s *Script) loadModules(runtime *lua.LState) error {
132+
runtime.PreloadModule("json", json.Loader)
133+
for _, m := range s.mods {
134+
if err := m.inject(runtime); err != nil {
135+
return err
136+
}
137+
}
138+
139+
return nil
140+
}
141+
132142
// Close closes the script and cleanly disposes of its resources.
133143
func (s *Script) Close() error {
134144
s.lock.Lock()
@@ -140,15 +150,13 @@ func (s *Script) Close() error {
140150

141151
// newVM creates a new LUA state
142152
func newVM() *lua.LState {
143-
state := lua.NewState(lua.Options{
153+
return lua.NewState(lua.Options{
144154
RegistrySize: 1024 * 20, // this is the initial size of the registry
145155
RegistryMaxSize: 1024 * 80, // this is the maximum size that the registry can grow to. If set to `0` (the default) then the registry will not auto grow
146156
RegistryGrowStep: 32, // this is how much to step up the registry by each time it runs out of space. The default is `32`.
147157
CallStackSize: 120, // this is the maximum callstack size of this LState
148158
MinimizeStackMemory: true, // Defaults to `false` if not specified. If set, the callstack will auto grow and shrink as needed up to a max of `CallStackSize`. If not set, the callstack will be fixed at `CallStackSize`.
149159
})
150-
//state.PreloadModule("relay", vm.mod.loadModule)
151-
return state
152160
}
153161

154162
// findFunction extracts a global function

script_test.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import (
1111
"github.com/stretchr/testify/assert"
1212
)
1313

14-
func newScript(file string) (*Script, error) {
14+
func newScript(file string, mods ...Module) (*Script, error) {
1515
f, _ := os.Open(file)
16-
return FromReader("test.lua", f, testModule())
16+
mods = append(mods, testModule())
17+
return FromReader("test.lua", f, mods...)
1718
}
1819

1920
type Person struct {
@@ -24,7 +25,6 @@ type Person struct {
2425
// Benchmark_Serial/fib-8 5870025 203 ns/op 16 B/op 2 allocs/op
2526
// Benchmark_Serial/empty-8 8592448 137 ns/op 0 B/op 0 allocs/op
2627
// Benchmark_Serial/update-8 1000000 1069 ns/op 224 B/op 14 allocs/op
27-
// Benchmark_Serial/module-8 1900801 629 ns/op 160 B/op 8 allocs/op
2828
func Benchmark_Serial(b *testing.B) {
2929
b.Run("fib", func(b *testing.B) {
3030
s, _ := newScript("fixtures/fib.lua")
@@ -157,23 +157,32 @@ func Test_NoMain(t *testing.T) {
157157
}
158158

159159
{
160-
_, err := FromString("", `main = 1`)
160+
s, err := FromString("", `main = 1`)
161+
assert.NoError(t, err)
162+
163+
_, err = s.Run(context.Background())
161164
assert.Error(t, err)
162165
}
163166

164167
{
165-
_, err := FromString("", `
168+
s, err := FromString("", `
166169
function notmain()
167170
local x = 1
168171
end`)
172+
assert.NoError(t, err)
173+
174+
_, err = s.Run(context.Background())
169175
assert.Error(t, err)
170176
}
171177

172178
{
173-
_, err := FromString("", `
179+
s, err := FromString("", `
174180
function xxx()
175181
local x = 1
176182
end`)
183+
assert.NoError(t, err)
184+
185+
_, err = s.Run(context.Background())
177186
assert.Error(t, err)
178187
}
179188
}

0 commit comments

Comments
 (0)