Skip to content

Commit 48b12f1

Browse files
committed
Add tests
1 parent 0b7b47f commit 48b12f1

File tree

6 files changed

+131
-32
lines changed

6 files changed

+131
-32
lines changed

ast/node.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ type CallNode struct {
124124
Arguments []Node
125125
Typed int
126126
Fast bool
127+
Name string
127128
Func func(params ...interface{}) (interface{}, error)
128129
}
129130

checker/checker.go

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,15 @@ func (v *visitor) NilNode(*ast.NilNode) (reflect.Type, info) {
131131
}
132132

133133
func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) {
134-
if v.config.Types == nil {
135-
node.Deref = true
136-
return anyType, info{}
137-
}
138134
if fn, ok := v.config.Functions[node.Value]; ok {
139135
// Return anyType instead of func type as we don't know the arguments yet.
140136
// The func type can be one of the fn.Types. The type will be resolved
141137
// when the arguments are known in CallNode.
142138
return anyType, info{fn: fn}
143139
}
144-
if t, ok := v.config.Types[node.Value]; ok {
140+
if v.config.Types == nil {
141+
node.Deref = true
142+
} else if t, ok := v.config.Types[node.Value]; ok {
145143
if t.Ambiguous {
146144
return v.error(node, "ambiguous identifier %v", node.Value)
147145
}
@@ -152,13 +150,13 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
152150
node.FieldIndex = t.FieldIndex
153151
return d, info{method: t.Method}
154152
}
155-
if !v.config.Strict {
156-
if v.config.DefaultType != nil {
157-
return v.config.DefaultType, info{}
158-
}
159-
return anyType, info{}
153+
if v.config.Strict {
154+
return v.error(node, "unknown name %v", node.Value)
155+
}
156+
if v.config.DefaultType != nil {
157+
return v.config.DefaultType, info{}
160158
}
161-
return v.error(node, "unknown name %v", node.Value)
159+
return anyType, info{}
162160
}
163161

164162
func (v *visitor) IntegerNode(*ast.IntegerNode) (reflect.Type, info) {
@@ -475,25 +473,24 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
475473

476474
if fnInfo.fn != nil {
477475
f := fnInfo.fn
476+
node.Name = f.Name
478477
node.Func = f.Func
479478
if len(f.Types) == 0 {
480479
// No type was specified, so we assume the function returns any.
481480
return anyType, info{}
482481
}
483-
var firstErr *file.Error
482+
var lastErr *file.Error
484483
for _, t := range f.Types {
485484
outType, err := v.checkFunc(f.Name, t, false, node)
486485
if err != nil {
487-
if firstErr == nil {
488-
firstErr = err
489-
}
486+
lastErr = err
490487
continue
491488
}
492489
return outType, info{}
493490
}
494-
if firstErr != nil {
491+
if lastErr != nil {
495492
if v.err == nil {
496-
v.err = firstErr
493+
v.err = lastErr
497494
}
498495
return anyType, info{}
499496
}

checker/checker_test.go

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ func TestCheck_NoConfig(t *testing.T) {
600600
tree, err := parser.Parse(`any`)
601601
require.NoError(t, err)
602602

603-
_, err = checker.Check(tree, nil)
603+
_, err = checker.Check(tree, conf.CreateNew())
604604
assert.NoError(t, err)
605605
}
606606

@@ -801,3 +801,74 @@ func TestCheck_operator_in_works_with_interfaces(t *testing.T) {
801801
_, err = checker.Check(tree, config)
802802
require.NoError(t, err)
803803
}
804+
805+
func TestCheck_Function_types_are_checked(t *testing.T) {
806+
add := expr.Function(
807+
"add",
808+
func(p ...interface{}) (interface{}, error) {
809+
out := 0
810+
for _, each := range p {
811+
out += each.(int)
812+
}
813+
return out, nil
814+
},
815+
new(func(int) int),
816+
new(func(int, int) int),
817+
new(func(int, int, int) int),
818+
new(func(...int) int),
819+
)
820+
821+
config := conf.CreateNew()
822+
add(config)
823+
824+
tests := []string{
825+
"add(1)",
826+
"add(1, 2)",
827+
"add(1, 2, 3)",
828+
"add(1, 2, 3, 4)",
829+
}
830+
for _, test := range tests {
831+
t.Run(test, func(t *testing.T) {
832+
tree, err := parser.Parse(test)
833+
require.NoError(t, err)
834+
835+
_, err = checker.Check(tree, config)
836+
require.NoError(t, err)
837+
require.Equal(t, "add", tree.Node.(*ast.CallNode).Name)
838+
require.NotNil(t, tree.Node.(*ast.CallNode).Func)
839+
})
840+
}
841+
842+
t.Run("errors", func(t *testing.T) {
843+
tree, err := parser.Parse("add(1, '2')")
844+
require.NoError(t, err)
845+
846+
_, err = checker.Check(tree, config)
847+
require.Error(t, err)
848+
require.Equal(t, "cannot use string as argument (type int) to call add (1:8)\n | add(1, '2')\n | .......^", err.Error())
849+
})
850+
}
851+
852+
func TestCheck_Function_without_types(t *testing.T) {
853+
add := expr.Function(
854+
"add",
855+
func(p ...interface{}) (interface{}, error) {
856+
out := 0
857+
for _, each := range p {
858+
out += each.(int)
859+
}
860+
return out, nil
861+
},
862+
)
863+
864+
tree, err := parser.Parse("add(1, 2, 3)")
865+
require.NoError(t, err)
866+
867+
config := conf.CreateNew()
868+
add(config)
869+
870+
_, err = checker.Check(tree, config)
871+
require.NoError(t, err)
872+
require.Equal(t, "add", tree.Node.(*ast.CallNode).Name)
873+
require.NotNil(t, tree.Node.(*ast.CallNode).Func)
874+
}

compiler/compiler.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,19 @@ func (c *compiler) addConstant(constant interface{}) int {
126126
return p
127127
}
128128

129-
func (c *compiler) addFunction(f Function) int {
130-
addr := fmt.Sprintf("%v", f)
131-
if p, ok := c.functionsIndex[addr]; ok {
129+
func (c *compiler) addFunction(node *ast.CallNode) int {
130+
if node.Name == "" {
131+
panic("function name is empty")
132+
}
133+
if node.Func == nil {
134+
panic("function is nil")
135+
}
136+
if p, ok := c.functionsIndex[node.Name]; ok {
132137
return p
133138
}
134139
p := len(c.functions)
135-
c.functions = append(c.functions, f)
136-
c.functionsIndex[addr] = p
140+
c.functions = append(c.functions, node.Func)
141+
c.functionsIndex[node.Name] = p
137142
return p
138143
}
139144

@@ -527,15 +532,15 @@ func (c *compiler) CallNode(node *ast.CallNode) {
527532
if node.Func != nil {
528533
switch len(node.Arguments) {
529534
case 0:
530-
c.emit(OpCall0, c.addFunction(node.Func))
535+
c.emit(OpCall0, c.addFunction(node))
531536
case 1:
532-
c.emit(OpCall1, c.addFunction(node.Func))
537+
c.emit(OpCall1, c.addFunction(node))
533538
case 2:
534-
c.emit(OpCall2, c.addFunction(node.Func))
539+
c.emit(OpCall2, c.addFunction(node))
535540
case 3:
536-
c.emit(OpCall3, c.addFunction(node.Func))
541+
c.emit(OpCall3, c.addFunction(node))
537542
default:
538-
c.emit(OpLoadFunc, c.addFunction(node.Func))
543+
c.emit(OpLoadFunc, c.addFunction(node))
539544
c.emit(OpCallN, len(node.Arguments))
540545
}
541546
return

expr_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,27 @@ func TestRun_custom_func_returns_an_error_as_second_arg(t *testing.T) {
16591659
assert.Equal(t, true, out)
16601660
}
16611661

1662+
func TestFunction(t *testing.T) {
1663+
add := expr.Function(
1664+
"add",
1665+
func(p ...interface{}) (interface{}, error) {
1666+
out := 0
1667+
for _, each := range p {
1668+
out += each.(int)
1669+
}
1670+
return out, nil
1671+
},
1672+
new(func(...int) int),
1673+
)
1674+
1675+
p, err := expr.Compile(`add() + add(1) + add(1, 2) + add(1, 2, 3) + add(1, 2, 3, 4)`, add)
1676+
assert.NoError(t, err)
1677+
1678+
out, err := expr.Run(p, nil)
1679+
assert.NoError(t, err)
1680+
assert.Equal(t, 20, out)
1681+
}
1682+
16621683
// Mock types
16631684

16641685
type mockEnv struct {

vm/vm_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,15 @@ func TestRun_Helpers(t *testing.T) {
115115
"b": b,
116116
}
117117

118+
config := conf.CreateNew()
119+
118120
tree, err := parser.Parse(input)
119121
require.NoError(t, err)
120122

121-
_, err = checker.Check(tree, nil)
123+
_, err = checker.Check(tree, config)
122124
require.NoError(t, err)
123125

124-
program, err := compiler.Compile(tree, nil)
126+
program, err := compiler.Compile(tree, config)
125127
require.NoError(t, err)
126128

127129
_, err = vm.Run(program, env)
@@ -194,13 +196,15 @@ func TestRun_Helpers_Time(t *testing.T) {
194196
"b": tt.b,
195197
}
196198

199+
config := conf.CreateNew()
200+
197201
tree, err := parser.Parse(input)
198202
require.NoError(t, err)
199203

200-
_, err = checker.Check(tree, nil)
204+
_, err = checker.Check(tree, config)
201205
require.NoError(t, err)
202206

203-
program, err := compiler.Compile(tree, nil)
207+
program, err := compiler.Compile(tree, config)
204208
require.NoError(t, err)
205209

206210
got, err := vm.Run(program, env)

0 commit comments

Comments
 (0)