Skip to content

Commit 3279c6a

Browse files
authored
Revert "fixing bug in which purity wasn't propagated properly" (#462)
Reverts #460 Turns out that semantically you have to make the IfElse body not pure if there is a return value because of a weird interaction with ConstantFold. Namely the fact that the return value of the if statement has no uses, and so it will be deleted.
1 parent 048f263 commit 3279c6a

File tree

2 files changed

+14
-170
lines changed

2 files changed

+14
-170
lines changed

src/kirin/dialects/scf/constprop.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from kirin import interp
44
from kirin.analysis import const
5+
from kirin.dialects import func
56

67
from .stmts import For, Yield, IfElse
78
from ._dialect import dialect
@@ -40,9 +41,12 @@ def if_else(
4041
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
4142
ret = interp_.frame_call_region(body_frame, stmt, body, cond)
4243
frame.entries.update(body_frame.entries)
43-
frame.frame_is_not_pure = (
44-
frame.frame_is_not_pure or body_frame.frame_is_not_pure
45-
)
44+
45+
if not body_frame.frame_is_not_pure and not isinstance(
46+
body.blocks[0].last_stmt, func.Return
47+
):
48+
frame.should_be_pure.add(stmt)
49+
return ret
4650
else:
4751
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
4852
then_results = interp_.frame_call_region(
@@ -58,28 +62,22 @@ def if_else(
5862
# parent frame variables value except cond
5963
frame.entries.update(then_frame.entries)
6064
frame.entries.update(else_frame.entries)
61-
# update frame purity
62-
# if either frame is not pure, then the whole if-else is not pure
63-
frame.frame_is_not_pure = (
64-
frame.frame_is_not_pure
65-
or then_frame.frame_is_not_pure
66-
or else_frame.frame_is_not_pure
67-
)
6865
# TODO: pick the non-return value
6966
if isinstance(then_results, interp.ReturnValue) and isinstance(
7067
else_results, interp.ReturnValue
7168
):
72-
ret = interp.ReturnValue(then_results.value.join(else_results.value))
69+
return interp.ReturnValue(then_results.value.join(else_results.value))
7370
elif isinstance(then_results, interp.ReturnValue):
7471
ret = else_results
7572
elif isinstance(else_results, interp.ReturnValue):
7673
ret = then_results
7774
else:
75+
if not (
76+
then_frame.frame_is_not_pure is True
77+
or else_frame.frame_is_not_pure is True
78+
):
79+
frame.should_be_pure.add(stmt)
7880
ret = interp_.join_results(then_results, else_results)
79-
80-
if not frame.frame_is_not_pure:
81-
frame.should_be_pure.add(stmt)
82-
8381
return ret
8482

8583
@interp.impl(For)
Lines changed: 1 addition & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from typing import Any
2-
31
from kirin.prelude import structural_no_opt
42
from kirin.analysis import const
5-
from kirin.dialects import scf, func, ilist
3+
from kirin.dialects import scf, func
64

75
prop = const.Propagate(structural_no_opt)
86

@@ -102,155 +100,3 @@ def simple_ifelse(x: int):
102100
assert isinstance(terminator, func.Return)
103101
assert isinstance(value := frame.entries[terminator.value], const.Value)
104102
assert value.data == 0
105-
106-
107-
def test_purity_1():
108-
109-
@structural_no_opt
110-
def test_func(src: ilist.IList[float, Any]):
111-
112-
def inner(i: int):
113-
if src[i] < 0:
114-
return 0.0
115-
elif src[i] < 1.0:
116-
return 1.0
117-
else:
118-
return 2.0
119-
120-
return ilist.map(inner, ilist.range(len(src)))
121-
122-
frame, ret = prop.run(test_func)
123-
124-
assert not frame.frame_is_not_pure, "function should be pure"
125-
126-
127-
def test_purity_2():
128-
129-
@structural_no_opt
130-
def test_func(src: ilist.IList[float, Any]):
131-
132-
def inner(i: int):
133-
value = 0.0
134-
if src[i] < 0:
135-
value = 0.0
136-
elif src[i] < 1.0:
137-
return 1.0
138-
else:
139-
value = 2.0
140-
141-
return value
142-
143-
return ilist.map(inner, ilist.range(len(src)))
144-
145-
frame, ret = prop.run(test_func)
146-
147-
assert not frame.frame_is_not_pure, "function should be pure"
148-
149-
150-
def test_purity_3():
151-
152-
@structural_no_opt
153-
def test_func(src: ilist.IList[float, Any]):
154-
155-
def inner(i: int):
156-
value = 0.0
157-
if src[i] < 0:
158-
value = 0.0
159-
elif src[i] < 1.0:
160-
return 1.0
161-
else:
162-
return 2.0
163-
164-
return value
165-
166-
return ilist.map(inner, ilist.range(len(src)))
167-
168-
frame, ret = prop.run(test_func)
169-
170-
assert not frame.frame_is_not_pure, "function should be pure"
171-
172-
173-
def test_purity_4():
174-
175-
@structural_no_opt
176-
def test_func(src: list[float]):
177-
178-
if True:
179-
return src
180-
else:
181-
src.append(2.0)
182-
return src
183-
184-
frame, ret = prop.run(test_func)
185-
186-
assert not frame.frame_is_not_pure, "function should be pure"
187-
188-
189-
def test_purity_5():
190-
191-
@structural_no_opt
192-
def test_func(src: list[float]):
193-
194-
if False:
195-
src.append(2.0)
196-
197-
return src
198-
199-
frame, ret = prop.run(test_func)
200-
201-
assert not frame.frame_is_not_pure, "function should be pure"
202-
203-
204-
def test_purity_6():
205-
206-
@structural_no_opt
207-
def test_func(src: list[float]):
208-
209-
if True:
210-
return src
211-
else:
212-
src.append(2.0)
213-
214-
return src
215-
216-
frame, ret = prop.run(test_func)
217-
218-
assert not frame.frame_is_not_pure, "function should be pure"
219-
220-
221-
def test_purity_7():
222-
223-
@structural_no_opt
224-
def test_func(src: list[float], cond: bool):
225-
226-
if cond:
227-
src.append(2.0)
228-
return src
229-
else:
230-
return src
231-
232-
frame, ret = prop.run(test_func)
233-
234-
assert frame.frame_is_not_pure, "function should not be pure"
235-
236-
237-
def test_purity_8():
238-
239-
@structural_no_opt
240-
def test_func(src: ilist.IList[float, Any], dst: ilist.IList[float, Any]):
241-
assert len(src) == len(dst), "src and dst must have the same length"
242-
243-
def inner(i: int):
244-
value = src[i]
245-
if src[i] < dst[i]:
246-
value = dst[i] - 3.0
247-
elif src[i] > dst[i]:
248-
return dst[i] + 3.0
249-
250-
return value
251-
252-
return ilist.map(inner, ilist.range(len(src)))
253-
254-
frame, ret = prop.run(test_func)
255-
256-
assert frame.frame_is_not_pure, "function should be pure"

0 commit comments

Comments
 (0)