Skip to content

Commit 048f263

Browse files
authored
fixing bug in which purity wasn't propagated properly (#460)
The following function is definitely pure: ```python @structural_no_opt def test_func(src: ilist.IList[float, Any]): def inner(i: int): if src[i] < 0: return 0.0 elif src[i] < 1.0: return 1.0 else: return 2.0 return ilist.map(inner, ilist.range(len(src))) ``` but there is a bug in const prop when you have return values. The propagation of frame purity was not occurring, making the constant property of this function impure, which created other issues regarding folding and DCE.
1 parent 2822bd1 commit 048f263

File tree

2 files changed

+170
-14
lines changed

2 files changed

+170
-14
lines changed

src/kirin/dialects/scf/constprop.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from kirin import interp
44
from kirin.analysis import const
5-
from kirin.dialects import func
65

76
from .stmts import For, Yield, IfElse
87
from ._dialect import dialect
@@ -41,12 +40,9 @@ def if_else(
4140
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
4241
ret = interp_.frame_call_region(body_frame, stmt, body, cond)
4342
frame.entries.update(body_frame.entries)
44-
45-
if not body_frame.frame_is_not_pure and not isinstance(
46-
body.blocks[0].last_stmt, func.Return
47-
):
48-
frame.should_be_pure.add(stmt)
49-
return ret
43+
frame.frame_is_not_pure = (
44+
frame.frame_is_not_pure or body_frame.frame_is_not_pure
45+
)
5046
else:
5147
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
5248
then_results = interp_.frame_call_region(
@@ -62,22 +58,28 @@ def if_else(
6258
# parent frame variables value except cond
6359
frame.entries.update(then_frame.entries)
6460
frame.entries.update(else_frame.entries)
61+
# update frame purity
62+
# if either frame is not pure, then the whole if-else is not pure
63+
frame.frame_is_not_pure = (
64+
frame.frame_is_not_pure
65+
or then_frame.frame_is_not_pure
66+
or else_frame.frame_is_not_pure
67+
)
6568
# TODO: pick the non-return value
6669
if isinstance(then_results, interp.ReturnValue) and isinstance(
6770
else_results, interp.ReturnValue
6871
):
69-
return interp.ReturnValue(then_results.value.join(else_results.value))
72+
ret = interp.ReturnValue(then_results.value.join(else_results.value))
7073
elif isinstance(then_results, interp.ReturnValue):
7174
ret = else_results
7275
elif isinstance(else_results, interp.ReturnValue):
7376
ret = then_results
7477
else:
75-
if not (
76-
then_frame.frame_is_not_pure is True
77-
or else_frame.frame_is_not_pure is True
78-
):
79-
frame.should_be_pure.add(stmt)
8078
ret = interp_.join_results(then_results, else_results)
79+
80+
if not frame.frame_is_not_pure:
81+
frame.should_be_pure.add(stmt)
82+
8183
return ret
8284

8385
@interp.impl(For)

test/dialects/scf/test_constprop.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Any
2+
13
from kirin.prelude import structural_no_opt
24
from kirin.analysis import const
3-
from kirin.dialects import scf, func
5+
from kirin.dialects import scf, func, ilist
46

57
prop = const.Propagate(structural_no_opt)
68

@@ -100,3 +102,155 @@ def simple_ifelse(x: int):
100102
assert isinstance(terminator, func.Return)
101103
assert isinstance(value := frame.entries[terminator.value], const.Value)
102104
assert value.data == 0
105+
106+
107+
def test_purity_1():
108+
109+
@structural_no_opt
110+
def test_func(src: ilist.IList[float, Any]):
111+
112+
def inner(i: int):
113+
if src[i] < 0:
114+
return 0.0
115+
elif src[i] < 1.0:
116+
return 1.0
117+
else:
118+
return 2.0
119+
120+
return ilist.map(inner, ilist.range(len(src)))
121+
122+
frame, ret = prop.run(test_func)
123+
124+
assert not frame.frame_is_not_pure, "function should be pure"
125+
126+
127+
def test_purity_2():
128+
129+
@structural_no_opt
130+
def test_func(src: ilist.IList[float, Any]):
131+
132+
def inner(i: int):
133+
value = 0.0
134+
if src[i] < 0:
135+
value = 0.0
136+
elif src[i] < 1.0:
137+
return 1.0
138+
else:
139+
value = 2.0
140+
141+
return value
142+
143+
return ilist.map(inner, ilist.range(len(src)))
144+
145+
frame, ret = prop.run(test_func)
146+
147+
assert not frame.frame_is_not_pure, "function should be pure"
148+
149+
150+
def test_purity_3():
151+
152+
@structural_no_opt
153+
def test_func(src: ilist.IList[float, Any]):
154+
155+
def inner(i: int):
156+
value = 0.0
157+
if src[i] < 0:
158+
value = 0.0
159+
elif src[i] < 1.0:
160+
return 1.0
161+
else:
162+
return 2.0
163+
164+
return value
165+
166+
return ilist.map(inner, ilist.range(len(src)))
167+
168+
frame, ret = prop.run(test_func)
169+
170+
assert not frame.frame_is_not_pure, "function should be pure"
171+
172+
173+
def test_purity_4():
174+
175+
@structural_no_opt
176+
def test_func(src: list[float]):
177+
178+
if True:
179+
return src
180+
else:
181+
src.append(2.0)
182+
return src
183+
184+
frame, ret = prop.run(test_func)
185+
186+
assert not frame.frame_is_not_pure, "function should be pure"
187+
188+
189+
def test_purity_5():
190+
191+
@structural_no_opt
192+
def test_func(src: list[float]):
193+
194+
if False:
195+
src.append(2.0)
196+
197+
return src
198+
199+
frame, ret = prop.run(test_func)
200+
201+
assert not frame.frame_is_not_pure, "function should be pure"
202+
203+
204+
def test_purity_6():
205+
206+
@structural_no_opt
207+
def test_func(src: list[float]):
208+
209+
if True:
210+
return src
211+
else:
212+
src.append(2.0)
213+
214+
return src
215+
216+
frame, ret = prop.run(test_func)
217+
218+
assert not frame.frame_is_not_pure, "function should be pure"
219+
220+
221+
def test_purity_7():
222+
223+
@structural_no_opt
224+
def test_func(src: list[float], cond: bool):
225+
226+
if cond:
227+
src.append(2.0)
228+
return src
229+
else:
230+
return src
231+
232+
frame, ret = prop.run(test_func)
233+
234+
assert frame.frame_is_not_pure, "function should not be pure"
235+
236+
237+
def test_purity_8():
238+
239+
@structural_no_opt
240+
def test_func(src: ilist.IList[float, Any], dst: ilist.IList[float, Any]):
241+
assert len(src) == len(dst), "src and dst must have the same length"
242+
243+
def inner(i: int):
244+
value = src[i]
245+
if src[i] < dst[i]:
246+
value = dst[i] - 3.0
247+
elif src[i] > dst[i]:
248+
return dst[i] + 3.0
249+
250+
return value
251+
252+
return ilist.map(inner, ilist.range(len(src)))
253+
254+
frame, ret = prop.run(test_func)
255+
256+
assert frame.frame_is_not_pure, "function should be pure"

0 commit comments

Comments
 (0)