1
1
package patcher
2
2
3
3
import (
4
+ "fmt"
5
+ "reflect"
6
+
4
7
"github.com/expr-lang/expr/ast"
8
+ "github.com/expr-lang/expr/builtin"
5
9
"github.com/expr-lang/expr/conf"
6
10
)
7
11
8
- type Operator struct {
9
- Operators conf.OperatorsTable
10
- Types conf.TypesTable
11
- Functions conf.FunctionTable
12
+ type OperatorOverride struct {
13
+ Operator string // Operator token to override.
14
+ Overrides []string // List of function names to override operator with.
15
+ Types conf.TypesTable // Env types.
16
+ Functions conf.FunctionsTable // Env functions.
12
17
}
13
18
14
- func (p * Operator ) Visit (node * ast.Node ) {
19
+ func (p * OperatorOverride ) Visit (node * ast.Node ) {
15
20
binaryNode , ok := (* node ).(* ast.BinaryNode )
16
21
if ! ok {
17
22
return
18
23
}
19
24
20
- fns , ok := p .Operators [binaryNode .Operator ]
21
- if ! ok {
25
+ if binaryNode .Operator != p .Operator {
22
26
return
23
27
}
24
28
25
29
leftType := binaryNode .Left .Type ()
26
30
rightType := binaryNode .Right .Type ()
27
31
28
- ret , fn , ok := conf .FindSuitableOperatorOverload (fns , p . Types , p . Functions , leftType , rightType )
32
+ ret , fn , ok := p .FindSuitableOperatorOverload (leftType , rightType )
29
33
if ok {
30
34
newNode := & ast.CallNode {
31
35
Callee : & ast.IdentifierNode {Value : fn },
@@ -35,3 +39,97 @@ func (p *Operator) Visit(node *ast.Node) {
35
39
ast .Patch (node , newNode )
36
40
}
37
41
}
42
+
43
+ func (p * OperatorOverride ) FindSuitableOperatorOverload (l , r reflect.Type ) (reflect.Type , string , bool ) {
44
+ t , fn , ok := p .findSuitableOperatorOverloadInFunctions (l , r )
45
+ if ! ok {
46
+ t , fn , ok = p .findSuitableOperatorOverloadInTypes (l , r )
47
+ }
48
+ return t , fn , ok
49
+ }
50
+
51
+ func (p * OperatorOverride ) findSuitableOperatorOverloadInTypes (l , r reflect.Type ) (reflect.Type , string , bool ) {
52
+ for _ , fn := range p .Overrides {
53
+ fnType , ok := p .Types [fn ]
54
+ if ! ok {
55
+ continue
56
+ }
57
+ firstInIndex := 0
58
+ if fnType .Method {
59
+ firstInIndex = 1 // As first argument to method is receiver.
60
+ }
61
+ ret , done := checkTypeSuits (fnType .Type , l , r , firstInIndex )
62
+ if done {
63
+ return ret , fn , true
64
+ }
65
+ }
66
+ return nil , "" , false
67
+ }
68
+
69
+ func (p * OperatorOverride ) findSuitableOperatorOverloadInFunctions (l , r reflect.Type ) (reflect.Type , string , bool ) {
70
+ for _ , fn := range p .Overrides {
71
+ fnType , ok := p .Functions [fn ]
72
+ if ! ok {
73
+ continue
74
+ }
75
+ firstInIndex := 0
76
+ for _ , overload := range fnType .Types {
77
+ ret , done := checkTypeSuits (overload , l , r , firstInIndex )
78
+ if done {
79
+ return ret , fn , true
80
+ }
81
+ }
82
+ }
83
+ return nil , "" , false
84
+ }
85
+
86
+ func checkTypeSuits (t reflect.Type , l reflect.Type , r reflect.Type , firstInIndex int ) (reflect.Type , bool ) {
87
+ firstArgType := t .In (firstInIndex )
88
+ secondArgType := t .In (firstInIndex + 1 )
89
+
90
+ firstArgumentFit := l == firstArgType || (firstArgType .Kind () == reflect .Interface && (l == nil || l .Implements (firstArgType )))
91
+ secondArgumentFit := r == secondArgType || (secondArgType .Kind () == reflect .Interface && (r == nil || r .Implements (secondArgType )))
92
+ if firstArgumentFit && secondArgumentFit {
93
+ return t .Out (0 ), true
94
+ }
95
+ return nil , false
96
+ }
97
+
98
+ func (p * OperatorOverride ) Check () {
99
+ for _ , fn := range p .Overrides {
100
+ fnType , foundType := p .Types [fn ]
101
+ fnFunc , foundFunc := p .Functions [fn ]
102
+ if ! foundFunc && (! foundType || fnType .Type .Kind () != reflect .Func ) {
103
+ panic (fmt .Errorf ("function %s for %s operator does not exist in the environment" , fn , p .Operator ))
104
+ }
105
+
106
+ if foundType {
107
+ checkType (fnType , fn , p .Operator )
108
+ }
109
+
110
+ if foundFunc {
111
+ checkFunc (fnFunc , fn , p .Operator )
112
+ }
113
+ }
114
+ }
115
+
116
+ func checkType (fnType conf.Tag , fn string , operator string ) {
117
+ requiredNumIn := 2
118
+ if fnType .Method {
119
+ requiredNumIn = 3 // As first argument of method is receiver.
120
+ }
121
+ if fnType .Type .NumIn () != requiredNumIn || fnType .Type .NumOut () != 1 {
122
+ panic (fmt .Errorf ("function %s for %s operator does not have a correct signature" , fn , operator ))
123
+ }
124
+ }
125
+
126
+ func checkFunc (fn * builtin.Function , name string , operator string ) {
127
+ if len (fn .Types ) == 0 {
128
+ panic (fmt .Errorf ("function %s for %s operator misses types" , name , operator ))
129
+ }
130
+ for _ , t := range fn .Types {
131
+ if t .NumIn () != 2 || t .NumOut () != 1 {
132
+ panic (fmt .Errorf ("function %s for %s operator does not have a correct signature" , name , operator ))
133
+ }
134
+ }
135
+ }
0 commit comments