Skip to content

Commit 480ffc5

Browse files
authored
rewrite rule for unrolling scf.ifelse (#244)
also adjust the aggressive fold rule order.
1 parent c627202 commit 480ffc5

File tree

6 files changed

+93
-15
lines changed

6 files changed

+93
-15
lines changed

src/kirin/dialects/scf/constprop.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def if_else(
4141
interp_, frame, stmt, cond, body
4242
)
4343
frame.entries.update(body_frame.entries)
44+
if not body_frame.frame_is_not_pure and not isinstance(
45+
body.blocks[0].last_stmt, func.Return
46+
):
47+
frame.should_be_pure.add(stmt)
4448
return ret
4549
else:
4650
then_frame, then_results = self._prop_const_cond_ifelse(
@@ -80,11 +84,6 @@ def _prop_const_cond_ifelse(
8084
body_frame.entries.update(frame.entries)
8185
body_frame.set(body.blocks[0].args[0], cond)
8286
results = interp_.run_ssacfg_region(body_frame, body)
83-
84-
if not body_frame.frame_is_not_pure and not isinstance(
85-
body.blocks[0].last_stmt, func.Return
86-
):
87-
frame.should_be_pure.add(stmt)
8887
return body_frame, results
8988

9089
@interp.impl(For)

src/kirin/dialects/scf/stmts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def print_impl(self, printer: Printer) -> None:
8787
printer.plain_print(" else ", style="keyword")
8888
printer.print(self.else_body)
8989

90+
with printer.rich(style="comment"):
91+
printer.plain_print(f" -> purity={self.purity}")
92+
9093
def verify(self) -> None:
9194
from kirin.dialects.func import Return
9295

@@ -225,6 +228,8 @@ def print_impl(self, printer: Printer) -> None:
225228
printer.print_stmt(stmt)
226229
printer.print_newline()
227230
printer.plain_print("}")
231+
with printer.rich(style="comment"):
232+
printer.plain_print(f" -> purity={self.purity}")
228233

229234

230235
@statement(dialect=dialect)

src/kirin/dialects/scf/unroll.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
11
from kirin import ir
22
from kirin.analysis import const
3+
from kirin.dialects import func
34
from kirin.rewrite.abc import RewriteRule
45
from kirin.rewrite.result import RewriteResult
56
from kirin.dialects.py.constant import Constant
67

7-
from .stmts import For, Yield
8+
from .stmts import For, Yield, IfElse
9+
10+
11+
class PickIfElse(RewriteRule):
12+
13+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14+
if not isinstance(node, IfElse):
15+
return RewriteResult()
16+
17+
if not isinstance(hint := node.cond.hints.get("const"), const.Value):
18+
return RewriteResult()
19+
20+
if hint.data:
21+
return self.insert_body(node, node.then_body)
22+
else:
23+
return self.insert_body(node, node.else_body)
24+
25+
def insert_body(self, node: IfElse, body: ir.Region):
26+
body_block = body.clone().blocks[0]
27+
body_block.args[0].replace_by(node.cond)
28+
block_stmt = body_block.first_stmt
29+
while block_stmt and not block_stmt.has_trait(ir.IsTerminator):
30+
block_stmt.detach()
31+
block_stmt.insert_before(node)
32+
block_stmt = body_block.first_stmt
33+
34+
terminator = body_block.last_stmt
35+
if isinstance(terminator, Yield):
36+
for result, output in zip(node.results, terminator.values):
37+
result.replace_by(output)
38+
node.delete()
39+
return RewriteResult(has_done_something=True)
40+
elif isinstance(terminator, func.Return):
41+
block = node.parent
42+
assert block is not None
43+
stmt = block.last_stmt
44+
while stmt is not None and stmt is not node: # remove the rest of the block
45+
delete_stmt = stmt
46+
stmt = stmt.prev_stmt
47+
delete_stmt.delete()
48+
49+
terminator.detach()
50+
node.replace_by(terminator)
51+
return RewriteResult(has_done_something=True)
52+
else:
53+
return RewriteResult()
854

955

1056
class ForLoop(RewriteRule):

src/kirin/ir/nodes/region.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def clone(self, ssamap: dict[SSAValue, SSAValue] | None = None) -> Region:
152152
for stmt in block.stmts:
153153
new_stmt = stmt.from_stmt(
154154
stmt,
155-
args=[_ssamap[arg] for arg in stmt.args],
155+
args=[_ssamap.get(arg, arg) for arg in stmt.args],
156156
regions=[region.clone(_ssamap) for region in stmt.regions],
157157
successors=[
158158
successor_map[successor] for successor in stmt.successors
Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
1-
from dataclasses import dataclass
1+
from dataclasses import field, dataclass
22

33
from kirin.passes import Pass
4-
from kirin.rewrite import aggressive
4+
from kirin.rewrite import (
5+
Walk,
6+
Chain,
7+
Inline,
8+
Fixpoint,
9+
WrapConst,
10+
Call2Invoke,
11+
ConstantFold,
12+
CFGCompactify,
13+
InlineGetItem,
14+
InlineGetField,
15+
DeadCodeElimination,
16+
)
517
from kirin.analysis import const
618
from kirin.ir.method import Method
719
from kirin.rewrite.abc import RewriteResult
820

921

1022
@dataclass
1123
class Fold(Pass):
24+
constprop: const.Propagate = field(init=False)
25+
26+
def __post_init__(self):
27+
self.constprop = const.Propagate(self.dialects)
1228

1329
def unsafe_run(self, mt: Method) -> RewriteResult:
14-
constprop = const.Propagate(self.dialects)
15-
constprop_results, _ = constprop.run_analysis(mt)
16-
return aggressive.Fold(constprop_results).rewrite(mt.code)
30+
result = RewriteResult()
31+
frame, _ = self.constprop.run_analysis(mt)
32+
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
33+
rule = Chain(
34+
ConstantFold(),
35+
Call2Invoke(),
36+
InlineGetField(),
37+
InlineGetItem(),
38+
DeadCodeElimination(),
39+
)
40+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
41+
result = Walk(Inline(lambda _: True)).rewrite(mt.code).join(result)
42+
result = Fixpoint(CFGCompactify()).rewrite(mt.code).join(result)
43+
return result

src/kirin/print/printer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ def print_stmt(self, node: "ir.Statement"):
145145
def print_hint(
146146
self,
147147
*values: "ir.SSAValue",
148-
prefix: str = "//hint<",
148+
prefix: str = " // hint<",
149149
suffix: str = ">",
150150
):
151-
if not self.hint:
151+
if not self.hint or not values:
152152
return
153153

154154
self.plain_print(prefix)
@@ -157,7 +157,8 @@ def print_hint(
157157
if idx > 0:
158158
self.plain_print(", ")
159159

160-
if item.hints.get(self.hint):
160+
self.plain_print("=")
161+
if item.hints.get(self.hint) is not None:
161162
self.plain_print(repr(item.hints.get(self.hint)))
162163
else:
163164
self.plain_print("missing")

0 commit comments

Comments
 (0)