Skip to content

Commit e154b70

Browse files
authored
scf.for loop unroll (#237)
this PR implements loop unroll for `scf.For`, note that the loop body here asssumes no branching inside thus the rewrite is just copy and replacing SSAValues from the previous expanded body.
1 parent 948d8a0 commit e154b70

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

src/kirin/dialects/scf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from . import (
1414
interp as interp,
15+
unroll as unroll,
1516
lowering as lowering,
1617
constprop as constprop,
1718
typeinfer as typeinfer,

src/kirin/dialects/scf/unroll.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from kirin import ir
2+
from kirin.analysis import const
3+
from kirin.rewrite.abc import RewriteRule
4+
from kirin.rewrite.result import RewriteResult
5+
from kirin.dialects.py.constant import Constant
6+
7+
from .stmts import For, Yield
8+
9+
10+
class ForLoop(RewriteRule):
11+
12+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
13+
if not isinstance(node, For):
14+
return RewriteResult()
15+
16+
# TODO: support for PartialTuple and IList with known length
17+
if not isinstance(hint := node.iterable.hints.get("const"), const.Value):
18+
return RewriteResult()
19+
20+
loop_vars = node.initializers
21+
for item in hint.data:
22+
body = node.body.clone()
23+
block = body.blocks[0]
24+
item_stmt = Constant(item)
25+
item_stmt.insert_before(node)
26+
block.args[0].replace_by(item_stmt.result)
27+
for var, input in zip(block.args[1:], loop_vars):
28+
var.replace_by(input)
29+
30+
block_stmt = block.first_stmt
31+
while block_stmt and not block_stmt.has_trait(ir.IsTerminator):
32+
block_stmt.detach()
33+
block_stmt.insert_before(node)
34+
block_stmt = block.first_stmt
35+
36+
terminator = block.last_stmt
37+
# we assume Yield has the same # of values as initializers
38+
# TODO: check this in validation
39+
if isinstance(terminator, Yield):
40+
loop_vars = terminator.values
41+
42+
for result, output in zip(node.results, loop_vars):
43+
result.replace_by(output)
44+
node.delete()
45+
return RewriteResult(has_done_something=True)

test/dialects/scf/test_unroll.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from kirin.passes import Fold
2+
from kirin.prelude import structural_no_opt
3+
from kirin.rewrite import Walk
4+
from kirin.dialects import py, scf, func
5+
6+
7+
def test_simple_loop_unroll():
8+
@structural_no_opt
9+
def simple_loop(x):
10+
for i in range(3):
11+
x = x + i
12+
return x
13+
14+
fold = Fold(structural_no_opt)
15+
fold(simple_loop)
16+
Walk(scf.unroll.ForLoop()).rewrite(simple_loop.code)
17+
assert len(simple_loop.callable_region.blocks) == 1
18+
stmts = simple_loop.callable_region.blocks[0].stmts
19+
assert isinstance(stmts.at(0), py.Constant)
20+
assert isinstance(stmts.at(1), py.Constant)
21+
assert isinstance(stmts.at(2), py.Add)
22+
assert isinstance(stmts.at(3), py.Constant)
23+
assert isinstance(stmts.at(4), py.Add)
24+
assert isinstance(stmts.at(5), py.Constant)
25+
assert isinstance(stmts.at(6), py.Add)
26+
assert isinstance(stmts.at(7), func.Return)
27+
assert simple_loop(1) == 4

0 commit comments

Comments
 (0)