Skip to content

Commit 6dbcb97

Browse files
authored
Dl/peephole optimize (#486)
adds Peephole optimization rewrites and test cases. For issue #140
1 parent 4fc3e17 commit 6dbcb97

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

src/kirin/rewrite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .getitem import InlineGetItem as InlineGetItem
99
from .fixpoint import Fixpoint as Fixpoint
1010
from .getfield import InlineGetField as InlineGetField
11+
from .peephole import PeepholeOptimize as PeepholeOptimize
1112
from .apply_type import ApplyType as ApplyType
1213
from .compactify import CFGCompactify as CFGCompactify
1314
from .wrap_const import WrapConst as WrapConst

src/kirin/rewrite/peephole.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from dataclasses import dataclass
2+
3+
from kirin import ir, types
4+
from kirin.dialects import py
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
7+
8+
@dataclass
9+
class PeepholeOptimize(RewriteRule):
10+
11+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12+
result_types = node.results.types
13+
if not all(
14+
t.is_subseteq(types.Union(types.Float, types.Int)) for t in result_types
15+
):
16+
return RewriteResult(has_done_something=False)
17+
18+
if isinstance(node, py.binop.Add):
19+
# add(%a, %a) -> mul(2, %a)
20+
if node.lhs is node.rhs:
21+
x = py.Constant(2)
22+
x.insert_before(node)
23+
new_stmt = py.binop.Mult(x.result, node.rhs)
24+
node.replace_by(new_stmt)
25+
return RewriteResult(has_done_something=True)
26+
27+
# add(mul(2, %a), %a) -> mul(3, %a)
28+
elif isinstance(mult_node := node.lhs.owner, py.binop.Mult) and isinstance(
29+
const_node := mult_node.lhs.owner, py.Constant
30+
):
31+
x = const_node.value.unwrap()
32+
const_node.replace_by(py.Constant(x + 1))
33+
node.replace_by(py.binop.Mult(mult_node.lhs, node.rhs))
34+
mult_node.delete()
35+
return RewriteResult(has_done_something=True)
36+
37+
# add(%a, mul(2, %a)) -> mul(3, %a)
38+
elif isinstance(mult_node := node.rhs.owner, py.binop.Mult) and isinstance(
39+
const_node := mult_node.lhs.owner, py.Constant
40+
):
41+
x = const_node.value.unwrap()
42+
const_node.replace_by(py.Constant(x + 1))
43+
node.replace_by(py.binop.Mult(mult_node.lhs, node.lhs))
44+
mult_node.delete()
45+
return RewriteResult(has_done_something=True)
46+
47+
return RewriteResult(has_done_something=False)

test/rewrite/test_peephole.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from kirin.prelude import basic_no_opt
2+
from kirin.rewrite import Walk, Fixpoint
3+
from kirin.rewrite.peephole import PeepholeOptimize
4+
5+
6+
# add(%a, %a) -> mul(2, %a)
7+
@basic_no_opt
8+
def peephole1(a: int):
9+
x = a + a
10+
return x
11+
12+
13+
# add(mul(2, %a), %a) -> mul(3, %a)
14+
@basic_no_opt
15+
def peephole2(a: int):
16+
x = 2 * a + a
17+
return x
18+
19+
20+
# add(%a, mul(2, %a)) -> mul(3, %a)
21+
@basic_no_opt
22+
def peephole3(a: int):
23+
x = a + 2 * a
24+
return x
25+
26+
27+
# add(%a, add(%a, mul(2, %a))) -> mul(4, %a)
28+
@basic_no_opt
29+
def peephole4(a: int):
30+
x = a + a + 2 * a
31+
return x
32+
33+
34+
def aux(program):
35+
for i in range(5):
36+
before = program(i)
37+
Fixpoint(Walk(PeepholeOptimize())).rewrite(program.code)
38+
after = program(i)
39+
assert before == after
40+
41+
42+
def test_peephole_opt1():
43+
aux(peephole1)
44+
45+
46+
def test_peephole_opt2():
47+
aux(peephole2)
48+
49+
50+
def test_peephole_opt3():
51+
aux(peephole3)
52+
53+
54+
def test_peephole_opt4():
55+
aux(peephole4)

0 commit comments

Comments
 (0)