|
1 | | -from ... import ir |
| 1 | +from dataclasses import field, dataclass |
| 2 | + |
| 3 | +from kirin import ir |
| 4 | +from kirin.rewrite.abc import RewriteRule, RewriteResult |
| 5 | + |
| 6 | +from ..cf import Branch, ConditionalBranch |
| 7 | +from ..func import ConstantNone |
2 | 8 | from .stmts import For, Yield, IfElse |
3 | | -from ...rewrite.abc import RewriteRule, RewriteResult |
| 9 | +from ..py.cmp import Is |
| 10 | +from ..py.iterable import Iter, Next |
4 | 11 |
|
5 | 12 |
|
6 | | -class ScfToCfRule(RewriteRule): |
| 13 | +class ScfRule(RewriteRule): |
7 | 14 |
|
8 | | - def rewrite_ifelse( |
9 | | - self, node: ir.Region, block_idx: int, curr_block: ir.Block, stmt: IfElse |
10 | | - ): |
11 | | - from kirin.dialects import cf |
| 15 | + def get_entr_and_exit_blks(self, node: For | IfElse): |
| 16 | + """Get the enter and exit blocks for the given SCF node. |
12 | 17 |
|
13 | | - # create a new block for entering the if statement |
14 | | - entry_block = ir.Block() |
15 | | - for arg in curr_block.args: |
16 | | - arg.replace_by(entry_block.args.append_from(arg.type, arg.name)) |
| 18 | + The exit block is a new block that will be created to hold the |
| 19 | + statements that follow the SCF node in the current block and the |
| 20 | + enter block is a new block that will be created to hold the |
| 21 | + any logic required to enter the SCF node. |
17 | 22 |
|
18 | | - # delete the args of the old block and replace with the result of the # if statement |
19 | | - for arg in curr_block.args: |
20 | | - curr_block.args.delete(arg) |
| 23 | + """ |
| 24 | + # split the current block into two parts |
| 25 | + exit_block = ir.Block() |
| 26 | + stmt = node.next_stmt |
| 27 | + while stmt is not None: |
| 28 | + next_stmt = stmt.next_stmt |
| 29 | + stmt.detach() |
| 30 | + exit_block.stmts.append(stmt) |
| 31 | + stmt = next_stmt |
21 | 32 |
|
22 | | - for arg in stmt.results: |
23 | | - arg.replace_by(curr_block.args.append_from(arg.type, arg.name)) |
| 33 | + for result in node.results: |
| 34 | + result.replace_by(exit_block.args.append_from(result.type, result.name)) |
24 | 35 |
|
25 | | - (then_block := stmt.then_body.blocks[0]).detach() |
26 | | - (else_block := stmt.else_body.blocks[0]).detach() |
| 36 | + curr_block = node.parent_block |
| 37 | + assert isinstance(curr_block, ir.Block), "Node must be inside a block" |
27 | 38 |
|
28 | | - entry_block.stmts.append( |
29 | | - cf.ConditionalBranch( |
30 | | - cond=stmt.cond, |
31 | | - then_arguments=tuple(stmt.args), |
32 | | - then_successor=then_block, |
33 | | - else_arguments=tuple(stmt.args), |
34 | | - else_successor=else_block, |
35 | | - ) |
| 39 | + curr_block.stmts.append( |
| 40 | + Branch(arguments=(), successor=(entr_block := ir.Block())) |
36 | 41 | ) |
37 | 42 |
|
38 | | - # insert the then/else blocks and add branch to the current block |
39 | | - # if the last statement of the then block is a yield |
40 | | - if isinstance(last_stmt := else_block.last_stmt, Yield): |
41 | | - last_stmt.replace_by( |
42 | | - cf.Branch( |
43 | | - arguments=tuple(last_stmt.args), |
44 | | - successor=curr_block, |
45 | | - ) |
46 | | - ) |
| 43 | + return exit_block, entr_block |
47 | 44 |
|
48 | | - if isinstance(last_stmt := then_block.last_stmt, Yield): |
49 | | - last_stmt.replace_by( |
50 | | - cf.Branch( |
51 | | - arguments=tuple(last_stmt.args), |
52 | | - successor=curr_block, |
53 | | - ) |
54 | | - ) |
| 45 | + def get_curr_blk_info(self, node: For | IfElse) -> tuple[ir.Region, int]: |
| 46 | + """Get the current region and the block index of the node in the region.""" |
| 47 | + curr_block = node.parent_block |
| 48 | + region = node.parent_region |
55 | 49 |
|
56 | | - node.blocks.insert(block_idx, curr_block) |
57 | | - node.blocks.insert(block_idx, else_block) |
58 | | - node.blocks.insert(block_idx, then_block) |
| 50 | + assert isinstance(region, ir.Region), "Node must be inside a region" |
| 51 | + assert isinstance(curr_block, ir.Block), "Node must be inside a block" |
59 | 52 |
|
60 | | - curr_stmt = stmt |
61 | | - next_stmt = stmt.prev_stmt |
62 | | - curr_stmt.delete() |
| 53 | + block_idx = region._block_idx[curr_block] |
| 54 | + return region, block_idx |
63 | 55 |
|
64 | | - return next_stmt, entry_block |
65 | 56 |
|
66 | | - def rewrite_for( |
67 | | - self, node: ir.Region, block_idx: int, curr_block: ir.Block, stmt: For |
68 | | - ): |
69 | | - from kirin.dialects import cf, py, func |
| 57 | +class ForRule(ScfRule): |
| 58 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 59 | + if ( |
| 60 | + not isinstance(node, For) |
| 61 | + # must be inside a callable statement |
| 62 | + or not isinstance(parent_stmt := node.parent_stmt, ir.Statement) |
| 63 | + or not parent_stmt.has_trait(ir.CallableStmtInterface) |
| 64 | + ): |
| 65 | + return RewriteResult() |
70 | 66 |
|
71 | | - (body_block := stmt.body.blocks[0]).detach() |
| 67 | + region, block_idx = self.get_curr_blk_info(node) |
| 68 | + exit_block, entr_block = self.get_entr_and_exit_blks(node) |
72 | 69 |
|
73 | | - entry_block = ir.Block() |
74 | | - for arg in curr_block.args: |
75 | | - arg.replace_by(entry_block.args.append_from(arg.type, arg.name)) |
| 70 | + (body_block := node.body.blocks[0]).detach() |
76 | 71 |
|
77 | 72 | # Get iterator from iterable object |
78 | | - entry_block.stmts.append(iterable_stmt := py.iterable.Iter(stmt.iterable)) |
79 | | - entry_block.stmts.append(const_none := func.ConstantNone()) |
80 | | - last_stmt = entry_block.last_stmt |
81 | | - entry_block.stmts.append( |
82 | | - next_stmt := py.iterable.Next(iterable_stmt.expect_one_result()) |
83 | | - ) |
84 | | - entry_block.stmts.append( |
85 | | - loop_cmp := py.cmp.Is(next_stmt.expect_one_result(), const_none.result) |
| 73 | + entr_block.stmts.append(iterable_stmt := Iter(node.iterable)) |
| 74 | + entr_block.stmts.append(next_stmt := Next(iterable_stmt.expect_one_result())) |
| 75 | + entr_block.stmts.append(const_none := ConstantNone()) |
| 76 | + entr_block.stmts.append( |
| 77 | + loop_cmp := Is(next_stmt.expect_one_result(), const_none.result) |
86 | 78 | ) |
87 | | - entry_block.stmts.append( |
88 | | - cf.ConditionalBranch( |
| 79 | + entr_block.stmts.append( |
| 80 | + ConditionalBranch( |
89 | 81 | cond=loop_cmp.result, |
90 | | - then_arguments=tuple(stmt.initializers), |
91 | | - then_successor=curr_block, |
| 82 | + then_arguments=tuple(node.initializers), |
| 83 | + then_successor=exit_block, |
92 | 84 | else_arguments=(next_stmt.expect_one_result(),) |
93 | | - + tuple(stmt.initializers), |
| 85 | + + tuple(node.initializers), |
94 | 86 | else_successor=body_block, |
95 | 87 | ) |
96 | 88 | ) |
97 | 89 |
|
98 | | - for arg in curr_block.args: |
99 | | - curr_block.args.delete(arg) |
100 | | - |
101 | | - for arg in stmt.results: |
102 | | - arg.replace_by(curr_block.args.append_from(arg.type, arg.name)) |
103 | | - |
104 | 90 | if isinstance(last_stmt := body_block.last_stmt, Yield): |
| 91 | + (next_stmt := Next(iterable_stmt.expect_one_result())).insert_before( |
| 92 | + last_stmt |
| 93 | + ) |
| 94 | + (const_none := ConstantNone()).insert_before(last_stmt) |
105 | 95 | ( |
106 | | - next_stmt := py.iterable.Next(iterable_stmt.expect_one_result()) |
107 | | - ).insert_before(last_stmt) |
108 | | - ( |
109 | | - loop_cmp := py.cmp.Is(next_stmt.expect_one_result(), const_none.result) |
| 96 | + loop_cmp := Is(next_stmt.expect_one_result(), const_none.result) |
110 | 97 | ).insert_before(last_stmt) |
111 | 98 | last_stmt.replace_by( |
112 | | - cf.ConditionalBranch( |
| 99 | + ConditionalBranch( |
113 | 100 | cond=loop_cmp.result, |
114 | 101 | else_arguments=(next_stmt.expect_one_result(),) |
115 | 102 | + tuple(last_stmt.args), |
116 | 103 | else_successor=body_block, |
117 | 104 | then_arguments=tuple(last_stmt.args), |
118 | | - then_successor=curr_block, |
| 105 | + then_successor=exit_block, |
119 | 106 | ) |
120 | 107 | ) |
121 | 108 |
|
122 | 109 | # insert the body block and add branch to the current block |
123 | | - node.blocks.insert(block_idx, curr_block) |
124 | | - node.blocks.insert(block_idx, body_block) |
125 | | - |
126 | | - curr_stmt = stmt |
127 | | - next_stmt = stmt.prev_stmt |
128 | | - curr_stmt.delete() |
| 110 | + region.blocks.insert(block_idx + 1, exit_block) |
| 111 | + region.blocks.insert(block_idx + 1, body_block) |
| 112 | + region.blocks.insert(block_idx + 1, entr_block) |
129 | 113 |
|
130 | | - return next_stmt, entry_block |
| 114 | + node.delete() |
131 | 115 |
|
132 | | - def rewrite_ssacfg(self, node: ir.Region): |
| 116 | + return RewriteResult(has_done_something=True) |
133 | 117 |
|
134 | | - has_done_something = False |
135 | 118 |
|
136 | | - for block_idx in range(len(node.blocks)): |
| 119 | +class IfElseRule(ScfRule): |
| 120 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 121 | + if ( |
| 122 | + not isinstance(node, IfElse) |
| 123 | + or not isinstance(parent_stmt := node.parent_stmt, ir.Statement) |
| 124 | + or not parent_stmt.has_trait(ir.CallableStmtInterface) |
| 125 | + ): |
| 126 | + return RewriteResult() |
137 | 127 |
|
138 | | - block = node.blocks.pop(block_idx) |
| 128 | + region, block_idx = self.get_curr_blk_info(node) |
| 129 | + exit_block, entr_block = self.get_entr_and_exit_blks(node) |
139 | 130 |
|
140 | | - stmt = block.last_stmt |
141 | | - if stmt is None: |
142 | | - continue |
| 131 | + (then_block := node.then_body.blocks[0]).detach() |
| 132 | + (else_block := node.else_body.blocks[0]).detach() |
| 133 | + entr_block.stmts.append( |
| 134 | + ConditionalBranch( |
| 135 | + node.cond, |
| 136 | + then_arguments=tuple(node.args), |
| 137 | + then_successor=then_block, |
| 138 | + else_arguments=tuple(node.args), |
| 139 | + else_successor=else_block, |
| 140 | + ) |
| 141 | + ) |
143 | 142 |
|
144 | | - curr_block = ir.Block() |
| 143 | + if isinstance(last_stmt := then_block.last_stmt, Yield): |
| 144 | + last_stmt.replace_by( |
| 145 | + Branch( |
| 146 | + arguments=tuple(last_stmt.args), |
| 147 | + successor=exit_block, |
| 148 | + ) |
| 149 | + ) |
145 | 150 |
|
146 | | - for arg in block.args: |
147 | | - arg.replace_by(curr_block.args.append_from(arg.type, arg.name)) |
| 151 | + if isinstance(last_stmt := else_block.last_stmt, Yield): |
| 152 | + last_stmt.replace_by( |
| 153 | + Branch( |
| 154 | + arguments=tuple(last_stmt.args), |
| 155 | + successor=exit_block, |
| 156 | + ) |
| 157 | + ) |
148 | 158 |
|
149 | | - while stmt is not None: |
150 | | - if isinstance(stmt, For): |
151 | | - has_done_something = True |
152 | | - stmt, curr_block = self.rewrite_for( |
153 | | - node, block_idx, curr_block, stmt |
154 | | - ) |
| 159 | + # insert the new blocks |
| 160 | + region.blocks.insert(block_idx + 1, exit_block) |
| 161 | + region.blocks.insert(block_idx + 1, else_block) |
| 162 | + region.blocks.insert(block_idx + 1, then_block) |
| 163 | + region.blocks.insert(block_idx + 1, entr_block) |
155 | 164 |
|
156 | | - elif isinstance(stmt, IfElse): |
157 | | - has_done_something = True |
158 | | - stmt, curr_block = self.rewrite_ifelse( |
159 | | - node, block_idx, curr_block, stmt |
160 | | - ) |
161 | | - else: |
162 | | - curr_stmt = stmt |
163 | | - stmt = stmt.prev_stmt |
164 | | - curr_stmt.detach() |
| 165 | + node.delete() |
| 166 | + return RewriteResult(has_done_something=True) |
165 | 167 |
|
166 | | - if curr_block.first_stmt is None: |
167 | | - curr_block.stmts.append(curr_stmt) |
168 | | - else: |
169 | | - curr_stmt.insert_before(curr_block.first_stmt) |
170 | 168 |
|
171 | | - # if the last block is empty, remove it |
172 | | - if curr_block.parent is None and curr_block.first_stmt is not None: |
173 | | - node.blocks.insert(block_idx, curr_block) |
| 169 | +@dataclass |
| 170 | +class ScfToCfRule(RewriteRule): |
174 | 171 |
|
175 | | - return RewriteResult(has_done_something=has_done_something) |
| 172 | + for_rule: ForRule = field(default_factory=ForRule, init=False) |
| 173 | + if_else_rule: IfElseRule = field(default_factory=IfElseRule, init=False) |
176 | 174 |
|
177 | 175 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
178 | | - if ( |
179 | | - isinstance(node, (For, IfElse)) |
180 | | - or not node.has_trait(ir.HasCFG) |
181 | | - and not node.has_trait(ir.SSACFG) |
182 | | - ): |
183 | | - # do not do rewrite in scf regions |
| 176 | + if isinstance(node, For): |
| 177 | + return self.for_rule.rewrite_Statement(node) |
| 178 | + elif isinstance(node, IfElse): |
| 179 | + return self.if_else_rule.rewrite_Statement(node) |
| 180 | + else: |
184 | 181 | return RewriteResult() |
185 | | - |
186 | | - result = RewriteResult() |
187 | | - for region in node.regions: |
188 | | - result = result.join(self.rewrite_ssacfg(region)) |
189 | | - |
190 | | - return result |
0 commit comments