Skip to content

Commit f6e8b8b

Browse files
authored
Refactor Scf to Cf rewrite rules. (#467)
The first pass on this rule was a bit clunky and was trying to optimize the rewrite. I have refactored the rule now to be less optimal but more readable.
1 parent 3279c6a commit f6e8b8b

File tree

2 files changed

+141
-141
lines changed

2 files changed

+141
-141
lines changed

src/kirin/dialects/scf/scf2cf.py

Lines changed: 125 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,190 +1,181 @@
1-
from ... import ir
1+
from dataclasses import field, dataclass
2+
3+
from kirin import ir
4+
from kirin.rewrite.abc import RewriteRule, RewriteResult
5+
6+
from ..cf import Branch, ConditionalBranch
7+
from ..func import ConstantNone
28
from .stmts import For, Yield, IfElse
3-
from ...rewrite.abc import RewriteRule, RewriteResult
9+
from ..py.cmp import Is
10+
from ..py.iterable import Iter, Next
411

512

6-
class ScfToCfRule(RewriteRule):
13+
class ScfRule(RewriteRule):
714

8-
def rewrite_ifelse(
9-
self, node: ir.Region, block_idx: int, curr_block: ir.Block, stmt: IfElse
10-
):
11-
from kirin.dialects import cf
15+
def get_entr_and_exit_blks(self, node: For | IfElse):
16+
"""Get the enter and exit blocks for the given SCF node.
1217
13-
# create a new block for entering the if statement
14-
entry_block = ir.Block()
15-
for arg in curr_block.args:
16-
arg.replace_by(entry_block.args.append_from(arg.type, arg.name))
18+
The exit block is a new block that will be created to hold the
19+
statements that follow the SCF node in the current block and the
20+
enter block is a new block that will be created to hold the
21+
any logic required to enter the SCF node.
1722
18-
# delete the args of the old block and replace with the result of the # if statement
19-
for arg in curr_block.args:
20-
curr_block.args.delete(arg)
23+
"""
24+
# split the current block into two parts
25+
exit_block = ir.Block()
26+
stmt = node.next_stmt
27+
while stmt is not None:
28+
next_stmt = stmt.next_stmt
29+
stmt.detach()
30+
exit_block.stmts.append(stmt)
31+
stmt = next_stmt
2132

22-
for arg in stmt.results:
23-
arg.replace_by(curr_block.args.append_from(arg.type, arg.name))
33+
for result in node.results:
34+
result.replace_by(exit_block.args.append_from(result.type, result.name))
2435

25-
(then_block := stmt.then_body.blocks[0]).detach()
26-
(else_block := stmt.else_body.blocks[0]).detach()
36+
curr_block = node.parent_block
37+
assert isinstance(curr_block, ir.Block), "Node must be inside a block"
2738

28-
entry_block.stmts.append(
29-
cf.ConditionalBranch(
30-
cond=stmt.cond,
31-
then_arguments=tuple(stmt.args),
32-
then_successor=then_block,
33-
else_arguments=tuple(stmt.args),
34-
else_successor=else_block,
35-
)
39+
curr_block.stmts.append(
40+
Branch(arguments=(), successor=(entr_block := ir.Block()))
3641
)
3742

38-
# insert the then/else blocks and add branch to the current block
39-
# if the last statement of the then block is a yield
40-
if isinstance(last_stmt := else_block.last_stmt, Yield):
41-
last_stmt.replace_by(
42-
cf.Branch(
43-
arguments=tuple(last_stmt.args),
44-
successor=curr_block,
45-
)
46-
)
43+
return exit_block, entr_block
4744

48-
if isinstance(last_stmt := then_block.last_stmt, Yield):
49-
last_stmt.replace_by(
50-
cf.Branch(
51-
arguments=tuple(last_stmt.args),
52-
successor=curr_block,
53-
)
54-
)
45+
def get_curr_blk_info(self, node: For | IfElse) -> tuple[ir.Region, int]:
46+
"""Get the current region and the block index of the node in the region."""
47+
curr_block = node.parent_block
48+
region = node.parent_region
5549

56-
node.blocks.insert(block_idx, curr_block)
57-
node.blocks.insert(block_idx, else_block)
58-
node.blocks.insert(block_idx, then_block)
50+
assert isinstance(region, ir.Region), "Node must be inside a region"
51+
assert isinstance(curr_block, ir.Block), "Node must be inside a block"
5952

60-
curr_stmt = stmt
61-
next_stmt = stmt.prev_stmt
62-
curr_stmt.delete()
53+
block_idx = region._block_idx[curr_block]
54+
return region, block_idx
6355

64-
return next_stmt, entry_block
6556

66-
def rewrite_for(
67-
self, node: ir.Region, block_idx: int, curr_block: ir.Block, stmt: For
68-
):
69-
from kirin.dialects import cf, py, func
57+
class ForRule(ScfRule):
58+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
59+
if (
60+
not isinstance(node, For)
61+
# must be inside a callable statement
62+
or not isinstance(parent_stmt := node.parent_stmt, ir.Statement)
63+
or not parent_stmt.has_trait(ir.CallableStmtInterface)
64+
):
65+
return RewriteResult()
7066

71-
(body_block := stmt.body.blocks[0]).detach()
67+
region, block_idx = self.get_curr_blk_info(node)
68+
exit_block, entr_block = self.get_entr_and_exit_blks(node)
7269

73-
entry_block = ir.Block()
74-
for arg in curr_block.args:
75-
arg.replace_by(entry_block.args.append_from(arg.type, arg.name))
70+
(body_block := node.body.blocks[0]).detach()
7671

7772
# Get iterator from iterable object
78-
entry_block.stmts.append(iterable_stmt := py.iterable.Iter(stmt.iterable))
79-
entry_block.stmts.append(const_none := func.ConstantNone())
80-
last_stmt = entry_block.last_stmt
81-
entry_block.stmts.append(
82-
next_stmt := py.iterable.Next(iterable_stmt.expect_one_result())
83-
)
84-
entry_block.stmts.append(
85-
loop_cmp := py.cmp.Is(next_stmt.expect_one_result(), const_none.result)
73+
entr_block.stmts.append(iterable_stmt := Iter(node.iterable))
74+
entr_block.stmts.append(next_stmt := Next(iterable_stmt.expect_one_result()))
75+
entr_block.stmts.append(const_none := ConstantNone())
76+
entr_block.stmts.append(
77+
loop_cmp := Is(next_stmt.expect_one_result(), const_none.result)
8678
)
87-
entry_block.stmts.append(
88-
cf.ConditionalBranch(
79+
entr_block.stmts.append(
80+
ConditionalBranch(
8981
cond=loop_cmp.result,
90-
then_arguments=tuple(stmt.initializers),
91-
then_successor=curr_block,
82+
then_arguments=tuple(node.initializers),
83+
then_successor=exit_block,
9284
else_arguments=(next_stmt.expect_one_result(),)
93-
+ tuple(stmt.initializers),
85+
+ tuple(node.initializers),
9486
else_successor=body_block,
9587
)
9688
)
9789

98-
for arg in curr_block.args:
99-
curr_block.args.delete(arg)
100-
101-
for arg in stmt.results:
102-
arg.replace_by(curr_block.args.append_from(arg.type, arg.name))
103-
10490
if isinstance(last_stmt := body_block.last_stmt, Yield):
91+
(next_stmt := Next(iterable_stmt.expect_one_result())).insert_before(
92+
last_stmt
93+
)
94+
(const_none := ConstantNone()).insert_before(last_stmt)
10595
(
106-
next_stmt := py.iterable.Next(iterable_stmt.expect_one_result())
107-
).insert_before(last_stmt)
108-
(
109-
loop_cmp := py.cmp.Is(next_stmt.expect_one_result(), const_none.result)
96+
loop_cmp := Is(next_stmt.expect_one_result(), const_none.result)
11097
).insert_before(last_stmt)
11198
last_stmt.replace_by(
112-
cf.ConditionalBranch(
99+
ConditionalBranch(
113100
cond=loop_cmp.result,
114101
else_arguments=(next_stmt.expect_one_result(),)
115102
+ tuple(last_stmt.args),
116103
else_successor=body_block,
117104
then_arguments=tuple(last_stmt.args),
118-
then_successor=curr_block,
105+
then_successor=exit_block,
119106
)
120107
)
121108

122109
# insert the body block and add branch to the current block
123-
node.blocks.insert(block_idx, curr_block)
124-
node.blocks.insert(block_idx, body_block)
125-
126-
curr_stmt = stmt
127-
next_stmt = stmt.prev_stmt
128-
curr_stmt.delete()
110+
region.blocks.insert(block_idx + 1, exit_block)
111+
region.blocks.insert(block_idx + 1, body_block)
112+
region.blocks.insert(block_idx + 1, entr_block)
129113

130-
return next_stmt, entry_block
114+
node.delete()
131115

132-
def rewrite_ssacfg(self, node: ir.Region):
116+
return RewriteResult(has_done_something=True)
133117

134-
has_done_something = False
135118

136-
for block_idx in range(len(node.blocks)):
119+
class IfElseRule(ScfRule):
120+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
121+
if (
122+
not isinstance(node, IfElse)
123+
or not isinstance(parent_stmt := node.parent_stmt, ir.Statement)
124+
or not parent_stmt.has_trait(ir.CallableStmtInterface)
125+
):
126+
return RewriteResult()
137127

138-
block = node.blocks.pop(block_idx)
128+
region, block_idx = self.get_curr_blk_info(node)
129+
exit_block, entr_block = self.get_entr_and_exit_blks(node)
139130

140-
stmt = block.last_stmt
141-
if stmt is None:
142-
continue
131+
(then_block := node.then_body.blocks[0]).detach()
132+
(else_block := node.else_body.blocks[0]).detach()
133+
entr_block.stmts.append(
134+
ConditionalBranch(
135+
node.cond,
136+
then_arguments=tuple(node.args),
137+
then_successor=then_block,
138+
else_arguments=tuple(node.args),
139+
else_successor=else_block,
140+
)
141+
)
143142

144-
curr_block = ir.Block()
143+
if isinstance(last_stmt := then_block.last_stmt, Yield):
144+
last_stmt.replace_by(
145+
Branch(
146+
arguments=tuple(last_stmt.args),
147+
successor=exit_block,
148+
)
149+
)
145150

146-
for arg in block.args:
147-
arg.replace_by(curr_block.args.append_from(arg.type, arg.name))
151+
if isinstance(last_stmt := else_block.last_stmt, Yield):
152+
last_stmt.replace_by(
153+
Branch(
154+
arguments=tuple(last_stmt.args),
155+
successor=exit_block,
156+
)
157+
)
148158

149-
while stmt is not None:
150-
if isinstance(stmt, For):
151-
has_done_something = True
152-
stmt, curr_block = self.rewrite_for(
153-
node, block_idx, curr_block, stmt
154-
)
159+
# insert the new blocks
160+
region.blocks.insert(block_idx + 1, exit_block)
161+
region.blocks.insert(block_idx + 1, else_block)
162+
region.blocks.insert(block_idx + 1, then_block)
163+
region.blocks.insert(block_idx + 1, entr_block)
155164

156-
elif isinstance(stmt, IfElse):
157-
has_done_something = True
158-
stmt, curr_block = self.rewrite_ifelse(
159-
node, block_idx, curr_block, stmt
160-
)
161-
else:
162-
curr_stmt = stmt
163-
stmt = stmt.prev_stmt
164-
curr_stmt.detach()
165+
node.delete()
166+
return RewriteResult(has_done_something=True)
165167

166-
if curr_block.first_stmt is None:
167-
curr_block.stmts.append(curr_stmt)
168-
else:
169-
curr_stmt.insert_before(curr_block.first_stmt)
170168

171-
# if the last block is empty, remove it
172-
if curr_block.parent is None and curr_block.first_stmt is not None:
173-
node.blocks.insert(block_idx, curr_block)
169+
@dataclass
170+
class ScfToCfRule(RewriteRule):
174171

175-
return RewriteResult(has_done_something=has_done_something)
172+
for_rule: ForRule = field(default_factory=ForRule, init=False)
173+
if_else_rule: IfElseRule = field(default_factory=IfElseRule, init=False)
176174

177175
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
178-
if (
179-
isinstance(node, (For, IfElse))
180-
or not node.has_trait(ir.HasCFG)
181-
and not node.has_trait(ir.SSACFG)
182-
):
183-
# do not do rewrite in scf regions
176+
if isinstance(node, For):
177+
return self.for_rule.rewrite_Statement(node)
178+
elif isinstance(node, IfElse):
179+
return self.if_else_rule.rewrite_Statement(node)
180+
else:
184181
return RewriteResult()
185-
186-
result = RewriteResult()
187-
for region in node.regions:
188-
result = result.join(self.rewrite_ssacfg(region))
189-
190-
return result

test/dialects/scf/test_scf2cf.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,31 +99,39 @@ def test():
9999

100100
expected_callable_region = ir.Region(
101101
[
102+
curr_block := ir.Block(),
102103
entry_block := ir.Block(),
103104
body_block := ir.Block(),
104105
exit_block := ir.Block(),
105106
]
106107
)
107108

108-
entry_block.args.append_from(types.MethodType, "self")
109-
entry_block.stmts.append(j_start := py.Constant(value=0))
109+
curr_block.args.append_from(types.MethodType, "self")
110+
curr_block.stmts.append(j_start := py.Constant(value=0))
111+
110112
j_start.result.name = "j"
111-
entry_block.stmts.append(iter_start := py.Constant(value=0))
112-
entry_block.stmts.append(iter_end := py.Constant(value=10))
113-
entry_block.stmts.append(iter_step := py.Constant(value=1))
114-
entry_block.stmts.append(
113+
curr_block.stmts.append(iter_start := py.Constant(value=0))
114+
curr_block.stmts.append(iter_end := py.Constant(value=10))
115+
curr_block.stmts.append(iter_step := py.Constant(value=1))
116+
curr_block.stmts.append(
115117
range_stmt := ilist.stmts.Range(
116118
start=iter_start.result,
117119
stop=iter_end.result,
118120
step=iter_step.result,
119121
)
120122
)
121123
range_stmt.result.type = ilist.IListType[types.Int, types.Literal(10)]
124+
curr_block.stmts.append(
125+
cf.Branch(
126+
arguments=(),
127+
successor=entry_block,
128+
)
129+
)
122130
entry_block.stmts.append(iterable_stmt := py.iterable.Iter(range_stmt.result))
123-
entry_block.stmts.append(none_stmt := func.ConstantNone())
124131
entry_block.stmts.append(
125132
first_iter := py.iterable.Next(iterable_stmt.expect_one_result())
126133
)
134+
entry_block.stmts.append(none_stmt := func.ConstantNone())
127135
entry_block.stmts.append(
128136
loop_cmp := py.cmp.Is(first_iter.expect_one_result(), none_stmt.result)
129137
)
@@ -157,6 +165,7 @@ def test():
157165
body_block.stmts.append(
158166
next_iter := py.iterable.Next(iterable_stmt.expect_one_result())
159167
)
168+
body_block.stmts.append(none_stmt := func.ConstantNone())
160169
body_block.stmts.append(
161170
loop_cmp := py.cmp.Is(next_iter.expect_one_result(), none_stmt.result)
162171
)

0 commit comments

Comments
 (0)