Skip to content

Commit 102fdd9

Browse files
authored
fixing bug with multi-return values in statement (#358)
1 parent e7d3379 commit 102fdd9

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

src/kirin/analysis/const/prop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def try_eval_const_pure(
6767
if method is not None:
6868
value = method(self._interp, _frame, stmt)
6969
else:
70-
return (Unknown(),)
70+
return tuple(Unknown() for _ in stmt.results)
7171
match value:
7272
case tuple():
7373
return tuple(Value(each) for each in value)

test/analysis/dataflow/constprop/__init__.py

Whitespace-only changes.
File renamed without changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from kirin import ir, types, lowering
2+
from kirin.decl import info, statement
3+
from kirin.passes import Fold, TypeInfer
4+
from kirin.prelude import structural_no_opt
5+
from kirin.analysis import const
6+
from kirin.dialects import py, func, ilist
7+
8+
dialect = ir.Dialect("analysis")
9+
10+
11+
@statement(dialect=dialect)
12+
class MultiReturnStatement(ir.Statement):
13+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
14+
inputs: tuple[ir.SSAValue] = info.argument(types.Any)
15+
16+
def __init__(self, *args: ir.SSAValue):
17+
super().__init__(
18+
args=args,
19+
result_types=tuple(arg.type for arg in args),
20+
args_slice={"inputs": slice(None)},
21+
)
22+
23+
24+
@ir.dialect_group(structural_no_opt.add(dialect))
25+
def dialect_group_test(self):
26+
fold = Fold(self)
27+
type_infer = TypeInfer(self)
28+
29+
def run_pass(mt):
30+
type_infer(mt)
31+
fold(mt)
32+
33+
return run_pass
34+
35+
36+
def test_multi_return_default_prop():
37+
38+
stmts = [
39+
(a := py.Constant(1)),
40+
(b := py.Constant(2)),
41+
(res := MultiReturnStatement(a.result, b.result)),
42+
(return_result := ilist.New((res.results[0], res.results[1]))),
43+
(func.Return(return_result.result)),
44+
]
45+
46+
body = ir.Region(ir.Block(stmts))
47+
func_code = func.Function(
48+
sym_name="test", signature=func.Signature((), types.Any), body=body
49+
)
50+
mt = ir.Method(None, None, "test", (), dialect_group_test, func_code)
51+
52+
frame, return_result = const.Propagate(dialect_group_test).run_analysis(
53+
mt, no_raise=False
54+
)
55+
56+
assert frame.entries[res.results[0]] == const.Unknown()
57+
assert frame.entries[res.results[1]] == const.Unknown()

0 commit comments

Comments
 (0)