Skip to content

Commit 33a808e

Browse files
committed
Add interface support to operator overloading
1 parent 5256953 commit 33a808e

File tree

4 files changed

+76
-28
lines changed

4 files changed

+76
-28
lines changed

checker/checker.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,9 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
174174

175175
// check operator overloading
176176
if fns, ok := v.operators[node.Operator]; ok {
177-
for _, fn := range fns {
178-
fnType := v.types[fn]
179-
firstInIndex := 0
180-
if fnType.Method {
181-
firstInIndex = 1 // As first argument to method is receiver.
182-
}
183-
firstArgType := fnType.Type.In(firstInIndex)
184-
secondArgType := fnType.Type.In(firstInIndex + 1)
185-
186-
if l == firstArgType && r == secondArgType {
187-
return fnType.Type.Out(0)
188-
}
177+
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.types, l, r)
178+
if ok {
179+
return t
189180
}
190181
}
191182

checker/patcher.go

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,16 @@ func (p *operatorPatcher) Exit(node *ast.Node) {
2525

2626
leftType := binaryNode.Left.GetType()
2727
rightType := binaryNode.Right.GetType()
28-
for _, fn := range fns {
29-
fnType := p.types[fn]
30-
firstInIndex := 0
31-
if fnType.Method {
32-
firstInIndex = 1 // As first argument to method is receiver.
33-
}
34-
firstArgType := fnType.Type.In(firstInIndex)
35-
secondArgType := fnType.Type.In(firstInIndex + 1)
3628

37-
if leftType == firstArgType && rightType == secondArgType {
38-
newNode := &ast.FunctionNode{
39-
Name: fn,
40-
Arguments: []ast.Node{binaryNode.Left, binaryNode.Right},
41-
}
42-
newNode.SetType((*node).GetType())
43-
newNode.SetLocation((*node).GetLocation())
44-
*node = newNode
29+
_, fn, ok := conf.FindSuitableOperatorOverload(fns, p.types, leftType, rightType)
30+
if ok {
31+
newNode := &ast.FunctionNode{
32+
Name: fn,
33+
Arguments: []ast.Node{binaryNode.Left, binaryNode.Right},
4534
}
35+
newNode.SetType((*node).GetType())
36+
newNode.SetLocation((*node).GetLocation())
37+
*node = newNode
4638
}
4739
}
4840

expr_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,26 @@ func TestOperator_struct(t *testing.T) {
416416
require.Equal(t, true, output)
417417
}
418418

419+
func TestOperator_interface(t *testing.T) {
420+
env := &mockEnv{
421+
Ticket: &ticket{Price: 100},
422+
}
423+
424+
code := `Ticket == "$100" && "$100" == Ticket && Now != Ticket && Now == Now`
425+
426+
program, err := expr.Compile(
427+
code,
428+
expr.Env(&mockEnv{}),
429+
expr.Operator("==", "StringerStringEqual", "StringStringerEqual", "StringerStringerEqual"),
430+
expr.Operator("!=", "NotStringerStringEqual", "NotStringStringerEqual", "NotStringerStringerEqual"),
431+
)
432+
require.NoError(t, err)
433+
434+
output, err := expr.Run(program, env)
435+
require.NoError(t, err)
436+
require.Equal(t, true, output)
437+
}
438+
419439
func TestExpr(t *testing.T) {
420440
env := &mockEnv{
421441
Any: "any",
@@ -836,6 +856,30 @@ func (*mockEnv) DateEqual(date time.Time, s string) bool {
836856
return date.Format("2006-01-02") == s
837857
}
838858

859+
func (*mockEnv) StringerStringEqual(f fmt.Stringer, s string) bool {
860+
return f.String() == s
861+
}
862+
863+
func (*mockEnv) StringStringerEqual(s string, f fmt.Stringer) bool {
864+
return s == f.String()
865+
}
866+
867+
func (*mockEnv) StringerStringerEqual(f fmt.Stringer, g fmt.Stringer) bool {
868+
return f.String() == g.String()
869+
}
870+
871+
func (*mockEnv) NotStringerStringEqual(f fmt.Stringer, s string) bool {
872+
return f.String() != s
873+
}
874+
875+
func (*mockEnv) NotStringStringerEqual(s string, f fmt.Stringer) bool {
876+
return s != f.String()
877+
}
878+
879+
func (*mockEnv) NotStringerStringerEqual(f fmt.Stringer, g fmt.Stringer) bool {
880+
return f.String() != g.String()
881+
}
882+
839883
type ticket struct {
840884
Price int
841885
}

internal/conf/operators_table.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
package conf
22

3+
import "reflect"
4+
35
// OperatorsTable maps binary operators to corresponding list of functions.
46
// Functions should be provided in the environment to allow operator overloading.
57
type OperatorsTable map[string][]string
8+
9+
func FindSuitableOperatorOverload(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
10+
for _, fn := range fns {
11+
fnType := types[fn]
12+
firstInIndex := 0
13+
if fnType.Method {
14+
firstInIndex = 1 // As first argument to method is receiver.
15+
}
16+
firstArgType := fnType.Type.In(firstInIndex)
17+
secondArgType := fnType.Type.In(firstInIndex + 1)
18+
19+
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && l.Implements(firstArgType))
20+
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && r.Implements(secondArgType))
21+
if firstArgumentFit && secondArgumentFit {
22+
return fnType.Type.Out(0), fn, true
23+
}
24+
}
25+
return nil, "", false
26+
}

0 commit comments

Comments
 (0)