Skip to content

Commit f9b98c9

Browse files
authored
backport 493 (#494)
1 parent 3104f49 commit f9b98c9

File tree

4 files changed

+76
-1
lines changed

4 files changed

+76
-1
lines changed

src/kirin/dialects/scf/unroll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
8888
for result, output in zip(node.results, loop_vars):
8989
result.replace_by(output)
9090
node.delete()
91-
return RewriteResult(has_done_something=True, terminated=True)
91+
return RewriteResult(has_done_something=True)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .fold import Fold as Fold
2+
from .unroll import UnrollScf as UnrollScf
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from dataclasses import field, dataclass
2+
3+
from kirin.ir import Method
4+
from kirin.passes import Fold, Pass, TypeInfer
5+
from kirin.rewrite import Walk
6+
from kirin.rewrite.abc import RewriteResult
7+
from kirin.dialects.scf.unroll import ForLoop, PickIfElse
8+
9+
10+
@dataclass
11+
class UnrollScf(Pass):
12+
"""This pass can be used to unroll scf.For loops and inline/expand scf.IfElse when
13+
the input are known at compile time.
14+
15+
usage:
16+
UnrollScf(dialects).fixpoint(method)
17+
18+
Note: This pass should be used in a fixpoint manner, to unroll nested scf nodes.
19+
20+
"""
21+
22+
typeinfer: TypeInfer = field(init=False)
23+
fold: Fold = field(init=False)
24+
25+
def __post_init__(self):
26+
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
27+
self.fold = Fold(self.dialects, no_raise=self.no_raise)
28+
29+
def unsafe_run(self, mt: Method):
30+
result = RewriteResult()
31+
result = Walk(PickIfElse()).rewrite(mt.code).join(result)
32+
result = Walk(ForLoop()).rewrite(mt.code).join(result)
33+
result = self.typeinfer(mt).join(result)
34+
result = self.fold(mt).join(result)
35+
return result

test/passes/test_unroll_scf.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from kirin.prelude import structural
2+
from kirin.dialects import py, func
3+
from kirin.passes.aggressive import UnrollScf
4+
5+
6+
def test_unroll_scf():
7+
@structural
8+
def main(r: list[int], cond: bool):
9+
if cond:
10+
for i in range(4):
11+
tmp = r[-1]
12+
if i < 2:
13+
tmp += i * 2
14+
else:
15+
for j in range(4):
16+
if i > j:
17+
tmp += i + j
18+
else:
19+
tmp += i - j
20+
21+
r.append(tmp)
22+
else:
23+
for i in range(4):
24+
r.append(i)
25+
return r
26+
27+
UnrollScf(structural).fixpoint(main)
28+
29+
num_adds = 0
30+
num_calls = 0
31+
32+
for op in main.callable_region.walk():
33+
if isinstance(op, py.Add):
34+
num_adds += 1
35+
elif isinstance(op, func.Call):
36+
num_calls += 1
37+
38+
assert num_adds == 10
39+
assert num_calls == 8

0 commit comments

Comments
 (0)