Skip to content

Commit 697e58e

Browse files
authored
fix code block scope after if-else statement (#201)
1 parent 8352721 commit 697e58e

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

src/kirin/dialects/lowering/cf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result:
126126
)
127127
)
128128

129+
after_frame.defs.update(frame.defs)
129130
phi: set[str] = set()
130131
for name in if_frame.defs.keys():
131132
if frame.get(name):

test/lowering/test_for.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,90 @@
11
from kirin.prelude import basic_no_opt
2+
from kirin.dialects import cf, func
23
from kirin.lowering import Lowering
34

45
lowering = Lowering(basic_no_opt)
5-
range_a = range(10)
66

77

8-
def simple_loop(x):
9-
for i in range(10):
10-
for j in range(10):
11-
x = x + i + j
8+
def test_simple_loop():
9+
def simple_loop(x):
10+
for i in range(10):
11+
for j in range(10):
12+
x = x + i + j
1213

14+
code = lowering.run(simple_loop)
15+
assert isinstance(code, func.Function)
16+
assert isinstance(stmt := code.body.blocks[0].last_stmt, cf.ConditionalBranch)
17+
assert stmt.then_arguments[0] is code.body.blocks[0].args[1]
18+
assert stmt.then_successor is code.body.blocks[4]
19+
assert stmt.else_arguments[0] is code.body.blocks[0].stmts.at(6).results[0]
20+
assert stmt.else_arguments[1] is code.body.blocks[0].args[1]
21+
assert stmt.else_successor is code.body.blocks[1]
1322

14-
code = lowering.run(simple_loop)
15-
# code.print()
23+
assert isinstance(stmt := code.body.blocks[1].last_stmt, cf.ConditionalBranch)
24+
assert stmt.then_arguments[0] is code.body.blocks[1].args[1]
25+
assert stmt.then_arguments[1] is code.body.blocks[1].args[0]
26+
assert stmt.else_arguments[0] is code.body.blocks[1].stmts.at(-3).results[0]
27+
assert stmt.else_arguments[1] is code.body.blocks[1].args[1]
28+
assert stmt.else_arguments[2] is code.body.blocks[1].args[0]
29+
assert stmt.else_successor is code.body.blocks[2]
1630

31+
assert isinstance(stmt := code.body.blocks[2].last_stmt, cf.ConditionalBranch)
32+
var_x = code.body.blocks[2].stmts.at(1).results[0]
33+
var_i = code.body.blocks[2].args[2]
34+
assert stmt.then_arguments[0] is var_x
35+
assert stmt.then_arguments[1] is var_i
36+
assert stmt.then_successor is code.body.blocks[3]
37+
assert stmt.else_arguments[0] is code.body.blocks[2].stmts.at(-3).results[0]
38+
assert stmt.else_arguments[1] is var_x
39+
assert stmt.else_arguments[2] is var_i
40+
# code.print()
1741

18-
def branch_pass():
19-
if True:
20-
pass
21-
else:
22-
pass
2342

43+
def test_branch_pass():
44+
def branch_pass():
45+
if True:
46+
pass
47+
else:
48+
pass
2449

25-
code = lowering.run(branch_pass, compactify=False)
50+
code = lowering.run(branch_pass)
51+
assert isinstance(code, func.Function)
52+
assert isinstance(code.body.blocks[0].last_stmt, func.Return)
53+
# code.print()
54+
55+
56+
def test_side_effect():
57+
def side_effect(reg, n: int):
58+
if n == 0:
59+
return
60+
61+
for i in range(10):
62+
reg[0] = i
63+
64+
code = lowering.run(side_effect)
65+
assert isinstance(code, func.Function)
66+
assert isinstance(stmt := code.body.blocks[0].last_stmt, cf.ConditionalBranch)
67+
assert stmt.then_arguments[0] is code.body.blocks[0].stmts.at(-2).results[0]
68+
assert stmt.then_successor is code.body.blocks[1]
69+
assert stmt.else_arguments == ()
70+
assert stmt.else_successor is code.body.blocks[2]
71+
72+
assert isinstance(code.body.blocks[1].last_stmt, func.Return)
73+
74+
assert isinstance(stmt := code.body.blocks[2].last_stmt, cf.ConditionalBranch)
75+
reg = code.body.blocks[0].args[1]
76+
assert stmt.then_arguments[0] is reg
77+
assert stmt.then_successor is code.body.blocks[4]
78+
assert stmt.else_arguments[0] is code.body.blocks[2].stmts.at(-3).results[0]
79+
assert stmt.else_arguments[1] is reg
80+
assert stmt.else_successor is code.body.blocks[3]
81+
82+
assert isinstance(stmt := code.body.blocks[3].last_stmt, cf.ConditionalBranch)
83+
reg = code.body.blocks[3].args[1]
84+
assert stmt.then_arguments[0] is reg
85+
assert stmt.then_successor is code.body.blocks[4]
86+
assert stmt.else_arguments[0] is code.body.blocks[3].stmts.at(-3).results[0]
87+
assert stmt.else_arguments[1] is reg
88+
89+
assert isinstance(code.body.blocks[4].last_stmt, func.Return)
90+
# code.print()

0 commit comments

Comments
 (0)