Skip to content

Commit 4629871

Browse files
committed
Fix float folding optimization
1 parent e81f517 commit 4629871

File tree

5 files changed

+89
-78
lines changed

5 files changed

+89
-78
lines changed

checker/checker.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -894,9 +894,9 @@ func (v *checker) checkArguments(name string, fn reflect.Type, method bool, argu
894894
in = fn.In(i + fnInOffset)
895895
}
896896

897-
if isIntegerOrArithmeticOperation(arg) && (isInteger(in) || isFloat(in)) {
898-
t = in
899-
setTypeForIntegers(arg, t)
897+
if isFloat(in) {
898+
t = floatType
899+
traverseAndReplaceIntegerNodesWithFloatNodes(&arg)
900900
}
901901

902902
if t == nil {
@@ -914,6 +914,23 @@ func (v *checker) checkArguments(name string, fn reflect.Type, method bool, argu
914914
return fn.Out(0), nil
915915
}
916916

917+
func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node) {
918+
switch (*node).(type) {
919+
case *ast.IntegerNode:
920+
*node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)}
921+
case *ast.UnaryNode:
922+
unaryNode := (*node).(*ast.UnaryNode)
923+
traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node)
924+
case *ast.BinaryNode:
925+
binaryNode := (*node).(*ast.BinaryNode)
926+
switch binaryNode.Operator {
927+
case "+", "-", "*":
928+
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left)
929+
traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right)
930+
}
931+
}
932+
}
933+
917934
func (v *checker) ClosureNode(node *ast.ClosureNode) (reflect.Type, info) {
918935
t, _ := v.visit(node.Node)
919936
if t == nil {

checker/types.go

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"reflect"
55
"time"
66

7-
"github.com/antonmedv/expr/ast"
87
"github.com/antonmedv/expr/conf"
98
)
109

@@ -228,39 +227,3 @@ func kind(t reflect.Type) reflect.Kind {
228227
}
229228
return t.Kind()
230229
}
231-
232-
func isIntegerOrArithmeticOperation(node ast.Node) bool {
233-
switch n := node.(type) {
234-
case *ast.IntegerNode:
235-
return true
236-
case *ast.UnaryNode:
237-
switch n.Operator {
238-
case "+", "-":
239-
return true
240-
}
241-
case *ast.BinaryNode:
242-
switch n.Operator {
243-
case "+", "-", "*":
244-
return true
245-
}
246-
}
247-
return false
248-
}
249-
250-
func setTypeForIntegers(node ast.Node, t reflect.Type) {
251-
switch n := node.(type) {
252-
case *ast.IntegerNode:
253-
n.SetType(t)
254-
case *ast.UnaryNode:
255-
switch n.Operator {
256-
case "+", "-":
257-
setTypeForIntegers(n.Node, t)
258-
}
259-
case *ast.BinaryNode:
260-
switch n.Operator {
261-
case "+", "/", "-", "*":
262-
setTypeForIntegers(n.Left, t)
263-
setTypeForIntegers(n.Right, t)
264-
}
265-
}
266-
}

expr_test.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,18 @@ func TestExpr(t *testing.T) {
10281028
`first(filter(map(1..9, # * 2), # % 2 == 0))`,
10291029
2,
10301030
},
1031+
{
1032+
`first(map(filter(1..9, # % 2 == 0), # * 2))`,
1033+
4,
1034+
},
1035+
{
1036+
`2^3 == 8`,
1037+
true,
1038+
},
1039+
{
1040+
`4/2 == 2`,
1041+
true,
1042+
},
10311043
}
10321044

10331045
for _, tt := range tests {
@@ -1198,13 +1210,13 @@ func TestExpr_calls_with_nil(t *testing.T) {
11981210
require.Equal(t, true, out)
11991211
}
12001212

1201-
func TestExpr_call_floatarg_func_with_int(t *testing.T) {
1213+
func TestExpr_call_float_arg_func_with_int(t *testing.T) {
12021214
env := map[string]interface{}{
12031215
"cnv": func(f float64) interface{} {
12041216
return f
12051217
},
12061218
}
1207-
for _, each := range []struct {
1219+
tests := []struct {
12081220
input string
12091221
expected float64
12101222
}{
@@ -1214,15 +1226,18 @@ func TestExpr_call_floatarg_func_with_int(t *testing.T) {
12141226
{"1-1", 0.0},
12151227
{"1/1", 1.0},
12161228
{"1*1", 1.0},
1217-
} {
1218-
p, err := expr.Compile(
1219-
fmt.Sprintf("cnv(%s)", each.input),
1220-
expr.Env(env))
1221-
require.NoError(t, err)
1229+
{"1^1", 1.0},
1230+
}
1231+
for _, tt := range tests {
1232+
t.Run(tt.input, func(t *testing.T) {
1233+
p, err := expr.Compile(fmt.Sprintf("cnv(%s)", tt.input), expr.Env(env))
1234+
require.NoError(t, err)
12221235

1223-
out, err := expr.Run(p, env)
1224-
require.NoError(t, err)
1225-
require.Equal(t, each.expected, out)
1236+
p.Disassemble()
1237+
out, err := expr.Run(p, env)
1238+
require.NoError(t, err)
1239+
require.Equal(t, tt.expected, out)
1240+
})
12261241
}
12271242
}
12281243

optimizer/fold.go

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package optimizer
22

33
import (
4+
"fmt"
45
"math"
56
"reflect"
67

78
. "github.com/antonmedv/expr/ast"
89
"github.com/antonmedv/expr/file"
910
)
1011

12+
var (
13+
integerType = reflect.TypeOf(0)
14+
floatType = reflect.TypeOf(float64(0))
15+
stringType = reflect.TypeOf("")
16+
)
17+
1118
type fold struct {
1219
applied bool
1320
err *file.Error
@@ -18,29 +25,36 @@ func (fold *fold) Visit(node *Node) {
1825
fold.applied = true
1926
Patch(node, newNode)
2027
}
21-
// for IntegerNode the type may have been changed from int->float
22-
// preserve this information by setting the type after the Patch
23-
patchWithType := func(newNode Node, leafType reflect.Type) {
28+
patchWithType := func(newNode Node) {
2429
patch(newNode)
25-
newNode.SetType(leafType)
30+
switch newNode.(type) {
31+
case *IntegerNode:
32+
newNode.SetType(integerType)
33+
case *FloatNode:
34+
newNode.SetType(floatType)
35+
case *StringNode:
36+
newNode.SetType(stringType)
37+
default:
38+
panic(fmt.Sprintf("unknown type %T", newNode))
39+
}
2640
}
2741

2842
switch n := (*node).(type) {
2943
case *UnaryNode:
3044
switch n.Operator {
3145
case "-":
3246
if i, ok := n.Node.(*IntegerNode); ok {
33-
patchWithType(&IntegerNode{Value: -i.Value}, n.Node.Type())
47+
patchWithType(&IntegerNode{Value: -i.Value})
3448
}
3549
if i, ok := n.Node.(*FloatNode); ok {
36-
patchWithType(&FloatNode{Value: -i.Value}, n.Node.Type())
50+
patchWithType(&FloatNode{Value: -i.Value})
3751
}
3852
case "+":
3953
if i, ok := n.Node.(*IntegerNode); ok {
40-
patchWithType(&IntegerNode{Value: i.Value}, n.Node.Type())
54+
patchWithType(&IntegerNode{Value: i.Value})
4155
}
4256
if i, ok := n.Node.(*FloatNode); ok {
43-
patchWithType(&FloatNode{Value: i.Value}, n.Node.Type())
57+
patchWithType(&FloatNode{Value: i.Value})
4458
}
4559
case "!", "not":
4660
if a := toBool(n.Node); a != nil {
@@ -55,28 +69,28 @@ func (fold *fold) Visit(node *Node) {
5569
a := toInteger(n.Left)
5670
b := toInteger(n.Right)
5771
if a != nil && b != nil {
58-
patchWithType(&IntegerNode{Value: a.Value + b.Value}, a.Type())
72+
patchWithType(&IntegerNode{Value: a.Value + b.Value})
5973
}
6074
}
6175
{
6276
a := toInteger(n.Left)
6377
b := toFloat(n.Right)
6478
if a != nil && b != nil {
65-
patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}, a.Type())
79+
patchWithType(&FloatNode{Value: float64(a.Value) + b.Value})
6680
}
6781
}
6882
{
6983
a := toFloat(n.Left)
7084
b := toInteger(n.Right)
7185
if a != nil && b != nil {
72-
patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}, a.Type())
86+
patchWithType(&FloatNode{Value: a.Value + float64(b.Value)})
7387
}
7488
}
7589
{
7690
a := toFloat(n.Left)
7791
b := toFloat(n.Right)
7892
if a != nil && b != nil {
79-
patchWithType(&FloatNode{Value: a.Value + b.Value}, a.Type())
93+
patchWithType(&FloatNode{Value: a.Value + b.Value})
8094
}
8195
}
8296
{
@@ -91,86 +105,86 @@ func (fold *fold) Visit(node *Node) {
91105
a := toInteger(n.Left)
92106
b := toInteger(n.Right)
93107
if a != nil && b != nil {
94-
patchWithType(&IntegerNode{Value: a.Value - b.Value}, a.Type())
108+
patchWithType(&IntegerNode{Value: a.Value - b.Value})
95109
}
96110
}
97111
{
98112
a := toInteger(n.Left)
99113
b := toFloat(n.Right)
100114
if a != nil && b != nil {
101-
patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}, a.Type())
115+
patchWithType(&FloatNode{Value: float64(a.Value) - b.Value})
102116
}
103117
}
104118
{
105119
a := toFloat(n.Left)
106120
b := toInteger(n.Right)
107121
if a != nil && b != nil {
108-
patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}, a.Type())
122+
patchWithType(&FloatNode{Value: a.Value - float64(b.Value)})
109123
}
110124
}
111125
{
112126
a := toFloat(n.Left)
113127
b := toFloat(n.Right)
114128
if a != nil && b != nil {
115-
patchWithType(&FloatNode{Value: a.Value - b.Value}, a.Type())
129+
patchWithType(&FloatNode{Value: a.Value - b.Value})
116130
}
117131
}
118132
case "*":
119133
{
120134
a := toInteger(n.Left)
121135
b := toInteger(n.Right)
122136
if a != nil && b != nil {
123-
patchWithType(&IntegerNode{Value: a.Value * b.Value}, a.Type())
137+
patchWithType(&IntegerNode{Value: a.Value * b.Value})
124138
}
125139
}
126140
{
127141
a := toInteger(n.Left)
128142
b := toFloat(n.Right)
129143
if a != nil && b != nil {
130-
patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}, a.Type())
144+
patchWithType(&FloatNode{Value: float64(a.Value) * b.Value})
131145
}
132146
}
133147
{
134148
a := toFloat(n.Left)
135149
b := toInteger(n.Right)
136150
if a != nil && b != nil {
137-
patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}, a.Type())
151+
patchWithType(&FloatNode{Value: a.Value * float64(b.Value)})
138152
}
139153
}
140154
{
141155
a := toFloat(n.Left)
142156
b := toFloat(n.Right)
143157
if a != nil && b != nil {
144-
patchWithType(&FloatNode{Value: a.Value * b.Value}, a.Type())
158+
patchWithType(&FloatNode{Value: a.Value * b.Value})
145159
}
146160
}
147161
case "/":
148162
{
149163
a := toInteger(n.Left)
150164
b := toInteger(n.Right)
151165
if a != nil && b != nil {
152-
patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}, a.Type())
166+
patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)})
153167
}
154168
}
155169
{
156170
a := toInteger(n.Left)
157171
b := toFloat(n.Right)
158172
if a != nil && b != nil {
159-
patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}, a.Type())
173+
patchWithType(&FloatNode{Value: float64(a.Value) / b.Value})
160174
}
161175
}
162176
{
163177
a := toFloat(n.Left)
164178
b := toInteger(n.Right)
165179
if a != nil && b != nil {
166-
patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}, a.Type())
180+
patchWithType(&FloatNode{Value: a.Value / float64(b.Value)})
167181
}
168182
}
169183
{
170184
a := toFloat(n.Left)
171185
b := toFloat(n.Right)
172186
if a != nil && b != nil {
173-
patchWithType(&FloatNode{Value: a.Value / b.Value}, a.Type())
187+
patchWithType(&FloatNode{Value: a.Value / b.Value})
174188
}
175189
}
176190
case "%":
@@ -191,28 +205,28 @@ func (fold *fold) Visit(node *Node) {
191205
a := toInteger(n.Left)
192206
b := toInteger(n.Right)
193207
if a != nil && b != nil {
194-
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}, a.Type())
208+
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))})
195209
}
196210
}
197211
{
198212
a := toInteger(n.Left)
199213
b := toFloat(n.Right)
200214
if a != nil && b != nil {
201-
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}, a.Type())
215+
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)})
202216
}
203217
}
204218
{
205219
a := toFloat(n.Left)
206220
b := toInteger(n.Right)
207221
if a != nil && b != nil {
208-
patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}, a.Type())
222+
patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))})
209223
}
210224
}
211225
{
212226
a := toFloat(n.Left)
213227
b := toFloat(n.Right)
214228
if a != nil && b != nil {
215-
patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}, a.Type())
229+
patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)})
216230
}
217231
}
218232
case "and", "&&":

optimizer/optimizer_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package optimizer_test
22

33
import (
4+
"reflect"
45
"strings"
56
"testing"
67

@@ -39,6 +40,7 @@ func TestOptimize_constant_folding_with_floats(t *testing.T) {
3940
expected := &ast.FloatNode{Value: 3.0}
4041

4142
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
43+
assert.Equal(t, reflect.Float64, tree.Node.Type().Kind())
4244
}
4345

4446
func TestOptimize_constant_folding_with_bools(t *testing.T) {

0 commit comments

Comments
 (0)