Skip to content

Commit 36418e3

Browse files
committed
Improve constant folding algorithm
1 parent 261808b commit 36418e3

File tree

8 files changed

+83
-37
lines changed

8 files changed

+83
-37
lines changed

ast/node.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ type MapNode struct {
169169
l file.Location
170170
t reflect.Type
171171

172-
Pairs []*PairNode
172+
Pairs []Node
173173
}
174174

175175
type PairNode struct {

ast/visitor.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,17 @@ func (w *walker) walk(node *Node) {
8989
w.walk(&n.Exp2)
9090
w.visitor.Exit(node)
9191
case *ArrayNode:
92-
for _, node := range n.Nodes {
93-
w.walk(&node)
92+
for i := range n.Nodes {
93+
w.walk(&n.Nodes[i])
9494
}
9595
w.visitor.Exit(node)
9696
case *MapNode:
97-
var pair Node
98-
for _, pair = range n.Pairs {
99-
w.walk(&pair)
97+
for i := range n.Pairs {
98+
w.walk(&n.Pairs[i])
10099
}
101100
w.visitor.Exit(node)
102101
case *PairNode:
102+
w.walk(&n.Key)
103103
w.walk(&n.Value)
104104
w.visitor.Exit(node)
105105
default:

checker/checker.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ func (v *visitor) visit(node ast.Node) reflect.Type {
9999
t = v.ArrayNode(n)
100100
case *ast.MapNode:
101101
t = v.MapNode(n)
102+
case *ast.PairNode:
103+
t = v.PairNode(n)
102104
default:
103105
panic(fmt.Sprintf("undefined node type (%T)", node))
104106
}
@@ -550,7 +552,13 @@ func (v *visitor) ArrayNode(node *ast.ArrayNode) reflect.Type {
550552

551553
func (v *visitor) MapNode(node *ast.MapNode) reflect.Type {
552554
for _, pair := range node.Pairs {
553-
v.visit(pair.Value)
555+
v.visit(pair)
554556
}
555557
return mapType
556558
}
559+
560+
func (v *visitor) PairNode(node *ast.PairNode) reflect.Type {
561+
v.visit(node.Key)
562+
v.visit(node.Value)
563+
return nilType
564+
}

cmd/exe/dot.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func (v *visitor) Exit(ref *Node) {
7070

7171
case *UnaryNode:
7272
n := v.pop()
73-
v.push("-")
73+
v.push(node.Operator)
7474
v.link(n)
7575

7676
case *BinaryNode:

compiler/compiler.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ func (c *compiler) compile(node ast.Node) {
156156
c.ArrayNode(n)
157157
case *ast.MapNode:
158158
c.MapNode(n)
159+
case *ast.PairNode:
160+
c.PairNode(n)
159161
default:
160162
panic(fmt.Sprintf("undefined node type (%T)", node))
161163
}
@@ -611,14 +613,18 @@ func (c *compiler) ArrayNode(node *ast.ArrayNode) {
611613

612614
func (c *compiler) MapNode(node *ast.MapNode) {
613615
for _, pair := range node.Pairs {
614-
c.compile(pair.Key)
615-
c.compile(pair.Value)
616+
c.compile(pair)
616617
}
617618

618619
c.emitPush(len(node.Pairs))
619620
c.emit(OpMap)
620621
}
621622

623+
func (c *compiler) PairNode(node *ast.PairNode) {
624+
c.compile(node.Key)
625+
c.compile(node.Value)
626+
}
627+
622628
func encode(i uint16) []byte {
623629
b := make([]byte, 2)
624630
binary.LittleEndian.PutUint16(b, i)

optimizer/optimizer.go

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"math"
66
)
77

8-
type fold struct{}
8+
type fold struct {
9+
applied bool
10+
}
911
type inRange struct{}
1012
type constRange struct{}
1113

@@ -16,14 +18,22 @@ func patch(node *Node, newNode Node) {
1618
}
1719

1820
func (*fold) Enter(node *Node) {}
19-
func (*fold) Exit(node *Node) {
21+
func (fold *fold) Exit(node *Node) {
22+
patch := func(newNode Node) {
23+
fold.applied = true
24+
patch(node, newNode)
25+
}
26+
2027
switch n := (*node).(type) {
2128
case *UnaryNode:
22-
if n.Operator == "-" {
29+
switch n.Operator {
30+
case "-":
31+
if i, ok := n.Node.(*IntegerNode); ok {
32+
patch(&IntegerNode{Value: -i.Value})
33+
}
34+
case "+":
2335
if i, ok := n.Node.(*IntegerNode); ok {
24-
patch(node, &IntegerNode{
25-
Value: -i.Value,
26-
})
36+
patch(&IntegerNode{Value: i.Value})
2737
}
2838
}
2939

@@ -32,24 +42,42 @@ func (*fold) Exit(node *Node) {
3242
case "+":
3343
if a, ok := n.Left.(*IntegerNode); ok {
3444
if b, ok := n.Right.(*IntegerNode); ok {
35-
patch(node, &IntegerNode{
36-
Value: a.Value + b.Value,
37-
})
45+
patch(&IntegerNode{Value: a.Value + b.Value})
3846
}
3947
}
4048
if a, ok := n.Left.(*StringNode); ok {
4149
if b, ok := n.Right.(*StringNode); ok {
42-
patch(node, &StringNode{
43-
Value: a.Value + b.Value,
44-
})
50+
patch(&StringNode{Value: a.Value + b.Value})
51+
}
52+
}
53+
case "-":
54+
if a, ok := n.Left.(*IntegerNode); ok {
55+
if b, ok := n.Right.(*IntegerNode); ok {
56+
patch(&IntegerNode{Value: a.Value - b.Value})
57+
}
58+
}
59+
case "*":
60+
if a, ok := n.Left.(*IntegerNode); ok {
61+
if b, ok := n.Right.(*IntegerNode); ok {
62+
patch(&IntegerNode{Value: a.Value * b.Value})
63+
}
64+
}
65+
case "/":
66+
if a, ok := n.Left.(*IntegerNode); ok {
67+
if b, ok := n.Right.(*IntegerNode); ok {
68+
patch(&IntegerNode{Value: a.Value / b.Value})
69+
}
70+
}
71+
case "%":
72+
if a, ok := n.Left.(*IntegerNode); ok {
73+
if b, ok := n.Right.(*IntegerNode); ok {
74+
patch(&IntegerNode{Value: a.Value % b.Value})
4575
}
4676
}
4777
case "**":
4878
if a, ok := n.Left.(*IntegerNode); ok {
4979
if b, ok := n.Right.(*IntegerNode); ok {
50-
patch(node, &FloatNode{
51-
Value: math.Pow(float64(a.Value), float64(b.Value)),
52-
})
80+
patch(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))})
5381
}
5482
}
5583
}
@@ -67,9 +95,7 @@ func (*fold) Exit(node *Node) {
6795
for i, a := range n.Nodes {
6896
value[i] = a.(*IntegerNode).Value
6997
}
70-
patch(node, &ConstantNode{
71-
Value: value,
72-
})
98+
patch(&ConstantNode{Value: value})
7399
}
74100

75101
string:
@@ -83,9 +109,7 @@ func (*fold) Exit(node *Node) {
83109
for i, a := range n.Nodes {
84110
value[i] = a.(*StringNode).Value
85111
}
86-
patch(node, &ConstantNode{
87-
Value: value,
88-
})
112+
patch(&ConstantNode{Value: value})
89113
}
90114

91115
}
@@ -148,7 +172,15 @@ func (*constRange) Exit(node *Node) {
148172
}
149173

150174
func Optimize(node *Node) {
151-
Walk(node, &fold{})
175+
limit := 1000
176+
for {
177+
fold := &fold{}
178+
Walk(node, fold)
179+
limit--
180+
if !fold.applied || limit == 0 {
181+
break
182+
}
183+
}
152184
Walk(node, &inRange{})
153185
Walk(node, &constRange{})
154186
}

parser/parser.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ func (p *parser) ExitMapLiteral(ctx *gen.MapLiteralContext) {
282282
return
283283
}
284284

285-
nodes := make([]*ast.PairNode, 0)
285+
nodes := make([]ast.Node, 0)
286286
for range e.GetList() {
287-
nodes = append([]*ast.PairNode{p.pop(ctx).(*ast.PairNode)}, nodes...)
287+
nodes = append([]ast.Node{p.pop(ctx).(*ast.PairNode)}, nodes...)
288288
}
289289
p.push(&ast.MapNode{Pairs: nodes}).SetLocation(location(ctx))
290290
}

parser/parser_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,19 @@ func TestParse(t *testing.T) {
144144
},
145145
{
146146
"{foo:1, bar:2}",
147-
&ast.MapNode{Pairs: []*ast.PairNode{{Key: &ast.StringNode{Value: "foo"}, Value: &ast.IntegerNode{Value: 1}}, {Key: &ast.StringNode{Value: "bar"}, Value: &ast.IntegerNode{Value: 2}}}},
147+
&ast.MapNode{Pairs: []ast.Node{&ast.PairNode{Key: &ast.StringNode{Value: "foo"}, Value: &ast.IntegerNode{Value: 1}}, &ast.PairNode{Key: &ast.StringNode{Value: "bar"}, Value: &ast.IntegerNode{Value: 2}}}},
148148
},
149149
{
150150
`{"a": 1, 'b': 2}`,
151-
&ast.MapNode{Pairs: []*ast.PairNode{{Key: &ast.StringNode{Value: "a"}, Value: &ast.IntegerNode{Value: 1}}, {Key: &ast.StringNode{Value: "b"}, Value: &ast.IntegerNode{Value: 2}}}},
151+
&ast.MapNode{Pairs: []ast.Node{&ast.PairNode{Key: &ast.StringNode{Value: "a"}, Value: &ast.IntegerNode{Value: 1}}, &ast.PairNode{Key: &ast.StringNode{Value: "b"}, Value: &ast.IntegerNode{Value: 2}}}},
152152
},
153153
{
154154
"[1].foo",
155155
&ast.PropertyNode{Node: &ast.ArrayNode{Nodes: []ast.Node{&ast.IntegerNode{Value: 1}}}, Property: "foo"},
156156
},
157157
{
158158
"{foo:1}.bar",
159-
&ast.PropertyNode{Node: &ast.MapNode{Pairs: []*ast.PairNode{{Key: &ast.StringNode{Value: "foo"}, Value: &ast.IntegerNode{Value: 1}}}}, Property: "bar"},
159+
&ast.PropertyNode{Node: &ast.MapNode{Pairs: []ast.Node{&ast.PairNode{Key: &ast.StringNode{Value: "foo"}, Value: &ast.IntegerNode{Value: 1}}}}, Property: "bar"},
160160
},
161161
{
162162
"len(foo)",

0 commit comments

Comments
 (0)