Skip to content

Commit 4c7e9b9

Browse files
davidklassenantonmedv
authored andcommitted
Allow overloading binary node with a function from env
1 parent 28f5482 commit 4c7e9b9

File tree

7 files changed

+261
-17
lines changed

7 files changed

+261
-17
lines changed

checker/checker.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
2525
}
2626
if config != nil {
2727
v.types = config.Types
28+
v.operators = config.Operators
2829
v.expect = config.Expect
2930
}
3031

3132
t = v.visit(tree.Node)
33+
patchOperators(tree, config)
3234

3335
if v.expect != reflect.Invalid {
3436
switch v.expect {
@@ -50,6 +52,7 @@ okay:
5052

5153
type visitor struct {
5254
types conf.TypesTable
55+
operators conf.OperatorsTable
5356
expect reflect.Kind
5457
collections []reflect.Type
5558
}
@@ -165,6 +168,20 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
165168
l := v.visit(node.Left)
166169
r := v.visit(node.Right)
167170

171+
// check operator overloading
172+
if fns, ok := v.operators[node.Operator]; ok {
173+
for _, fn := range fns {
174+
fnType := v.types[fn]
175+
176+
firstArgType := fnType.Type.In(0)
177+
secondArgType := fnType.Type.In(1)
178+
179+
if l == firstArgType && r == secondArgType {
180+
return fnType.Type.Out(0)
181+
}
182+
}
183+
}
184+
168185
switch node.Operator {
169186
case "==", "!=":
170187
if isNumber(l) && isNumber(r) {

checker/patcher.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package checker
2+
3+
import (
4+
"github.com/antonmedv/expr/ast"
5+
"github.com/antonmedv/expr/internal/conf"
6+
"github.com/antonmedv/expr/parser"
7+
)
8+
9+
type operatorPatcher struct {
10+
ast.BaseVisitor
11+
ops map[string][]string
12+
types conf.TypesTable
13+
}
14+
15+
func (p *operatorPatcher) Node(node *ast.Node) {
16+
binaryNode, ok := (*node).(*ast.BinaryNode)
17+
if !ok {
18+
return
19+
}
20+
21+
fns, ok := p.ops[binaryNode.Operator]
22+
if !ok {
23+
return
24+
}
25+
26+
leftType := binaryNode.Left.GetType()
27+
rightType := binaryNode.Right.GetType()
28+
for _, fn := range fns {
29+
fnType := p.types[fn]
30+
31+
firstArgType := fnType.Type.In(0)
32+
secondArgType := fnType.Type.In(1)
33+
34+
if leftType == firstArgType && rightType == secondArgType {
35+
*node = &ast.FunctionNode{
36+
Name: fn,
37+
Arguments: []ast.Node{binaryNode.Left, binaryNode.Right},
38+
}
39+
}
40+
}
41+
}
42+
43+
func patchOperators(tree *parser.Tree, config *conf.Config) {
44+
if len(config.Operators) == 0 {
45+
return
46+
}
47+
patcher := &operatorPatcher{ops: config.Operators, types: config.Types}
48+
ast.Walk(&tree.Node, patcher)
49+
}

expr.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ func Env(i interface{}) conf.Option {
4343
}
4444
}
4545

46+
func Operator(operator string, fns ...string) conf.Option {
47+
return func(c *conf.Config) {
48+
c.Operators[operator] = append(c.Operators[operator], fns...)
49+
}
50+
}
51+
4652
// AsBool tells the compiler to expect boolean result.
4753
func AsBool() conf.Option {
4854
return func(c *conf.Config) {
@@ -66,12 +72,16 @@ func AsFloat64() conf.Option {
6672

6773
// Compile parses and compiles given input expression to bytecode program.
6874
func Compile(input string, ops ...conf.Option) (*vm.Program, error) {
69-
config := &conf.Config{}
75+
config := &conf.Config{Operators: make(map[string][]string)}
7076

7177
for _, op := range ops {
7278
op(config)
7379
}
7480

81+
if err := config.Check(); err != nil {
82+
return nil, err
83+
}
84+
7585
tree, err := parser.Parse(input)
7686
if err != nil {
7787
return nil, err

expr_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import (
55
"fmt"
66
"github.com/antonmedv/expr"
77
"github.com/antonmedv/expr/vm"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
810
"strings"
11+
"testing"
912
)
1013

1114
func ExampleEval() {
@@ -183,6 +186,49 @@ func ExampleEnv() {
183186
// Output: true
184187
}
185188

189+
func ExampleOperator() {
190+
type Place struct {
191+
Code string
192+
}
193+
type Segment struct {
194+
Origin Place
195+
}
196+
type Helpers struct {
197+
PlaceEq func(p Place, s string) bool
198+
}
199+
type Request struct {
200+
Segments []*Segment
201+
Helpers
202+
}
203+
204+
code := `Segments[0].Origin == "MOW" && PlaceEq(Segments[0].Origin, "MOW")`
205+
206+
program, err := expr.Compile(code, expr.Env(&Request{}), expr.Operator("==", "PlaceEq"))
207+
if err != nil {
208+
fmt.Printf("%v", err)
209+
return
210+
}
211+
212+
request := &Request{
213+
Segments: []*Segment{
214+
{Origin: Place{Code: "MOW"}},
215+
},
216+
Helpers: Helpers{PlaceEq: func(p Place, s string) bool {
217+
return p.Code == s
218+
}},
219+
}
220+
221+
output, err := expr.Run(program, request)
222+
if err != nil {
223+
fmt.Printf("%v", err)
224+
return
225+
}
226+
227+
fmt.Printf("%v", output)
228+
229+
// Output: true
230+
}
231+
186232
func ExampleEval_marshal() {
187233
env := map[string]int{
188234
"foo": 1,
@@ -218,3 +264,93 @@ func ExampleEval_marshal() {
218264

219265
// Output: 3
220266
}
267+
268+
func TestExpr(t *testing.T) {
269+
type mockEnv struct {
270+
One, Two, Three int
271+
IntArray []int
272+
MultiDimArray [][]int
273+
Sum func(list []int) int
274+
Inc func(int) int
275+
}
276+
277+
request := mockEnv{
278+
One: 1,
279+
Two: 2,
280+
Three: 3,
281+
IntArray: []int{1, 2, 3},
282+
MultiDimArray: [][]int{{1, 2, 3}, {1, 2, 3}},
283+
Sum: func(list []int) int {
284+
var ret int
285+
for _, el := range list {
286+
ret += el
287+
}
288+
return ret
289+
},
290+
Inc: func(a int) int { return a + 1 },
291+
}
292+
293+
tests := []struct {
294+
name string
295+
code string
296+
want interface{}
297+
}{
298+
{
299+
name: "+ operator",
300+
code: "1 + 1",
301+
want: 2,
302+
},
303+
{
304+
name: "associativity",
305+
code: "(One * Two) * Three == One * (Two * Three)",
306+
want: true,
307+
},
308+
{
309+
name: "indexing",
310+
code: "IntArray[0]",
311+
want: 1,
312+
},
313+
{
314+
name: "helpers",
315+
code: "Sum(IntArray)",
316+
want: 6,
317+
},
318+
{
319+
name: "binary with indexing",
320+
code: "IntArray[0] < IntArray[1]",
321+
want: true,
322+
},
323+
{
324+
name: "helpers with indexing",
325+
code: "Sum(MultiDimArray[0])",
326+
want: 6,
327+
},
328+
{
329+
name: "helpers with indexing in binary operations",
330+
code: "Sum(MultiDimArray[0]) + Sum(MultiDimArray[1])",
331+
want: 12,
332+
},
333+
{
334+
name: "binary operations in function call arguments",
335+
code: "Inc(IntArray[0] + IntArray[1])",
336+
want: 4,
337+
},
338+
{
339+
name: "binary operations with indexing",
340+
code: "IntArray[0] + IntArray[1]",
341+
want: 3,
342+
},
343+
}
344+
345+
for _, tt := range tests {
346+
t.Run(tt.name, func(t *testing.T) {
347+
program, err := expr.Compile(tt.code, expr.Env(mockEnv{}))
348+
require.NoError(t, err, "Compile() error")
349+
350+
got, err := expr.Run(program, request)
351+
require.NoError(t, err, "Run() error")
352+
353+
assert.Equal(t, tt.want, got)
354+
})
355+
}
356+
}

internal/conf/config.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,46 @@
11
package conf
22

3-
import "reflect"
3+
import (
4+
"fmt"
5+
"reflect"
6+
)
47

58
type Config struct {
6-
MapEnv bool
7-
Types TypesTable
8-
Expect reflect.Kind
9+
MapEnv bool
10+
Types TypesTable
11+
Operators OperatorsTable
12+
Expect reflect.Kind
13+
}
14+
15+
func New(i interface{}) *Config {
16+
var mapEnv bool
17+
if _, ok := i.(map[string]interface{}); ok {
18+
mapEnv = true
19+
}
20+
21+
return &Config{
22+
MapEnv: mapEnv,
23+
Types: CreateTypesTable(i),
24+
}
25+
}
26+
27+
// Check validates the compiler configuration.
28+
func (c *Config) Check() error {
29+
// check that all functions that define operator overloading
30+
// exist in environment and have correct signatures.
31+
for op, fns := range c.Operators {
32+
for _, fn := range fns {
33+
fnType, ok := c.Types[fn]
34+
if !ok || fnType.Type.Kind() != reflect.Func {
35+
return fmt.Errorf("function %s for %s operator does not exist in environment", fn, op)
36+
}
37+
38+
if fnType.Type.NumIn() != 2 || fnType.Type.NumOut() != 1 {
39+
return fmt.Errorf("function %s for %s operator does not have a correct signature", fn, op)
40+
}
41+
}
42+
}
43+
return nil
944
}
1045

1146
// Option for configuring config.

internal/conf/operators_table.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package conf
2+
3+
// OperatorsTable maps binary operators to corresponding list of functions.
4+
// Functions should be provided in the environment to allow operator overloading.
5+
type OperatorsTable map[string][]string

internal/conf/types_table.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,17 @@ type Tag struct {
99

1010
type TypesTable map[string]Tag
1111

12-
func New(i interface{}) *Config {
13-
var mapEnv bool
14-
if _, ok := i.(map[string]interface{}); ok {
15-
mapEnv = true
16-
}
17-
18-
return &Config{
19-
MapEnv: mapEnv,
20-
Types: CreateTypesTable(i),
21-
}
22-
}
23-
2412
// CreateTypesTable creates types table for type checks during parsing.
2513
// If struct is passed, all fields will be treated as variables,
2614
// as well as all fields of embedded structs and struct itself.
2715
//
2816
// If map is passed, all items will be treated as variables
2917
// (key as name, value as type).
3018
func CreateTypesTable(i interface{}) TypesTable {
19+
if i == nil {
20+
return nil
21+
}
22+
3123
types := make(TypesTable)
3224
v := reflect.ValueOf(i)
3325
t := reflect.TypeOf(i)

0 commit comments

Comments
 (0)