Skip to content

Commit f7474f2

Browse files
authored
Fix bug that caused error in type inference when returning from if (#427)
@weinbe58 @Roger-luo I'm still not 100% certain what the issue was, but this fixes #402 . I think it's because using the `frame` instead of creating a new one meant it still had the function as `node` rather than the actual `IfElse` statement. So we do need to create a new one here. Is that correct? Please take a look and let me know whether that explanation is correct, I'd like to take this opportunity to learn ;)
1 parent 80af27d commit f7474f2

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/kirin/dialects/scf/absint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,10 @@ def _infer_if_else_cond(
6262
if isinstance(body_term, func.Return):
6363
frame.worklist.append(interp.Successor(body_block, frame.get(stmt.cond)))
6464
return
65-
return interp_.frame_call_region(frame, stmt, body, frame.get(stmt.cond))
65+
66+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
67+
ret = interp_.frame_call_region(
68+
body_frame, stmt, body, frame.get(stmt.cond)
69+
)
70+
frame.entries.update(body_frame.entries)
71+
return ret

test/analysis/dataflow/typeinfer/test_inter_method.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ def foo(x: int):
1010
return x - 1.0
1111

1212

13-
@basic(typeinfer=True)
13+
@basic(typeinfer=True, no_raise=False)
1414
def main(x: int):
1515
return foo(x)
1616

1717

18-
@basic(typeinfer=True)
18+
@basic(typeinfer=True, no_raise=False)
1919
def moo(x):
2020
return foo(x)
2121

@@ -28,3 +28,18 @@ def test_inter_method_infer():
2828
assert foo.arg_types[0] == types.Int
2929
assert foo.inferred is False
3030
assert foo.return_type is types.Any
31+
32+
33+
def test_infer_if_return():
34+
from kirin.prelude import structural
35+
36+
@structural(typeinfer=True, fold=True, no_raise=False)
37+
def test(b: bool):
38+
if b:
39+
return False
40+
else:
41+
b = not b
42+
43+
return b
44+
45+
test.print()

0 commit comments

Comments
 (0)