Skip to content

Commit 8911caa

Browse files
authored
fix inference with scf (#241)
this PR fixes inference when there are `func.Return` inside the `scf` region.
1 parent be86bd0 commit 8911caa

File tree

21 files changed

+407
-149
lines changed

21 files changed

+407
-149
lines changed

src/kirin/analysis/typeinfer/analysis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ class TypeInference(Forward[types.TypeAttribute]):
2424
keys = ["typeinfer"]
2525
lattice = types.TypeAttribute
2626

27+
def run_analysis(
28+
self, method: ir.Method, args: tuple[types.TypeAttribute, ...] | None = None
29+
) -> tuple[ForwardFrame[types.TypeAttribute], types.TypeAttribute]:
30+
if args is None:
31+
args = method.arg_types
32+
return super().run_analysis(method, args)
33+
2734
# NOTE: unlike concrete interpreter, instead of using type information
2835
# within the IR. Type inference will use the interpreted
2936
# value (which is a type) to determine the method dispatch.

src/kirin/dialects/eltype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ElType(ir.Statement):
1717
This statement is used by other dialects to query the element type of a value.
1818
"""
1919

20-
container: ir.SSAValue = info.argument(types.PyClass(types.TypeAttribute))
20+
container: ir.SSAValue = info.argument(types.Any)
2121
"""The value to query the element type of."""
2222
elem: ir.ResultValue = info.result(types.PyClass(types.TypeAttribute))
2323
"""The element type of the value."""

src/kirin/dialects/func/typeinfer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from kirin import ir, types
44
from kirin.interp import Frame, MethodTable, ReturnValue, impl
5+
from kirin.analysis import const
56
from kirin.analysis.typeinfer import TypeInference, TypeResolution
67
from kirin.dialects.func.stmts import (
78
Call,
@@ -24,6 +25,8 @@ def const_none(self, interp: TypeInference, frame: Frame, stmt: ConstantNone):
2425

2526
@impl(Return)
2627
def return_(self, interp: TypeInference, frame: Frame, stmt: Return) -> ReturnValue:
28+
if isinstance(hint := stmt.value.hints.get("const"), const.Value):
29+
return ReturnValue(types.Literal(hint.data))
2730
return ReturnValue(frame.get(stmt.value))
2831

2932
@impl(Call)

src/kirin/dialects/ilist/typeinfer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin import types
22
from kirin.interp import Frame, MethodTable, impl
3+
from kirin.dialects.eltype import ElType
34
from kirin.dialects.py.binop import Add
45
from kirin.analysis.typeinfer import TypeInference
56
from kirin.dialects.py.indexing import GetItem
@@ -19,6 +20,16 @@ def _get_list_len(typ: types.Generic):
1920
else:
2021
return types.Any
2122

23+
@impl(ElType, types.PyClass(IList))
24+
def eltype_list(
25+
self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: ElType
26+
):
27+
list_type = frame.get(stmt.container)
28+
if isinstance(list_type, types.Generic):
29+
return (list_type.vars[0],)
30+
else:
31+
return (types.Any,)
32+
2233
@impl(New)
2334
def new(self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: New):
2435
values = frame.get_values(stmt.values)

src/kirin/dialects/py/list/typeinfer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from kirin import types, interp
2+
from kirin.dialects.eltype import ElType
23
from kirin.dialects.py.binop import Add
34
from kirin.dialects.py.indexing import GetItem
45

@@ -8,6 +9,14 @@
89
@dialect.register(key="typeinfer")
910
class TypeInfer(interp.MethodTable):
1011

12+
@interp.impl(ElType, types.PyClass(list))
13+
def eltype_list(self, interp, frame: interp.Frame, stmt: ElType):
14+
list_type = frame.get(stmt.container)
15+
if isinstance(list_type, types.Generic):
16+
return (list_type.vars[0],)
17+
else:
18+
return (types.Any,)
19+
1120
@interp.impl(Add, types.PyClass(list), types.PyClass(list))
1221
def add(self, interp, frame: interp.Frame, stmt: Add):
1322
lhs_type = frame.get(stmt.lhs)

src/kirin/dialects/py/range.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import ast
1515
from dataclasses import dataclass
1616

17-
from kirin import ir, types, lowering, exceptions
17+
from kirin import ir, types, interp, lowering, exceptions
1818
from kirin.decl import info, statement
19+
from kirin.dialects import eltype
1920

2021
dialect = ir.Dialect("py.range")
2122

@@ -48,6 +49,14 @@ def lower_Call_range(
4849
return _lower_range(state, node)
4950

5051

52+
@dialect.register(key="typeinfer")
53+
class TypeInfer(interp.MethodTable):
54+
55+
@interp.impl(eltype.ElType, types.PyClass(range))
56+
def eltype_range(self, interp_, frame: interp.Frame, stmt: eltype.ElType):
57+
return (types.Int,)
58+
59+
5160
def _lower_range(state: lowering.LoweringState, node: ast.Call) -> lowering.Result:
5261
if len(node.args) == 1:
5362
start = state.visit(ast.Constant(0)).expect_one()

src/kirin/dialects/py/tuple.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from kirin.decl import info, statement
1919
from kirin.analysis import const
2020
from kirin.emit.julia import EmitJulia, EmitStrFrame
21+
from kirin.dialects.eltype import ElType
2122
from kirin.dialects.py.binop import Add
2223

2324
dialect = ir.Dialect("py.tuple")
@@ -49,6 +50,17 @@ def new(self, interp: interp.Interpreter, frame: interp.Frame, stmt: New):
4950
@dialect.register(key="typeinfer")
5051
class TypeInfer(interp.MethodTable):
5152

53+
@interp.impl(ElType, types.PyClass(tuple))
54+
def eltype_tuple(self, interp, frame: interp.Frame, stmt: ElType):
55+
tuple_type = frame.get(stmt.container)
56+
if isinstance(tuple_type, types.Generic):
57+
ret = tuple_type.vars[0]
58+
for var in tuple_type.vars[1:]:
59+
ret = ret.join(var)
60+
return (ret,)
61+
else:
62+
return (types.Any,)
63+
5264
@interp.impl(Add, types.PyClass(tuple), types.PyClass(tuple))
5365
def add(self, interp, frame: interp.Frame[types.TypeAttribute], stmt):
5466
lhs = frame.get(stmt.lhs)
Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Iterable
22

3-
from kirin import interp
3+
from kirin import ir, interp
44
from kirin.analysis import const
55

66
from .stmts import For, Yield, IfElse
@@ -32,35 +32,50 @@ def if_else(
3232
):
3333
cond = frame.get(stmt.cond)
3434
if isinstance(cond, const.Value):
35-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
36-
body_frame.entries.update(frame.entries)
37-
if cond.data:
38-
results = interp_.run_ssacfg_region(body_frame, stmt.then_body)
39-
else:
40-
results = interp_.run_ssacfg_region(body_frame, stmt.else_body)
35+
if cond.data:
36+
body = stmt.then_body
37+
else:
38+
body = stmt.else_body
39+
body_frame, ret = self._prop_const_cond_ifelse(
40+
interp_, frame, stmt, cond, body
41+
)
42+
frame.entries.update(body_frame.entries)
43+
return ret
44+
else:
45+
then_frame, then_results = self._prop_const_cond_ifelse(
46+
interp_, frame, stmt, const.Value(True), stmt.then_body
47+
)
48+
else_frame, else_results = self._prop_const_cond_ifelse(
49+
interp_, frame, stmt, const.Value(False), stmt.else_body
50+
)
51+
ret = interp_.join_results(then_results, else_results)
4152

42-
if not body_frame.frame_is_not_pure:
53+
if not then_frame.frame_is_not_pure or not else_frame.frame_is_not_pure:
4354
frame.should_be_pure.add(stmt)
44-
else:
45-
with interp_.state.new_frame(interp_.new_frame(stmt)) as then_body_frame:
46-
then_body_frame.entries.update(frame.entries)
47-
then_results = interp_.run_ssacfg_region(
48-
then_body_frame, stmt.then_body
49-
)
5055

51-
with interp_.state.new_frame(interp_.new_frame(stmt)) as else_body_frame:
52-
else_body_frame.entries.update(frame.entries)
53-
else_results = interp_.run_ssacfg_region(
54-
else_body_frame, stmt.else_body
55-
)
56-
results = interp_.join_results(then_results, else_results)
56+
# NOTE: then_frame and else_frame do not change
57+
# parent frame variables value except cond
58+
frame.entries.update(then_frame.entries)
59+
frame.entries.update(else_frame.entries)
60+
frame.set(stmt.cond, cond)
61+
return ret
5762

58-
if (
59-
not then_body_frame.frame_is_not_pure
60-
or not else_body_frame.frame_is_not_pure
61-
):
62-
frame.should_be_pure.add(stmt)
63-
return results
63+
def _prop_const_cond_ifelse(
64+
self,
65+
interp_: const.Propagate,
66+
frame: const.Frame,
67+
stmt: IfElse,
68+
cond: const.Value,
69+
body: ir.Region,
70+
):
71+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
72+
body_frame.entries.update(frame.entries)
73+
body_frame.set(body.blocks[0].args[0], cond)
74+
results = interp_.run_ssacfg_region(body_frame, body)
75+
76+
if not body_frame.frame_is_not_pure:
77+
frame.should_be_pure.add(stmt)
78+
return body_frame, results
6479

6580
@interp.impl(For)
6681
def for_loop(
@@ -70,33 +85,44 @@ def for_loop(
7085
stmt: For,
7186
):
7287
iterable = frame.get(stmt.iterable)
73-
loop_vars = frame.get_values(stmt.initializers)
74-
block_args = stmt.body.blocks[0].args
75-
7688
if isinstance(iterable, const.Value):
77-
frame_is_not_pure = False
78-
if not isinstance(iterable.data, Iterable):
79-
raise interp.InterpreterError(
80-
f"Expected iterable, got {type(iterable.data)}"
81-
)
82-
for value in iterable.data:
83-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
84-
body_frame.entries.update(frame.entries)
85-
body_frame.set_values(
86-
block_args,
87-
(const.Value(value),) + loop_vars,
88-
)
89-
loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body)
90-
91-
if body_frame.frame_is_not_pure:
92-
frame_is_not_pure = True
93-
if loop_vars is None:
94-
loop_vars = ()
95-
elif isinstance(loop_vars, interp.ReturnValue):
96-
return loop_vars
97-
98-
if not frame_is_not_pure:
99-
frame.should_be_pure.add(stmt)
100-
return loop_vars
89+
return self._prop_const_iterable_forloop(interp_, frame, stmt, iterable)
10190
else: # TODO: support other iteration
10291
return tuple(interp_.lattice.top() for _ in stmt.results)
92+
93+
def _prop_const_iterable_forloop(
94+
self,
95+
interp_: const.Propagate,
96+
frame: const.Frame,
97+
stmt: For,
98+
iterable: const.Value,
99+
):
100+
frame_is_not_pure = False
101+
if not isinstance(iterable.data, Iterable):
102+
raise interp.InterpreterError(
103+
f"Expected iterable, got {type(iterable.data)}"
104+
)
105+
106+
loop_vars = frame.get_values(stmt.initializers)
107+
body_block = stmt.body.blocks[0]
108+
block_args = body_block.args
109+
110+
for value in iterable.data:
111+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
112+
body_frame.entries.update(frame.entries)
113+
body_frame.set_values(
114+
block_args,
115+
(const.Value(value),) + loop_vars,
116+
)
117+
loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body)
118+
119+
if body_frame.frame_is_not_pure:
120+
frame_is_not_pure = True
121+
if loop_vars is None:
122+
loop_vars = ()
123+
elif isinstance(loop_vars, interp.ReturnValue):
124+
return loop_vars
125+
126+
if not frame_is_not_pure:
127+
frame.should_be_pure.add(stmt)
128+
return loop_vars

src/kirin/dialects/scf/lowering.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@ def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Resu
1515
cond = state.visit(node.test).expect_one()
1616
frame = state.current_frame
1717
body_frame = lowering.Frame.from_stmts(node.body, state, globals=frame.globals)
18+
then_cond = body_frame.curr_block.args.append_from(types.Bool, cond.name)
19+
if cond.name:
20+
body_frame.defs[cond.name] = then_cond
1821
state.push_frame(body_frame)
1922
state.exhaust(body_frame)
2023
state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
2124

2225
else_frame = lowering.Frame.from_stmts(
2326
node.orelse, state, globals=frame.globals
2427
)
28+
else_cond = else_frame.curr_block.args.append_from(types.Bool, cond.name)
29+
if cond.name:
30+
else_frame.defs[cond.name] = else_cond
2531
state.push_frame(else_frame)
2632
state.exhaust(else_frame)
2733
state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
@@ -96,7 +102,11 @@ def new_block_arg_if_inside_loop(frame: lowering.Frame, capture: ir.SSAValue):
96102
unpacking(state, node.target, loop_var)
97103
state.exhaust(body_frame)
98104
# NOTE: this frame won't have phi nodes
99-
body_frame.append_stmt(Yield(*[body_frame.defs[name] for name in yields])) # type: ignore
105+
if yields and (
106+
body_frame.curr_block.last_stmt is None
107+
or not body_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
108+
):
109+
body_frame.append_stmt(Yield(*[body_frame.defs[name] for name in yields])) # type: ignore
100110
state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
101111

102112
initializers: list[ir.SSAValue] = []

0 commit comments

Comments
 (0)