Skip to content

Commit 677dfb5

Browse files
committed
Better type checks for env in get()
1 parent be61690 commit 677dfb5

File tree

5 files changed

+141
-52
lines changed

5 files changed

+141
-52
lines changed

builtin/builtin.go

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ var Functions = []*Function{
3939
Builtin1: Len,
4040
Validate: func(args []reflect.Type) (reflect.Type, error) {
4141
if len(args) != 1 {
42-
return anyType, fmt.Errorf("invalid number of arguments for len (expected 1, got %d)", len(args))
42+
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
4343
}
4444
switch kind(args[0]) {
4545
case reflect.Array, reflect.Map, reflect.Slice, reflect.String, reflect.Interface:
@@ -53,7 +53,7 @@ var Functions = []*Function{
5353
Builtin1: Abs,
5454
Validate: func(args []reflect.Type) (reflect.Type, error) {
5555
if len(args) != 1 {
56-
return anyType, fmt.Errorf("invalid number of arguments for abs (expected 1, got %d)", len(args))
56+
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
5757
}
5858
switch kind(args[0]) {
5959
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface:
@@ -67,7 +67,7 @@ var Functions = []*Function{
6767
Builtin1: Int,
6868
Validate: func(args []reflect.Type) (reflect.Type, error) {
6969
if len(args) != 1 {
70-
return anyType, fmt.Errorf("invalid number of arguments for int (expected 1, got %d)", len(args))
70+
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
7171
}
7272
switch kind(args[0]) {
7373
case reflect.Interface:
@@ -85,7 +85,7 @@ var Functions = []*Function{
8585
Builtin1: Float,
8686
Validate: func(args []reflect.Type) (reflect.Type, error) {
8787
if len(args) != 1 {
88-
return anyType, fmt.Errorf("invalid number of arguments for float (expected 1, got %d)", len(args))
88+
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
8989
}
9090
switch kind(args[0]) {
9191
case reflect.Interface:
@@ -398,7 +398,7 @@ var Functions = []*Function{
398398
},
399399
Validate: func(args []reflect.Type) (reflect.Type, error) {
400400
if len(args) != 1 {
401-
return anyType, fmt.Errorf("invalid number of arguments for first (expected 1, got %d)", len(args))
401+
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
402402
}
403403
switch kind(args[0]) {
404404
case reflect.Interface:
@@ -421,7 +421,7 @@ var Functions = []*Function{
421421
},
422422
Validate: func(args []reflect.Type) (reflect.Type, error) {
423423
if len(args) != 1 {
424-
return anyType, fmt.Errorf("invalid number of arguments for last (expected 1, got %d)", len(args))
424+
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
425425
}
426426
switch kind(args[0]) {
427427
case reflect.Interface:
@@ -442,19 +442,5 @@ var Functions = []*Function{
442442
}()
443443
return runtime.Fetch(args[0], args[1]), nil
444444
},
445-
Validate: func(args []reflect.Type) (reflect.Type, error) {
446-
if len(args) != 2 {
447-
return anyType, fmt.Errorf("invalid number of arguments for get (expected 2, got %d)", len(args))
448-
}
449-
switch kind(args[0]) {
450-
case reflect.Interface:
451-
return anyType, nil
452-
case reflect.Slice, reflect.Array:
453-
return args[0].Elem(), nil
454-
case reflect.Map:
455-
return args[0].Elem(), nil
456-
}
457-
return anyType, fmt.Errorf("cannot get %s from %s", args[1], args[0])
458-
},
459445
},
460446
}

builtin/builtin_test.go

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
package builtin_test
22

33
import (
4+
"fmt"
5+
"reflect"
46
"testing"
57
"time"
68

79
"github.com/antonmedv/expr"
10+
"github.com/antonmedv/expr/checker"
11+
"github.com/antonmedv/expr/conf"
12+
"github.com/antonmedv/expr/parser"
813
"github.com/stretchr/testify/assert"
914
"github.com/stretchr/testify/require"
1015
)
@@ -96,32 +101,31 @@ func TestBuiltin(t *testing.T) {
96101
}
97102
}
98103

99-
var errorTests = []struct {
100-
input string
101-
err string
102-
}{
103-
{`len()`, `invalid number of arguments for len (expected 1, got 0)`},
104-
{`len(1)`, `invalid argument for len (type int)`},
105-
{`abs()`, `invalid number of arguments for abs (expected 1, got 0)`},
106-
{`abs(1, 2)`, `invalid number of arguments for abs (expected 1, got 2)`},
107-
{`abs("foo")`, `invalid argument for abs (type string)`},
108-
{`int()`, `invalid number of arguments for int (expected 1, got 0)`},
109-
{`int(1, 2)`, `invalid number of arguments for int (expected 1, got 2)`},
110-
{`float()`, `invalid number of arguments for float (expected 1, got 0)`},
111-
{`float(1, 2)`, `invalid number of arguments for float (expected 1, got 2)`},
112-
{`string(1, 2)`, `too many arguments to call string`},
113-
{`trim()`, `not enough arguments to call trim`},
114-
{`max()`, `not enough arguments to call max`},
115-
{`max(1, "2")`, `invalid argument for max (type string)`},
116-
{`min()`, `not enough arguments to call min`},
117-
{`min(1, "2")`, `invalid argument for min (type string)`},
118-
{`duration("error")`, `invalid duration`},
119-
{`date("error")`, `invalid date`},
120-
{`get()`, `invalid number of arguments for get (expected 2, got 0)`},
121-
{`get(1, 2)`, `cannot get int from int`},
122-
}
123-
124-
func TestBuiltinErrors(t *testing.T) {
104+
func TestBuiltin_errors(t *testing.T) {
105+
var errorTests = []struct {
106+
input string
107+
err string
108+
}{
109+
{`len()`, `invalid number of arguments (expected 1, got 0)`},
110+
{`len(1)`, `invalid argument for len (type int)`},
111+
{`abs()`, `invalid number of arguments (expected 1, got 0)`},
112+
{`abs(1, 2)`, `invalid number of arguments (expected 1, got 2)`},
113+
{`abs("foo")`, `invalid argument for abs (type string)`},
114+
{`int()`, `invalid number of arguments (expected 1, got 0)`},
115+
{`int(1, 2)`, `invalid number of arguments (expected 1, got 2)`},
116+
{`float()`, `invalid number of arguments (expected 1, got 0)`},
117+
{`float(1, 2)`, `invalid number of arguments (expected 1, got 2)`},
118+
{`string(1, 2)`, `too many arguments to call string`},
119+
{`trim()`, `not enough arguments to call trim`},
120+
{`max()`, `not enough arguments to call max`},
121+
{`max(1, "2")`, `invalid argument for max (type string)`},
122+
{`min()`, `not enough arguments to call min`},
123+
{`min(1, "2")`, `invalid argument for min (type string)`},
124+
{`duration("error")`, `invalid duration`},
125+
{`date("error")`, `invalid date`},
126+
{`get()`, `invalid number of arguments (expected 2, got 0)`},
127+
{`get(1, 2)`, `type int does not support indexing`},
128+
}
125129
for _, test := range errorTests {
126130
t.Run(test.input, func(t *testing.T) {
127131
_, err := expr.Eval(test.input, nil)
@@ -130,3 +134,38 @@ func TestBuiltinErrors(t *testing.T) {
130134
})
131135
}
132136
}
137+
138+
func TestBuiltin_types(t *testing.T) {
139+
env := map[string]interface{}{
140+
"num": 42,
141+
"str": "foo",
142+
"ArrayOfString": []string{"foo", "bar", "baz"},
143+
"ArrayOfInt": []int{1, 2, 3},
144+
}
145+
146+
tests := []struct {
147+
input string
148+
want reflect.Kind
149+
}{
150+
{`get(ArrayOfString, 0)`, reflect.String},
151+
{`get(ArrayOfInt, 0)`, reflect.Int},
152+
{`first(ArrayOfString)`, reflect.String},
153+
{`first(ArrayOfInt)`, reflect.Int},
154+
{`last(ArrayOfString)`, reflect.String},
155+
{`last(ArrayOfInt)`, reflect.Int},
156+
{`get(env, 'str')`, reflect.String},
157+
{`get(env, 'num')`, reflect.Int},
158+
{`get(env, 'ArrayOfString')`, reflect.Slice},
159+
}
160+
161+
for _, test := range tests {
162+
t.Run(test.input, func(t *testing.T) {
163+
tree, err := parser.Parse(test.input)
164+
require.NoError(t, err)
165+
166+
rtype, err := checker.Check(tree, conf.New(env))
167+
require.NoError(t, err)
168+
require.True(t, rtype.Kind() == test.want, fmt.Sprintf("expected %s, got %s", test.want, rtype.Kind()))
169+
})
170+
}
171+
}

checker/checker.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ func (v *visitor) NilNode(*ast.NilNode) (reflect.Type, info) {
135135
}
136136

137137
func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) {
138+
if node.Value == "env" {
139+
return mapType, info{}
140+
}
138141
if fn, ok := v.config.Functions[node.Value]; ok {
139142
// Return anyType instead of func type as we don't know the arguments yet.
140143
// The func type can be one of the fn.Types. The type will be resolved
@@ -393,7 +396,9 @@ func (v *visitor) ChainNode(node *ast.ChainNode) (reflect.Type, info) {
393396
}
394397

395398
func (v *visitor) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
399+
base, _ := v.visit(node.Node)
396400
prop, _ := v.visit(node.Property)
401+
397402
if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "env" {
398403
// If the index is a constant string, can save some
399404
// cycles later by finding the type of its referent
@@ -404,7 +409,6 @@ func (v *visitor) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
404409
}
405410
return anyType, info{}
406411
}
407-
base, _ := v.visit(node.Node)
408412

409413
if name, ok := node.Property.(*ast.StringNode); ok {
410414
if base == nil {
@@ -503,7 +507,7 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
503507

504508
if fnInfo.fn != nil {
505509
node.Func = fnInfo.fn
506-
return v.checkFunction(fnInfo.fn, node.Arguments, node)
510+
return v.checkFunction(fnInfo.fn, node, node.Arguments)
507511
}
508512

509513
fnName := "function"
@@ -639,13 +643,44 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
639643
}
640644

641645
if id, ok := builtin.Index[node.Name]; ok {
642-
return v.checkFunction(builtin.Functions[id], node.Arguments, node)
646+
switch node.Name {
647+
case "get":
648+
return v.checkBuiltinGet(node)
649+
}
650+
return v.checkFunction(builtin.Functions[id], node, node.Arguments)
643651
}
644652

645653
return v.error(node, "unknown builtin %v", node.Name)
646654
}
647655

648-
func (v *visitor) checkFunction(f *builtin.Function, arguments []ast.Node, node ast.Node) (reflect.Type, info) {
656+
func (v *visitor) checkBuiltinGet(node *ast.BuiltinNode) (reflect.Type, info) {
657+
if len(node.Arguments) != 2 {
658+
return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments))
659+
}
660+
661+
val := node.Arguments[0]
662+
prop := node.Arguments[1]
663+
if id, ok := val.(*ast.IdentifierNode); ok && id.Value == "env" {
664+
if s, ok := prop.(*ast.StringNode); ok {
665+
return v.config.Types[s.Value].Type, info{}
666+
}
667+
return anyType, info{}
668+
}
669+
670+
t, _ := v.visit(val)
671+
672+
switch t.Kind() {
673+
case reflect.Interface:
674+
return anyType, info{}
675+
case reflect.Slice, reflect.Array:
676+
return t.Elem(), info{}
677+
case reflect.Map:
678+
return t.Elem(), info{}
679+
}
680+
return v.error(val, "type %v does not support indexing", t)
681+
}
682+
683+
func (v *visitor) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) (reflect.Type, info) {
649684
if f.Validate != nil {
650685
args := make([]reflect.Type, len(arguments))
651686
for i, arg := range arguments {
@@ -656,8 +691,7 @@ func (v *visitor) checkFunction(f *builtin.Function, arguments []ast.Node, node
656691
return v.error(node, "%v", err)
657692
}
658693
return t, info{}
659-
}
660-
if len(f.Types) == 0 {
694+
} else if len(f.Types) == 0 {
661695
t, err := v.checkArguments(f.Name, functionType, false, arguments, node)
662696
if err != nil {
663697
if v.err == nil {

checker/checker_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,3 +934,31 @@ func TestCheck_do_not_override_params_for_functions(t *testing.T) {
934934
require.Contains(t, err.Error(), "cannot use int as argument")
935935
})
936936
}
937+
938+
func TestCheck_env_keyword(t *testing.T) {
939+
env := map[string]interface{}{
940+
"num": 42,
941+
"str": "foo",
942+
"name": "str",
943+
}
944+
945+
tests := []struct {
946+
input string
947+
want reflect.Kind
948+
}{
949+
{`env['str']`, reflect.String},
950+
{`env['num']`, reflect.Int},
951+
{`env[name]`, reflect.Interface},
952+
}
953+
954+
for _, test := range tests {
955+
t.Run(test.input, func(t *testing.T) {
956+
tree, err := parser.Parse(test.input)
957+
require.NoError(t, err)
958+
959+
rtype, err := checker.Check(tree, conf.New(env))
960+
require.NoError(t, err)
961+
require.True(t, rtype.Kind() == test.want, fmt.Sprintf("expected %s, got %s", test.want, rtype.Kind()))
962+
})
963+
}
964+
}

expr_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,8 @@ func TestEnv_keyword(t *testing.T) {
17041704
{"env.mylist[1]", 2},
17051705
{"env?.OtherMap?.a", "b"},
17061706
{"env?.OtherMap?.d", ""},
1707+
{"'num' in env", true},
1708+
{"get(env, 'num')", 10},
17071709
}
17081710

17091711
for _, tt := range tests {

0 commit comments

Comments
 (0)