Skip to content

Commit 05d8e61

Browse files
authored
fix more scf bugs for bloqade (#243)
1 parent 8911caa commit 05d8e61

File tree

5 files changed

+89
-46
lines changed

5 files changed

+89
-46
lines changed

src/kirin/dialects/scf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313
from . import (
14+
absint as absint,
1415
interp as interp,
1516
unroll as unroll,
1617
lowering as lowering,

src/kirin/dialects/scf/absint.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from kirin import ir, types, interp
2+
from kirin.analysis import const
3+
from kirin.dialects import func
4+
5+
from .stmts import Yield, IfElse
6+
from ._dialect import dialect
7+
8+
9+
@dialect.register(key="absint")
10+
class Methods(interp.MethodTable):
11+
12+
@interp.impl(Yield)
13+
def yield_stmt(
14+
self,
15+
interp_: interp.AbstractInterpreter,
16+
frame: interp.AbstractFrame,
17+
stmt: Yield,
18+
):
19+
return interp.YieldValue(frame.get_values(stmt.values))
20+
21+
@interp.impl(IfElse)
22+
def if_else(
23+
self,
24+
interp_: interp.AbstractInterpreter,
25+
frame: interp.AbstractFrame,
26+
stmt: IfElse,
27+
):
28+
if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
29+
if hint.data:
30+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
31+
else:
32+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
33+
then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
34+
else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
35+
36+
match (then_results, else_results):
37+
case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
38+
return interp.ReturnValue(then_value.join(else_value))
39+
case (interp.ReturnValue(then_value), _):
40+
return then_results
41+
case (_, interp.ReturnValue(else_value)):
42+
return else_results
43+
case _:
44+
return interp_.join_results(then_results, else_results)
45+
46+
def _infer_if_else_cond(
47+
self,
48+
interp_: interp.AbstractInterpreter,
49+
frame: interp.AbstractFrame,
50+
stmt: IfElse,
51+
body: ir.Region,
52+
):
53+
body_block = body.blocks[0]
54+
body_term = body_block.last_stmt
55+
if isinstance(body_term, func.Return):
56+
frame.worklist.append(interp.Successor(body_block, types.Bool))
57+
return
58+
59+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
60+
body_frame.entries.update(frame.entries)
61+
body_frame.set(body_block.args[0], types.Bool)
62+
return interp_.run_ssacfg_region(body_frame, stmt.then_body)

src/kirin/dialects/scf/constprop.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from kirin import ir, interp
44
from kirin.analysis import const
5+
from kirin.dialects import func
56

67
from .stmts import For, Yield, IfElse
78
from ._dialect import dialect
@@ -48,16 +49,23 @@ def if_else(
4849
else_frame, else_results = self._prop_const_cond_ifelse(
4950
interp_, frame, stmt, const.Value(False), stmt.else_body
5051
)
51-
ret = interp_.join_results(then_results, else_results)
52-
53-
if not then_frame.frame_is_not_pure or not else_frame.frame_is_not_pure:
54-
frame.should_be_pure.add(stmt)
55-
5652
# NOTE: then_frame and else_frame do not change
5753
# parent frame variables value except cond
5854
frame.entries.update(then_frame.entries)
5955
frame.entries.update(else_frame.entries)
60-
frame.set(stmt.cond, cond)
56+
# TODO: pick the non-return value
57+
if isinstance(then_results, interp.ReturnValue) and isinstance(
58+
else_results, interp.ReturnValue
59+
):
60+
return interp.ReturnValue(then_results.value.join(else_results.value))
61+
elif isinstance(then_results, interp.ReturnValue):
62+
ret = else_results
63+
elif isinstance(else_results, interp.ReturnValue):
64+
ret = then_results
65+
else:
66+
if not then_frame.frame_is_not_pure or not else_frame.frame_is_not_pure:
67+
frame.should_be_pure.add(stmt)
68+
ret = interp_.join_results(then_results, else_results)
6169
return ret
6270

6371
def _prop_const_cond_ifelse(
@@ -73,7 +81,9 @@ def _prop_const_cond_ifelse(
7381
body_frame.set(body.blocks[0].args[0], cond)
7482
results = interp_.run_ssacfg_region(body_frame, body)
7583

76-
if not body_frame.frame_is_not_pure:
84+
if not body_frame.frame_is_not_pure and not isinstance(
85+
body.blocks[0].last_stmt, func.Return
86+
):
7787
frame.should_be_pure.add(stmt)
7888
return body_frame, results
7989

src/kirin/dialects/scf/typeinfer.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
11
from kirin import ir, types, interp
2-
from kirin.analysis import ForwardFrame, TypeInference, const
2+
from kirin.analysis import ForwardFrame, TypeInference
33
from kirin.dialects import func
44
from kirin.dialects.eltype import ElType
55

6-
from .stmts import For, Yield, IfElse
6+
from . import absint
7+
from .stmts import For, IfElse
78
from ._dialect import dialect
89

910

1011
@dialect.register(key="typeinfer")
11-
class TypeInfer(interp.MethodTable):
12-
13-
@interp.impl(Yield)
14-
def yield_stmt(
15-
self,
16-
interp_: TypeInference,
17-
frame: ForwardFrame[types.TypeAttribute],
18-
stmt: Yield,
19-
):
20-
return interp.YieldValue(frame.get_values(stmt.values))
12+
class TypeInfer(absint.Methods):
2113

2214
@interp.impl(IfElse)
23-
def if_else(
15+
def if_else_(
2416
self,
2517
interp_: TypeInference,
2618
frame: ForwardFrame[types.TypeAttribute],
@@ -29,32 +21,7 @@ def if_else(
2921
frame.set(
3022
stmt.cond, frame.get(stmt.cond).meet(types.Bool)
3123
) # set cond backwards
32-
if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
33-
if hint.data:
34-
return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
35-
else:
36-
return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
37-
then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
38-
else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
39-
return interp_.join_results(then_results, else_results)
40-
41-
def _infer_if_else_cond(
42-
self,
43-
interp_: TypeInference,
44-
frame: ForwardFrame[types.TypeAttribute],
45-
stmt: IfElse,
46-
body: ir.Region,
47-
):
48-
body_block = body.blocks[0]
49-
body_term = body_block.last_stmt
50-
if isinstance(body_term, func.Return): # TODO: use types.Literal?
51-
frame.worklist.append(interp.Successor(body_block, types.Bool))
52-
return
53-
54-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
55-
body_frame.entries.update(frame.entries)
56-
body_frame.set(body_block.args[0], types.Bool)
57-
return interp_.run_ssacfg_region(body_frame, stmt.then_body)
24+
return super().if_else(self, interp_, frame, stmt)
5825

5926
@interp.impl(For)
6027
def for_loop(

src/kirin/interp/impl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def __repr__(self):
7070
else:
7171
return f"interp {self.parent.name}"
7272

73+
def __call__(self, *args, **kwargs):
74+
return self.impl(*args, **kwargs)
75+
7376

7477
@dataclass
7578
class AttributeImplDef(Def[type[Attribute], "AttributeFunction"]):

0 commit comments

Comments
 (0)