Skip to content

Commit 45b2a42

Browse files
authored
try to fix some scf bugs (#184)
SCF interpreter is still a bit buggy. Trying to switch to scf to see what break here. UPDATE: Ok I see the problem here, we cannot take the similar route as xDSL because they don't aim to deal with python semantics. It works if we follow MLIR convention, but if we have to support Python. We need to allow `scf.IfElse` statement accept other terminators, e.g `func.Return`, I don't see an easy way transforming to `Yield` in this case because we are doing more dynamic semantics (e.g the return value can be a union of different sized tuple), making the following code seems not possible to have an `yield` equivalent form ```mlir %1, %2 = scf.if a > b { func.return a } else { y = b + 1 scf.yield a,y } ``` we cannot lift the `return` statement outside the if statement here... so directly supporting this would require `run_ssacfg_region` to forward the `ReturnValue` result until we pop the current function frame.
1 parent b97dcf1 commit 45b2a42

29 files changed

+503
-209
lines changed

src/kirin/analysis/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
from kirin.analysis import const as const
1717
from kirin.analysis.cfg import CFG as CFG
18-
from kirin.analysis.forward import Forward as Forward, ForwardExtra as ForwardExtra
18+
from kirin.analysis.forward import (
19+
Forward as Forward,
20+
ForwardExtra as ForwardExtra,
21+
ForwardFrame as ForwardFrame,
22+
)
1923
from kirin.analysis.callgraph import CallGraph as CallGraph
2024
from kirin.analysis.typeinfer import TypeInference as TypeInference

src/kirin/analysis/const/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
the IR.
99
"""
1010

11-
from .prop import Propagate as Propagate
11+
from .prop import Propagate as Propagate, ExtraFrameInfo as ExtraFrameInfo
1212
from .lattice import (
1313
Pure as Pure,
1414
Value as Value,

src/kirin/analysis/const/prop.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,20 @@ def _try_eval_const_pure(
5959
_frame = self._interp.new_frame(frame.code)
6060
_frame.set_values(stmt.args, tuple(x.data for x in values))
6161
value = self._interp.eval_stmt(_frame, stmt)
62-
if isinstance(value, tuple):
63-
return tuple(JointResult(Value(each), Pure()) for each in value)
64-
elif isinstance(value, interp.ReturnValue):
65-
return interp.ReturnValue(
66-
*tuple(JointResult(Value(x), Pure()) for x in value.results)
67-
)
68-
elif isinstance(value, interp.Successor):
69-
return interp.Successor(
70-
value.block,
71-
*tuple(
72-
JointResult(Value(each), Pure()) for each in value.block_args
73-
),
74-
)
62+
match value:
63+
case tuple():
64+
return tuple(JointResult(Value(each), Pure()) for each in value)
65+
case interp.ReturnValue(ret):
66+
return interp.ReturnValue(JointResult(Value(ret), Pure()))
67+
case interp.YieldValue(yields):
68+
return interp.YieldValue(
69+
tuple(JointResult(Value(each), Pure()) for each in yields)
70+
)
71+
case interp.Successor(block, args):
72+
return interp.Successor(
73+
block,
74+
*tuple(JointResult(Value(each), Pure()) for each in args),
75+
)
7576
except interp.InterpreterError:
7677
pass
7778
return (self.void,)
@@ -104,8 +105,11 @@ def _set_frame_not_pure(self, result: interp.StatementResult[JointResult]):
104105
if isinstance(result, tuple) and all(x.purity is Pure() for x in result):
105106
return
106107

107-
if isinstance(result, interp.ReturnValue) and all(
108-
x.purity is Pure() for x in result.results
108+
if isinstance(result, interp.ReturnValue) and isinstance(result.value, Pure):
109+
return
110+
111+
if isinstance(result, interp.YieldValue) and all(
112+
isinstance(x, Pure) for x in result
109113
):
110114
return
111115

src/kirin/dialects/cf/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
)
66
from kirin.dialects.cf.lower import CfLowering as CfLowering
77
from kirin.dialects.cf.stmts import (
8-
Assert as Assert,
98
Branch as Branch,
109
ConditionalBranch as ConditionalBranch,
1110
)

src/kirin/dialects/cf/constprop.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from kirin.interp import FrameABC, Successor, MethodTable, impl
22
from kirin.analysis import const
3-
from kirin.dialects.cf.stmts import Assert, Branch, ConditionalBranch
3+
from kirin.dialects.cf.stmts import Branch, ConditionalBranch
44
from kirin.dialects.cf.dialect import dialect
55

66

77
@dialect.register(key="constprop")
88
class DialectConstProp(MethodTable):
99

10-
@impl(Assert)
11-
def assert_stmt(self, interp: const.Propagate, frame, stmt: Assert):
12-
return ()
13-
1410
@impl(Branch)
1511
def branch(
1612
self, interp: const.Propagate, frame: FrameABC[const.JointResult], stmt: Branch

src/kirin/dialects/cf/emit.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from kirin.interp import Successor, MethodTable, impl
55
from kirin.emit.julia import EmitJulia
66

7-
from .stmts import Assert, Branch, ConditionalBranch
7+
from .stmts import Branch, ConditionalBranch
88
from .dialect import dialect
99

1010
IO_t = TypeVar("IO_t", bound=IO)
@@ -13,15 +13,6 @@
1313
@dialect.register(key="emit.julia")
1414
class JuliaMethodTable(MethodTable):
1515

16-
@impl(Assert)
17-
def emit_assert(
18-
self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Assert
19-
):
20-
interp.writeln(
21-
frame, f"@assert {frame.get(stmt.condition)} {frame.get(stmt.message)}"
22-
)
23-
return ()
24-
2516
@impl(Branch)
2617
def emit_branch(
2718
self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Branch

src/kirin/dialects/cf/interp.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
1-
from kirin.interp import Frame, Successor, Interpreter, MethodTable, WrapException, impl
2-
from kirin.dialects.cf.stmts import Assert, Branch, ConditionalBranch
1+
from kirin.interp import Frame, Successor, Interpreter, MethodTable, impl
2+
from kirin.dialects.cf.stmts import Branch, ConditionalBranch
33
from kirin.dialects.cf.dialect import dialect
44

55

66
@dialect.register
77
class CfInterpreter(MethodTable):
88

9-
@impl(Assert)
10-
def assert_stmt(self, interp: Interpreter, frame: Frame, stmt: Assert):
11-
if frame.get(stmt.condition) is True:
12-
return ()
13-
14-
if stmt.message:
15-
raise WrapException(AssertionError(frame.get(stmt.message)))
16-
else:
17-
raise WrapException(AssertionError("Assertion failed"))
18-
199
@impl(Branch)
2010
def branch(self, interp: Interpreter, frame: Frame, stmt: Branch):
2111
return Successor(stmt.successor, *frame.get_values(stmt.arguments))

src/kirin/dialects/cf/lower.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,6 @@ def lower_Pass(self, state: LoweringState, node: ast.Pass) -> Result:
1717
state.append_stmt(cf.Branch(arguments=(), successor=next))
1818
return Result()
1919

20-
def lower_Assert(self, state: LoweringState, node: ast.Assert) -> Result:
21-
from kirin.dialects.py.constant import Constant
22-
23-
cond = state.visit(node.test).expect_one()
24-
if node.msg:
25-
message = state.visit(node.msg).expect_one()
26-
state.append_stmt(cf.Assert(condition=cond, message=message))
27-
else:
28-
message_stmt = state.append_stmt(Constant(""))
29-
state.append_stmt(cf.Assert(condition=cond, message=message_stmt.result))
30-
return Result()
31-
3220
def lower_If(self, state: LoweringState, node: ast.If) -> Result:
3321
cond = state.visit(node.test).expect_one()
3422

src/kirin/dialects/cf/stmts.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,10 @@
11
from kirin.ir import Block, SSAValue, Statement, IsTerminator
22
from kirin.decl import info, statement
33
from kirin.print.printer import Printer
4-
from kirin.ir.attrs.types import Bool, String
4+
from kirin.ir.attrs.types import Bool
55
from kirin.dialects.cf.dialect import dialect
66

77

8-
@statement(dialect=dialect)
9-
class Assert(Statement):
10-
name = "assert"
11-
traits = frozenset({})
12-
13-
condition: SSAValue
14-
message: SSAValue = info.argument(String)
15-
16-
def print_impl(self, printer: Printer) -> None:
17-
with printer.rich(style="keyword"):
18-
printer.print_name(self)
19-
20-
printer.plain_print(" ")
21-
printer.print(self.condition)
22-
23-
if self.message:
24-
printer.plain_print(", ")
25-
printer.print(self.message)
26-
27-
288
@statement(dialect=dialect)
299
class Branch(Statement):
3010
name = "br"

src/kirin/dialects/cf/typeinfer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
1-
from kirin.ir import types
21
from kirin.interp import Successor, MethodTable, AbstractFrame, impl
3-
from kirin.dialects.cf.stmts import Assert, Branch, ConditionalBranch
2+
from kirin.dialects.cf.stmts import Branch, ConditionalBranch
43
from kirin.analysis.typeinfer import TypeInference
54
from kirin.dialects.cf.dialect import dialect
65

76

87
@dialect.register(key="typeinfer")
98
class TypeInfer(MethodTable):
109

11-
@impl(Assert)
12-
def assert_stmt(self, interp: TypeInference, frame: AbstractFrame, stmt: Assert):
13-
return (types.Bottom,)
14-
1510
@impl(Branch)
1611
def branch(self, interp: TypeInference, frame: AbstractFrame, stmt: Branch):
1712
frame.worklist.append(

0 commit comments

Comments
 (0)