Skip to content

Commit 4589d0a

Browse files
authored
Dl/scf lowering miss yield (#487)
Fixes issue #468
1 parent 3c0347d commit 4589d0a

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

src/kirin/dialects/scf/lowering.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,27 @@ def lower_If(self, state: lowering.State, node: ast.If) -> lowering.Result:
4545
yield_names: list[str] = []
4646
body_yields: list[ir.SSAValue] = []
4747
else_yields: list[ir.SSAValue] = []
48-
for name in body_frame.defs.keys():
49-
if name in else_frame.defs:
48+
all_names: set[str] = set(body_frame.defs.keys()) | (
49+
set(else_frame.defs.keys())
50+
)
51+
for name in all_names:
52+
if name in body_frame.defs and name in else_frame.defs:
5053
yield_names.append(name)
5154
body_yields.append(body_frame[name])
5255
else_yields.append(else_frame[name])
53-
elif (value := self._frame_or_any_parent_has_def(frame, name)) is not None:
56+
elif (
57+
name not in body_frame.defs
58+
and (value := self._frame_or_any_parent_has_def(frame, name))
59+
is not None
60+
):
61+
yield_names.append(name)
62+
body_yields.append(value)
63+
else_yields.append(else_frame[name])
64+
elif (
65+
name not in else_frame.defs
66+
and (value := self._frame_or_any_parent_has_def(frame, name))
67+
is not None
68+
):
5469
yield_names.append(name)
5570
body_yields.append(body_frame[name])
5671
else_yields.append(value)

test/dialects/scf/test_ifelse.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,38 @@ def main_nested_if(n: int):
110110

111111
assert main_nested_if(3) == 4 == main_nested_if2(3)
112112
assert main_nested_if(10) == 0 == main_nested_if2(8)
113+
114+
115+
def test_def_only_else():
116+
@kernel
117+
def main(n: int):
118+
c = 1.0
119+
if n <= 0:
120+
return 0.0
121+
else:
122+
c = 2.0
123+
return c
124+
125+
main.print()
126+
assert main(1) == 2.0
127+
assert main(0) == 0.0
128+
129+
130+
def test_def_only_else_nested():
131+
@kernel
132+
def main(n: int):
133+
c = 1.0
134+
if n <= 0:
135+
return 0.0
136+
else:
137+
if n <= 2:
138+
return 1.0
139+
else:
140+
c = 4.0
141+
return c
142+
143+
main.print()
144+
assert main(3) == 4.0
145+
assert main(2) == 1.0
146+
assert main(1) == 1.0
147+
assert main(0) == 0.0

0 commit comments

Comments
 (0)