Skip to content

Commit 5002d98

Browse files
committed
Fix get/first/last
1 parent 9f2ffce commit 5002d98

File tree

2 files changed

+114
-71
lines changed

2 files changed

+114
-71
lines changed

builtin/builtin.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,48 @@ var Functions = []*Function{
389389
{
390390
Name: "first",
391391
Func: func(args ...interface{}) (interface{}, error) {
392+
defer func() {
393+
if r := recover(); r != nil {
394+
return
395+
}
396+
}()
392397
return runtime.Fetch(args[0], 0), nil
393398
},
394-
Types: types(new(func([]interface{}) interface{})),
399+
Validate: func(args []reflect.Type) (reflect.Type, error) {
400+
if len(args) != 1 {
401+
return anyType, fmt.Errorf("invalid number of arguments for first (expected 1, got %d)", len(args))
402+
}
403+
switch kind(args[0]) {
404+
case reflect.Interface:
405+
return anyType, nil
406+
case reflect.Slice, reflect.Array:
407+
return args[0].Elem(), nil
408+
}
409+
return anyType, fmt.Errorf("cannot get first element from %s", args[0])
410+
},
395411
},
396412
{
397413
Name: "last",
398414
Func: func(args ...interface{}) (interface{}, error) {
415+
defer func() {
416+
if r := recover(); r != nil {
417+
return
418+
}
419+
}()
399420
return runtime.Fetch(args[0], -1), nil
400421
},
422+
Validate: func(args []reflect.Type) (reflect.Type, error) {
423+
if len(args) != 1 {
424+
return anyType, fmt.Errorf("invalid number of arguments for last (expected 1, got %d)", len(args))
425+
}
426+
switch kind(args[0]) {
427+
case reflect.Interface:
428+
return anyType, nil
429+
case reflect.Slice, reflect.Array:
430+
return args[0].Elem(), nil
431+
}
432+
return anyType, fmt.Errorf("cannot get last element from %s", args[0])
433+
},
401434
},
402435
{
403436
Name: "get",
@@ -414,8 +447,12 @@ var Functions = []*Function{
414447
return anyType, fmt.Errorf("invalid number of arguments for get (expected 2, got %d)", len(args))
415448
}
416449
switch kind(args[0]) {
417-
case reflect.Map, reflect.Struct, reflect.Slice, reflect.Array, reflect.Interface:
450+
case reflect.Interface:
418451
return anyType, nil
452+
case reflect.Slice, reflect.Array:
453+
return args[0].Elem(), nil
454+
case reflect.Map:
455+
return args[0].Elem(), nil
419456
}
420457
return anyType, fmt.Errorf("cannot get %s from %s", args[1], args[0])
421458
},

builtin/builtin_test.go

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,85 +6,91 @@ import (
66

77
"github.com/antonmedv/expr"
88
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
910
)
1011

11-
var tests = []struct {
12-
input string
13-
want interface{}
14-
}{
15-
{`len(1..10)`, 10},
16-
{`len({foo: 1, bar: 2})`, 2},
17-
{`len("hello")`, 5},
18-
{`abs(-5)`, 5},
19-
{`abs(.5)`, .5},
20-
{`abs(-.5)`, .5},
21-
{`int(5.5)`, 5},
22-
{`int(5)`, 5},
23-
{`int("5")`, 5},
24-
{`float(5)`, 5.0},
25-
{`float(5.5)`, 5.5},
26-
{`float("5.5")`, 5.5},
27-
{`string(5)`, "5"},
28-
{`string(5.5)`, "5.5"},
29-
{`string("5.5")`, "5.5"},
30-
{`trim(" foo ")`, "foo"},
31-
{`trim("__foo___", "_")`, "foo"},
32-
{`trimPrefix("prefix_foo", "prefix_")`, "foo"},
33-
{`trimSuffix("foo_suffix", "_suffix")`, "foo"},
34-
{`upper("foo")`, "FOO"},
35-
{`lower("FOO")`, "foo"},
36-
{`split("foo,bar,baz", ",")`, []string{"foo", "bar", "baz"}},
37-
{`splitN("foo,bar,baz", ",", 2)`, []string{"foo", "bar,baz"}},
38-
{`splitAfter("foo,bar,baz", ",")`, []string{"foo,", "bar,", "baz"}},
39-
{`splitAfterN("foo,bar,baz", ",", 2)`, []string{"foo,", "bar,baz"}},
40-
{`replace("foo,bar,baz", ",", ";")`, "foo;bar;baz"},
41-
{`replace("foo,bar,baz,goo", ",", ";", 2)`, "foo;bar;baz,goo"},
42-
{`repeat("foo", 3)`, "foofoofoo"},
43-
{`join(ArrayOfString, ",")`, "foo,bar,baz"},
44-
{`join(ArrayOfString)`, "foobarbaz"},
45-
{`join(["foo", "bar", "baz"], ",")`, "foo,bar,baz"},
46-
{`join(["foo", "bar", "baz"])`, "foobarbaz"},
47-
{`indexOf("foo,bar,baz", ",")`, 3},
48-
{`lastIndexOf("foo,bar,baz", ",")`, 7},
49-
{`hasPrefix("foo,bar,baz", "foo")`, true},
50-
{`hasSuffix("foo,bar,baz", "baz")`, true},
51-
{`max(1, 2, 3)`, 3},
52-
{`max(1.5, 2.5, 3.5)`, 3.5},
53-
{`min(1, 2, 3)`, 1},
54-
{`min(1.5, 2.5, 3.5)`, 1.5},
55-
{`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"},
56-
{`fromJSON("[1, 2, 3]")`, []interface{}{1.0, 2.0, 3.0}},
57-
{`toBase64("hello")`, "aGVsbG8="},
58-
{`fromBase64("aGVsbG8=")`, "hello"},
59-
{`now().Format("2006-01-02T15:04:05Z")`, time.Now().Format("2006-01-02T15:04:05Z")},
60-
{`duration("1h")`, time.Hour},
61-
{`date("2006-01-02T15:04:05Z")`, time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC)},
62-
{`date("2006.01.02", "2006.01.02")`, time.Date(2006, 1, 2, 0, 0, 0, 0, time.UTC)},
63-
{`first(ArrayOfString)`, "foo"},
64-
{`first(ArrayOfInt)`, 1},
65-
{`first(ArrayOfAny)`, 1},
66-
{`last(ArrayOfString)`, "baz"},
67-
{`last(ArrayOfInt)`, 3},
68-
{`last(ArrayOfAny)`, true},
69-
{`get(ArrayOfString, 1)`, "bar"},
70-
{`get(ArrayOfString, 99)`, nil},
71-
{`get(ArrayOfInt, 1)`, 2},
72-
{`get(ArrayOfInt, -1)`, 3},
73-
{`get(ArrayOfAny, 1)`, "2"},
74-
{`get({foo: 1, bar: 2}, "foo")`, 1},
75-
{`get({foo: 1, bar: 2}, "unknown")`, nil},
76-
}
77-
7812
func TestBuiltin(t *testing.T) {
13+
var tests = []struct {
14+
input string
15+
want interface{}
16+
}{
17+
{`len(1..10)`, 10},
18+
{`len({foo: 1, bar: 2})`, 2},
19+
{`len("hello")`, 5},
20+
{`abs(-5)`, 5},
21+
{`abs(.5)`, .5},
22+
{`abs(-.5)`, .5},
23+
{`int(5.5)`, 5},
24+
{`int(5)`, 5},
25+
{`int("5")`, 5},
26+
{`float(5)`, 5.0},
27+
{`float(5.5)`, 5.5},
28+
{`float("5.5")`, 5.5},
29+
{`string(5)`, "5"},
30+
{`string(5.5)`, "5.5"},
31+
{`string("5.5")`, "5.5"},
32+
{`trim(" foo ")`, "foo"},
33+
{`trim("__foo___", "_")`, "foo"},
34+
{`trimPrefix("prefix_foo", "prefix_")`, "foo"},
35+
{`trimSuffix("foo_suffix", "_suffix")`, "foo"},
36+
{`upper("foo")`, "FOO"},
37+
{`lower("FOO")`, "foo"},
38+
{`split("foo,bar,baz", ",")`, []string{"foo", "bar", "baz"}},
39+
{`splitN("foo,bar,baz", ",", 2)`, []string{"foo", "bar,baz"}},
40+
{`splitAfter("foo,bar,baz", ",")`, []string{"foo,", "bar,", "baz"}},
41+
{`splitAfterN("foo,bar,baz", ",", 2)`, []string{"foo,", "bar,baz"}},
42+
{`replace("foo,bar,baz", ",", ";")`, "foo;bar;baz"},
43+
{`replace("foo,bar,baz,goo", ",", ";", 2)`, "foo;bar;baz,goo"},
44+
{`repeat("foo", 3)`, "foofoofoo"},
45+
{`join(ArrayOfString, ",")`, "foo,bar,baz"},
46+
{`join(ArrayOfString)`, "foobarbaz"},
47+
{`join(["foo", "bar", "baz"], ",")`, "foo,bar,baz"},
48+
{`join(["foo", "bar", "baz"])`, "foobarbaz"},
49+
{`indexOf("foo,bar,baz", ",")`, 3},
50+
{`lastIndexOf("foo,bar,baz", ",")`, 7},
51+
{`hasPrefix("foo,bar,baz", "foo")`, true},
52+
{`hasSuffix("foo,bar,baz", "baz")`, true},
53+
{`max(1, 2, 3)`, 3},
54+
{`max(1.5, 2.5, 3.5)`, 3.5},
55+
{`min(1, 2, 3)`, 1},
56+
{`min(1.5, 2.5, 3.5)`, 1.5},
57+
{`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"},
58+
{`fromJSON("[1, 2, 3]")`, []interface{}{1.0, 2.0, 3.0}},
59+
{`toBase64("hello")`, "aGVsbG8="},
60+
{`fromBase64("aGVsbG8=")`, "hello"},
61+
{`now().Format("2006-01-02T15:04:05Z")`, time.Now().Format("2006-01-02T15:04:05Z")},
62+
{`duration("1h")`, time.Hour},
63+
{`date("2006-01-02T15:04:05Z")`, time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC)},
64+
{`date("2006.01.02", "2006.01.02")`, time.Date(2006, 1, 2, 0, 0, 0, 0, time.UTC)},
65+
{`first(ArrayOfString)`, "foo"},
66+
{`first(ArrayOfInt)`, 1},
67+
{`first(ArrayOfAny)`, 1},
68+
{`first([])`, nil},
69+
{`last(ArrayOfString)`, "baz"},
70+
{`last(ArrayOfInt)`, 3},
71+
{`last(ArrayOfAny)`, true},
72+
{`last([])`, nil},
73+
{`get(ArrayOfString, 1)`, "bar"},
74+
{`get(ArrayOfString, 99)`, nil},
75+
{`get(ArrayOfInt, 1)`, 2},
76+
{`get(ArrayOfInt, -1)`, 3},
77+
{`get(ArrayOfAny, 1)`, "2"},
78+
{`get({foo: 1, bar: 2}, "foo")`, 1},
79+
{`get({foo: 1, bar: 2}, "unknown")`, nil},
80+
}
81+
7982
env := map[string]interface{}{
8083
"ArrayOfString": []string{"foo", "bar", "baz"},
8184
"ArrayOfInt": []int{1, 2, 3},
8285
"ArrayOfAny": []interface{}{1, "2", true},
8386
}
8487
for _, test := range tests {
8588
t.Run(test.input, func(t *testing.T) {
86-
out, err := expr.Eval(test.input, env)
87-
assert.NoError(t, err)
89+
program, err := expr.Compile(test.input, expr.Env(env))
90+
require.NoError(t, err)
91+
92+
out, err := expr.Run(program, env)
93+
require.NoError(t, err)
8894
assert.Equal(t, test.want, out)
8995
})
9096
}

0 commit comments

Comments
 (0)