Skip to content

Commit 4c406d1

Browse files
authored
structural control flow (#176)
This PR implements a Python-like SCF dialect that can be a lowering target from Python. To support Python-like multiple loop variables, a new dialect is also added as `unpack`. The support of assignment is still not implemented in this PR, we can add that in the future (so you can only unpack in the loop head in this PR) Unlike MLIR SCF, this dialect has less restrictions on the iterable (it is typed `Any`), this will have issues in type inference where we need to implicitly call an `eltype` method to query the element type of a iterable, but I think this what Python loops expect anyways. To use a `affine.for`-like semantics, one can always rewrite this Python-based `for` statement into `affine.for` when constant fold finds the loop iterable is a `range`. I realize to support these statements, we still need to let region evaluation return multiple values, so the API of `BaseInterpreter.run_ssa_cfg_region` (I hope we can move this into a trait instead of inside interpreter so we can support other ways of execution) is changed into returning `tuple[ValueType, ...]` and in the case of empty return, it returns `()` instead of `self.void`. The `run_callable_region` method now checks if region gives `()` and returns `void` to match Python semantics (given we don't allow functions return a tuple directly). I think this is mostly non-breaking but let me know if you are using this API (@kaihsin @weinbe58 ) example: ```python from kirin.prelude import python_basic from kirin.dialects import py, scf, func, ilist xs = ilist.IList([(1, 2), (3, 4)]) @python_basic.union([func, scf, py.range, py.unpack, ilist]) def main(x): for a, b in xs: x = x + a return x main.print() ``` <img width="680" alt="image" src="https://github.com/user-attachments/assets/6b9a9e5d-d561-40da-a6f6-3cf7607df138" /> supercedes #101
1 parent d6d6abd commit 4c406d1

File tree

22 files changed

+564
-45
lines changed

22 files changed

+564
-45
lines changed

example/beer/rewrite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from stmts import RandomBranch
44

55
from kirin.dialects import cf
6+
from kirin.rewrite.abc import RewriteRule, RewriteResult
67
from kirin.ir.nodes.stmt import Statement
7-
from kirin.rewrite.abc import RewriteResult, RewriteRule
88

99

1010
@dataclass

scripts/release.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import argparse
2-
import logging
31
import os
2+
import logging
3+
import argparse
44
import subprocess
55

66
import tomlkit

src/kirin/analysis/const/prop.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def _try_eval_const_pure(
6262
if isinstance(value, tuple):
6363
return tuple(JointResult(Value(each), Pure()) for each in value)
6464
elif isinstance(value, interp.ReturnValue):
65-
return interp.ReturnValue(JointResult(Value(value.result), Pure()))
65+
return interp.ReturnValue(
66+
*tuple(JointResult(Value(x), Pure()) for x in value.results)
67+
)
6668
elif isinstance(value, interp.Successor):
6769
return interp.Successor(
6870
value.block,
@@ -102,7 +104,9 @@ def _set_frame_not_pure(self, result: interp.StatementResult[JointResult]):
102104
if isinstance(result, tuple) and all(x.purity is Pure() for x in result):
103105
return
104106

105-
if isinstance(result, interp.ReturnValue) and result.result.purity is Pure():
107+
if isinstance(result, interp.ReturnValue) and all(
108+
x.purity is Pure() for x in result.results
109+
):
106110
return
107111

108112
if isinstance(result, interp.Successor) and all(

src/kirin/dialects/py/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
unary as unary,
1212
assign as assign,
1313
boolop as boolop,
14+
unpack as unpack,
1415
builtin as builtin,
1516
constant as constant,
1617
indexing as indexing,

src/kirin/dialects/py/unpack.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import ast
2+
3+
from kirin import ir, interp, lowering
4+
from kirin.decl import info, statement
5+
from kirin.print import Printer
6+
from kirin.exceptions import DialectLoweringError
7+
8+
dialect = ir.Dialect("py.unpack")
9+
10+
11+
@statement(dialect=dialect, init=False)
12+
class Unpack(ir.Statement):
13+
value: ir.SSAValue = info.argument(ir.types.Any)
14+
names: tuple[str | None, ...] = info.attribute(property=True)
15+
16+
def __init__(self, value: ir.SSAValue, names: tuple[str | None, ...]):
17+
result_types = [ir.types.Any] * len(names)
18+
super().__init__(
19+
args=(value,),
20+
result_types=result_types,
21+
args_slice={"value": 0},
22+
properties={"names": ir.PyAttr(names)},
23+
)
24+
for result, name in zip(self.results, names):
25+
result.name = name
26+
27+
def print_impl(self, printer: Printer) -> None:
28+
printer.print_name(self)
29+
printer.plain_print(" ")
30+
printer.print(self.value)
31+
32+
33+
@dialect.register
34+
class Concrete(interp.MethodTable):
35+
36+
@interp.impl(Unpack)
37+
def unpack(self, interp: interp.Interpreter, frame: interp.Frame, stmt: Unpack):
38+
return tuple(frame.get(stmt.value))
39+
40+
41+
@dialect.register(key="typeinfer")
42+
class TypeInfer(interp.MethodTable):
43+
44+
@interp.impl(Unpack)
45+
def unpack(self, interp, frame: interp.Frame[ir.types.TypeAttribute], stmt: Unpack):
46+
value = frame.get(stmt.value)
47+
if isinstance(value, ir.types.Generic) and value.is_subseteq(ir.types.Tuple):
48+
if value.vararg:
49+
rest = tuple(value.vararg.typ for _ in stmt.names[len(value.vars) :])
50+
return tuple(value.vars) + rest
51+
else:
52+
return value.vars
53+
# TODO: support unpacking other types
54+
return tuple(ir.types.Any for _ in stmt.names)
55+
56+
57+
def unpackable(state: lowering.LoweringState, node: ast.expr, value: ir.SSAValue):
58+
if isinstance(node, ast.Name):
59+
state.current_frame.defs[node.id] = value
60+
value.name = node.id
61+
return
62+
elif not isinstance(node, ast.Tuple):
63+
raise DialectLoweringError(f"unsupported unpack node {node}")
64+
65+
names: list[str | None] = []
66+
continue_unpack: list[int] = []
67+
for idx, item in enumerate(node.elts):
68+
if isinstance(item, ast.Name):
69+
names.append(item.id)
70+
else:
71+
names.append(None)
72+
continue_unpack.append(idx)
73+
stmt = state.append_stmt(Unpack(value, tuple(names)))
74+
for name, result in zip(names, stmt.results):
75+
if name is not None:
76+
state.current_frame.defs[name] = result
77+
78+
for idx in continue_unpack:
79+
unpackable(state, node.elts[idx], stmt.results[idx])

src/kirin/dialects/scf/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import interp as interp, lowering as lowering
2+
from .stmts import For as For, Yield as Yield, IfElse as IfElse
3+
from ._dialect import dialect as dialect

src/kirin/dialects/scf/_dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("scf")

src/kirin/dialects/scf/interp.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from kirin import interp
2+
3+
from .stmts import For, Yield, IfElse
4+
from ._dialect import dialect
5+
6+
7+
@dialect.register
8+
class Concrete(interp.MethodTable):
9+
10+
@interp.impl(Yield)
11+
def yield_stmt(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: Yield):
12+
return interp.ReturnValue(*frame.get_values(stmt.values))
13+
14+
@interp.impl(IfElse)
15+
def if_else(self, interp: interp.Interpreter, frame: interp.Frame, stmt: IfElse):
16+
cond = frame.get(stmt.cond)
17+
if cond:
18+
body = stmt.then_body
19+
else:
20+
body = stmt.else_body
21+
return interp.run_ssacfg_region(frame, body)
22+
23+
@interp.impl(For)
24+
def for_loop(self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: For):
25+
iterable = frame.get(stmt.iterable)
26+
loop_vars = frame.get_values(stmt.initializers)
27+
block_args = stmt.body.blocks[0].args
28+
for value in iterable:
29+
frame.set_values(block_args, (value,) + loop_vars)
30+
loop_vars = interpreter.run_ssacfg_region(frame, stmt.body)
31+
32+
return loop_vars

src/kirin/dialects/scf/lowering.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import ast
2+
3+
from kirin import ir, lowering
4+
from kirin.exceptions import DialectLoweringError
5+
from kirin.dialects.py.unpack import unpackable
6+
7+
from .stmts import For, Yield, IfElse
8+
from ._dialect import dialect
9+
10+
11+
@dialect.register
12+
class Lowering(lowering.FromPythonAST):
13+
14+
def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Result:
15+
cond = state.visit(node.test).expect_one()
16+
frame = state.current_frame
17+
body_frame = lowering.Frame.from_stmts(node.body, state)
18+
state.push_frame(body_frame)
19+
state.exhaust(body_frame)
20+
state.pop_frame()
21+
22+
else_frame = lowering.Frame.from_stmts(node.orelse, state)
23+
state.push_frame(else_frame)
24+
state.exhaust(else_frame)
25+
state.pop_frame()
26+
27+
yield_names: list[str] = []
28+
body_yields: list[ir.SSAValue] = []
29+
else_yields: list[ir.SSAValue] = []
30+
if node.orelse:
31+
for name in body_frame.defs.keys():
32+
if name in else_frame.defs:
33+
yield_names.append(name)
34+
body_yields.append(body_frame.get_scope(name))
35+
else_yields.append(else_frame.get_scope(name))
36+
else:
37+
for name in body_frame.defs.keys():
38+
if name in frame.defs:
39+
yield_names.append(name)
40+
body_yields.append(body_frame.get_scope(name))
41+
value = frame.get(name)
42+
if value is None:
43+
raise DialectLoweringError(f"expected value for {name}")
44+
else_yields.append(value)
45+
46+
body_frame.append_stmt(Yield(*body_yields))
47+
else_frame.append_stmt(Yield(*else_yields))
48+
stmt = IfElse(
49+
cond,
50+
then_body=body_frame.current_region,
51+
else_body=else_frame.current_region,
52+
)
53+
for result, name, body, else_ in zip(
54+
stmt.results, yield_names, body_yields, else_yields
55+
):
56+
result.name = name
57+
result.type = body.type.join(else_.type)
58+
frame.defs[name] = result
59+
state.append_stmt(stmt)
60+
return lowering.Result()
61+
62+
def lower_For(
63+
self, state: lowering.LoweringState, node: ast.For
64+
) -> lowering.Result:
65+
iter_ = state.visit(node.iter).expect_one()
66+
67+
yields: list[str] = []
68+
69+
def new_block_arg_if_inside_loop(frame: lowering.Frame, capture: ir.SSAValue):
70+
if not capture.name:
71+
raise DialectLoweringError("unexpected loop variable captured")
72+
yields.append(capture.name)
73+
return frame.current_block.args.append_from(capture.type, capture.name)
74+
75+
body_frame = state.push_frame(
76+
lowering.Frame.from_stmts(
77+
node.body, state, capture_callback=new_block_arg_if_inside_loop
78+
)
79+
)
80+
loop_var = body_frame.current_block.args.append_from(ir.types.Any)
81+
unpackable(state, node.target, loop_var)
82+
state.exhaust(body_frame)
83+
# NOTE: this frame won't have phi nodes
84+
body_frame.append_stmt(Yield(*[body_frame.defs[name] for name in yields])) # type: ignore
85+
state.pop_frame()
86+
87+
initializers: list[ir.SSAValue] = []
88+
for name in yields:
89+
value = state.current_frame.get(name)
90+
if value is None:
91+
raise DialectLoweringError(f"expected value for {name}")
92+
initializers.append(value)
93+
stmt = For(iter_, body_frame.current_region, *initializers)
94+
for name, result in zip(yields, stmt.results):
95+
state.current_frame.defs[name] = result
96+
state.append_stmt(stmt)
97+
return lowering.Result()

0 commit comments

Comments
 (0)