Skip to content

Commit 625cbab

Browse files
committed
Add abs() builtin
1 parent d3ff814 commit 625cbab

File tree

3 files changed

+133
-6
lines changed

3 files changed

+133
-6
lines changed

builtin/builtin.go

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,98 @@ import (
77
"github.com/antonmedv/expr/vm/runtime"
88
)
99

10+
var (
11+
anyType = reflect.TypeOf(new(interface{})).Elem()
12+
integerType = reflect.TypeOf(0)
13+
)
14+
1015
type Function struct {
1116
Name string
12-
Func func(params ...interface{}) (interface{}, error)
17+
Func func(args ...interface{}) (interface{}, error)
1318
Types []reflect.Type
14-
Validate func(args []reflect.Type) error
19+
Validate func(args []reflect.Type) (reflect.Type, error)
1520
}
1621

1722
var Builtins = []*Function{
1823
{
1924
Name: "len",
2025
Func: runtime.Len,
21-
Validate: func(args []reflect.Type) error {
26+
Validate: func(args []reflect.Type) (reflect.Type, error) {
27+
if len(args) != 1 {
28+
return anyType, fmt.Errorf("invalid number of arguments for len (expected 1, got %d)", len(args))
29+
}
2230
switch args[0].Kind() {
2331
case reflect.Array, reflect.Map, reflect.Slice, reflect.String, reflect.Interface:
24-
return nil
32+
return integerType, nil
33+
}
34+
return anyType, fmt.Errorf("invalid argument for len (type %s)", args[0])
35+
},
36+
},
37+
{
38+
Name: "abs",
39+
Func: func(args ...interface{}) (interface{}, error) {
40+
x := args[0]
41+
switch x.(type) {
42+
case float32:
43+
if x.(float32) < 0 {
44+
return -x.(float32), nil
45+
}
46+
case float64:
47+
if x.(float64) < 0 {
48+
return -x.(float64), nil
49+
}
50+
case int:
51+
if x.(int) < 0 {
52+
return -x.(int), nil
53+
}
54+
case int8:
55+
if x.(int8) < 0 {
56+
return -x.(int8), nil
57+
}
58+
case int16:
59+
if x.(int16) < 0 {
60+
return -x.(int16), nil
61+
}
62+
case int32:
63+
if x.(int32) < 0 {
64+
return -x.(int32), nil
65+
}
66+
case int64:
67+
if x.(int64) < 0 {
68+
return -x.(int64), nil
69+
}
70+
case uint:
71+
if x.(uint) < 0 {
72+
return -x.(uint), nil
73+
}
74+
case uint8:
75+
if x.(uint8) < 0 {
76+
return -x.(uint8), nil
77+
}
78+
case uint16:
79+
if x.(uint16) < 0 {
80+
return -x.(uint16), nil
81+
}
82+
case uint32:
83+
if x.(uint32) < 0 {
84+
return -x.(uint32), nil
85+
}
86+
case uint64:
87+
if x.(uint64) < 0 {
88+
return -x.(uint64), nil
89+
}
90+
}
91+
return nil, fmt.Errorf("invalid argument for abs (type %T)", x)
92+
},
93+
Validate: func(args []reflect.Type) (reflect.Type, error) {
94+
if len(args) != 1 {
95+
return anyType, fmt.Errorf("invalid number of arguments for abs (expected 1, got %d)", len(args))
96+
}
97+
switch args[0].Kind() {
98+
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface:
99+
return args[0], nil
25100
}
26-
return fmt.Errorf("invalid argument for len (type %s)", args[0])
101+
return anyType, fmt.Errorf("invalid argument for abs (type %s)", args[0])
27102
},
28103
},
29104
}

builtin/builtin_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package builtin_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/antonmedv/expr"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
var tests = []struct {
11+
input string
12+
want interface{}
13+
}{
14+
{`len(1..10)`, 10},
15+
{`len({foo: 1, bar: 2})`, 2},
16+
{`len("hello")`, 5},
17+
{`abs(-5)`, 5},
18+
{`abs(-.5)`, .5},
19+
}
20+
21+
func TestBuiltin(t *testing.T) {
22+
for _, test := range tests {
23+
t.Run(test.input, func(t *testing.T) {
24+
out, err := expr.Eval(test.input, nil)
25+
assert.NoError(t, err)
26+
assert.Equal(t, test.want, out)
27+
})
28+
}
29+
}
30+
31+
var errorTests = []struct {
32+
input string
33+
err string
34+
}{
35+
{`len()`, `invalid number of arguments for len (expected 1, got 0)`},
36+
{`len(1)`, `invalid argument for len (type int)`},
37+
{`abs()`, `invalid number of arguments for abs (expected 1, got 0)`},
38+
{`abs(1, 2)`, `invalid number of arguments for abs (expected 1, got 2)`},
39+
{`abs("foo")`, `invalid argument for abs (type string)`},
40+
}
41+
42+
func TestBuiltinErrors(t *testing.T) {
43+
for _, test := range errorTests {
44+
t.Run(test.input, func(t *testing.T) {
45+
_, err := expr.Eval(test.input, nil)
46+
assert.Error(t, err)
47+
assert.Contains(t, err.Error(), test.err)
48+
})
49+
}
50+
}

checker/checker.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,11 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
480480
for i, arg := range node.Arguments {
481481
args[i], _ = v.visit(arg)
482482
}
483-
if err := f.Validate(args); err != nil {
483+
t, err := f.Validate(args)
484+
if err != nil {
484485
return v.error(node, "%v", err)
485486
}
487+
return t, info{}
486488
}
487489
if len(f.Types) == 0 {
488490
t, err := v.checkFunc(f.Name, functionType, false, node)

0 commit comments

Comments
 (0)