Skip to content

Commit ba807ac

Browse files
authored
support lowering loop to cf (#187)
this PR supports lowering loop to `cf`, e.g ```python @basic_no_opt def main(x: int): for i in range(5): for j in range(5): x = x + 1 return x ``` <details> <summary>compiles to</summary> ```mlir func.func main(!py.int) -> !Any { ^0(%main_self, %x_1): │ %0 = py.constant.constant 0 : !py.int │ %1 = py.constant.constant 5 : !py.int │ %2 = py.constant.constant 1 : !py.int │ %3 = py.range.range(start=%0, stop=%1, step=%2) : !py.range │ %4 = py.iterable.iter(value=%3) : !Any │ %5 = py.constant.constant None : !py.NoneType │ %6 = py.iterable.next(iter=%4) : !Any │ %7 = py.cmp.is(lhs=%6, rhs=%5) : !py.bool │ cf.cond_br %7 goto ^4(%x_1 : !py.int) else ^1(%6, %x_1 : !py.int) ^1(%i, %x_2): │ %8 = py.constant.constant 0 : !py.int │ %9 = py.constant.constant 5 : !py.int │ %10 = py.constant.constant 1 : !py.int │ %11 = py.range.range(start=%8, stop=%9, step=%10) : !py.range │ %12 = py.iterable.iter(value=%11) : !Any │ %13 = py.constant.constant None : !py.NoneType │ %14 = py.iterable.next(iter=%12) : !Any │ %15 = py.cmp.is(lhs=%14, rhs=%13) : !py.bool │ cf.cond_br %15 goto ^3(%x_2 : !py.int) else ^2(%14, %x_2 : !py.int) ^2(%j, %x_3): │ %16 = py.constant.constant 1 : !py.int │ %x = py.binop.add(%x_3 : !py.int, %16) : ~T │ %17 = py.iterable.next(iter=%12) : !Any │ %18 = py.cmp.is(lhs=%17, rhs=%13) : !py.bool │ cf.cond_br %18 goto ^3(%x) else ^2(%17, %x) ^3(%x_4): │ %19 = py.iterable.next(iter=%4) : !Any │ %20 = py.cmp.is(lhs=%19, rhs=%5) : !py.bool │ cf.cond_br %20 goto ^4(%x_4 : ~T) else ^1(%19, %x_4 : ~T) ^4(%x_5): │ func.return %x_5 : ~T ^5(): │ %21 = func.const.none() : !py.NoneType │ func.return %21 } // func.func main ``` </details>
1 parent dccdd74 commit ba807ac

File tree

9 files changed

+224
-2
lines changed

9 files changed

+224
-2
lines changed

src/kirin/analysis/cfg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def predecessors(self):
4444
def successors(self):
4545
"""CFG data, mapping a block to its neighbors."""
4646
graph: dict[ir.Block, set[ir.Block]] = {}
47+
visited: set[ir.Block] = set()
4748
worklist: WorkList[ir.Block] = WorkList()
4849
if self.parent.blocks.isempty():
4950
return graph
@@ -54,8 +55,11 @@ def successors(self):
5455
if block.last_stmt is not None:
5556
neighbors.update(block.last_stmt.successors)
5657
worklist.extend(block.last_stmt.successors)
58+
visited.add(block)
5759

5860
block = worklist.pop()
61+
while block is not None and block in visited:
62+
block = worklist.pop()
5963
return graph
6064

6165
# graph interface

src/kirin/dialects/cf/stmts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ class Branch(Statement):
1313
arguments: tuple[SSAValue, ...]
1414
successor: Block = info.block()
1515

16+
def verify(self) -> None:
17+
return
18+
1619
def print_impl(self, printer: Printer) -> None:
1720
with printer.rich(style="keyword"):
1821
printer.print_name(self)
@@ -61,3 +64,6 @@ def print_impl(self, printer: Printer) -> None:
6164
printer.plain_print("(")
6265
printer.print_seq(self.else_arguments, delim=", ")
6366
printer.plain_print(")")
67+
68+
def verify(self) -> None:
69+
return

src/kirin/dialects/lowering/cf.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import ast
55

66
from kirin import ir
7+
from kirin.dialects import cf, py
78
from kirin.lowering import Frame, Result, FromPythonAST, LoweringState
89
from kirin.exceptions import DialectLoweringError
9-
from kirin.dialects.cf import stmts as cf
1010

1111
dialect = ir.Dialect("lowering.cf")
1212

@@ -21,6 +21,83 @@ def lower_Pass(self, state: LoweringState, node: ast.Pass) -> Result:
2121
state.append_stmt(cf.Branch(arguments=(), successor=next))
2222
return Result()
2323

24+
def lower_For(self, state: LoweringState, node: ast.For) -> Result:
25+
iterable = state.visit(node.iter).expect_one()
26+
iter_stmt = state.append_stmt(py.iterable.Iter(iterable))
27+
none_stmt = state.append_stmt(py.Constant(None))
28+
yields: list[str] = []
29+
30+
def new_block_arg_if_inside_loop(frame: Frame, capture: ir.SSAValue):
31+
if not capture.name:
32+
raise DialectLoweringError("unexpected loop variable captured")
33+
yields.append(capture.name)
34+
return frame.entry_block.args.append_from(capture.type, capture.name)
35+
36+
frame = state.current_frame
37+
next_block = ir.Block()
38+
39+
body_frame = state.push_frame(
40+
Frame.from_stmts(
41+
node.body,
42+
state,
43+
region=frame.current_region,
44+
globals=state.current_frame.globals,
45+
capture_callback=new_block_arg_if_inside_loop,
46+
)
47+
)
48+
body_frame.next_block = next_block
49+
next_value = body_frame.entry_block.args.append_from(ir.types.Any, "next_value")
50+
py.unpack.unpackable(state, node.target, next_value)
51+
state.exhaust(body_frame)
52+
next_stmt = body_frame.append_stmt(py.iterable.Next(iter_stmt.iter))
53+
cond_stmt = body_frame.append_stmt(py.cmp.Is(next_stmt.value, none_stmt.result))
54+
yield_args = tuple(body_frame.get_scope(name) for name in yields)
55+
state.pop_frame()
56+
57+
next_frame = state.push_frame(
58+
Frame.from_stmts(
59+
frame.stream.split(),
60+
state,
61+
region=frame.current_region,
62+
block=next_block,
63+
globals=frame.globals,
64+
)
65+
)
66+
next_frame.next_block = frame.next_block
67+
for name, arg in zip(yields, yield_args):
68+
input = next_frame.current_block.args.append_from(arg.type, name)
69+
next_frame.defs[name] = input
70+
state.exhaust()
71+
state.pop_frame()
72+
73+
yield_args = tuple(body_frame.get_scope(name) for name in yields)
74+
body_frame.append_stmt(
75+
cf.ConditionalBranch(
76+
cond_stmt.result,
77+
yield_args,
78+
(next_stmt.value,) + yield_args,
79+
then_successor=next_frame.entry_block,
80+
else_successor=body_frame.entry_block,
81+
)
82+
)
83+
84+
next_stmt = frame.append_stmt(py.iterable.Next(iter_stmt.iter))
85+
cond_stmt = frame.append_stmt(py.cmp.Is(next_stmt.value, none_stmt.result))
86+
yield_args = tuple(frame.get_scope(name) for name in yields)
87+
frame.append_stmt(
88+
cf.ConditionalBranch(
89+
cond_stmt.result,
90+
yield_args,
91+
(next_stmt.value,) + yield_args,
92+
then_successor=next_frame.entry_block,
93+
else_successor=body_frame.entry_block,
94+
)
95+
)
96+
frame.current_block = next_frame.current_block
97+
frame.next_block = next_frame.next_block
98+
frame.defs.update(next_frame.defs)
99+
return Result()
100+
24101
def lower_If(self, state: LoweringState, node: ast.If) -> Result:
25102
cond = state.visit(node.test).expect_one()
26103

src/kirin/dialects/py/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
builtin as builtin,
1616
constant as constant,
1717
indexing as indexing,
18+
iterable as iterable,
1819
)
1920
from .len import Len as Len
2021
from .attr import GetAttr as GetAttr

src/kirin/dialects/py/iterable.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""This module provides access to Python iterables.
2+
3+
This is used to lower Python loops into `cf` dialect.
4+
"""
5+
6+
from ast import Call
7+
8+
from kirin import ir, interp, lowering
9+
from kirin.decl import info, statement
10+
from kirin.exceptions import DialectLoweringError
11+
12+
dialect = ir.Dialect("py.iterable")
13+
14+
PyRangeIterType = ir.types.PyClass(type(iter(range(0))))
15+
16+
17+
@statement(dialect=dialect)
18+
class Iter(ir.Statement):
19+
"""This is equivalent to `iter(value)` in Python."""
20+
21+
traits = frozenset({ir.Pure()})
22+
value: ir.SSAValue = info.argument(ir.types.Any)
23+
iter: ir.ResultValue = info.result(ir.types.Any)
24+
25+
26+
@statement(dialect=dialect)
27+
class Next(ir.Statement):
28+
"""This is equivalent to `next(iterable, None)` in Python."""
29+
30+
iter: ir.SSAValue = info.argument(ir.types.Any)
31+
value: ir.ResultValue = info.result(ir.types.Any)
32+
33+
34+
@dialect.register
35+
class Concrete(interp.MethodTable):
36+
37+
@interp.impl(Iter)
38+
def iter_(self, interp, frame: interp.Frame, stmt: Iter):
39+
return (iter(frame.get(stmt.value)),)
40+
41+
@interp.impl(Next)
42+
def next_(self, interp, frame: interp.Frame, stmt: Next):
43+
return (next(frame.get(stmt.iter), None),)
44+
45+
46+
@dialect.register(key="typeinfer")
47+
class TypeInfer(interp.MethodTable):
48+
49+
@interp.impl(Iter, ir.types.PyClass(range))
50+
def iter_(self, interp, frame: interp.Frame, stmt: Iter):
51+
return (PyRangeIterType,)
52+
53+
@interp.impl(Next, PyRangeIterType)
54+
def next_(self, interp, frame: interp.Frame, stmt: Next):
55+
return (ir.types.Int,)
56+
57+
58+
@dialect.register
59+
class Lowering(lowering.FromPythonAST):
60+
61+
def lower_Call_iter(
62+
self, state: lowering.LoweringState, node: Call
63+
) -> lowering.Result:
64+
if len(node.args) != 1:
65+
raise DialectLoweringError("iter() takes exactly 1 argument")
66+
return lowering.Result(
67+
state.append_stmt(Iter(state.visit(node.args[0]).expect_one()))
68+
)
69+
70+
def lower_Call_next(
71+
self, state: lowering.LoweringState, node: Call
72+
) -> lowering.Result:
73+
if len(node.args) == 2:
74+
raise DialectLoweringError(
75+
"next() does not throw StopIteration inside kernel"
76+
)
77+
if len(node.args) != 1:
78+
raise DialectLoweringError("next() takes exactly 1 argument")
79+
return lowering.Result(
80+
state.append_stmt(Next(state.visit(node.args[0]).expect_one()))
81+
)

src/kirin/interp/abstract.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import TypeVar, Iterable, TypeAlias, overload
33
from dataclasses import field, dataclass
44

5-
from kirin.ir import Region, SSAValue, Statement
5+
from kirin.ir import Block, Region, SSAValue, Statement
66
from kirin.lattice import BoundedLattice
77
from kirin.worklist import WorkList
88
from kirin.interp.base import BaseInterpreter, InterpreterMeta
@@ -27,6 +27,7 @@ class AbstractFrame(Frame[ResultType]):
2727
"""
2828

2929
worklist: WorkList[Successor[ResultType]] = field(default_factory=WorkList)
30+
visited: dict[Block, set[Successor[ResultType]]] = field(default_factory=dict)
3031

3132

3233
AbstractFrameType = TypeVar("AbstractFrameType", bound=AbstractFrame)
@@ -151,8 +152,18 @@ def run_ssacfg_region(
151152
Successor(region.blocks[0], *frame.get_values(region.blocks[0].args))
152153
)
153154
while (succ := frame.worklist.pop()) is not None:
155+
if succ.block in frame.visited:
156+
if succ in frame.visited[succ.block]:
157+
continue
158+
else:
159+
frame.visited[succ.block] = set()
154160
self.prehook_succ(frame, succ)
155161
block_result = self.run_block(frame, succ)
162+
if len(frame.visited[succ.block]) < 128:
163+
frame.visited[succ.block].add(succ)
164+
else:
165+
continue
166+
156167
if isinstance(block_result, Successor):
157168
raise InterpreterError(
158169
"unexpected successor, successors should be in worklist"

src/kirin/lowering/dialect.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class FromPythonAST:
7575
def lower_Call_slice(self, state: LoweringState, node: ast.Call) -> Result: ...
7676
def lower_Call_range(self, state: LoweringState, node: ast.Call) -> Result: ...
7777
def lower_Call_len(self, state: LoweringState, node: ast.Call) -> Result: ...
78+
def lower_Call_iter(self, state: LoweringState, node: ast.Call) -> Result: ...
79+
def lower_Call_next(self, state: LoweringState, node: ast.Call) -> Result: ...
7880
def lower_Call_local(
7981
self, state: LoweringState, callee: SSAValue, node: ast.Call
8082
) -> Result: ...

src/kirin/prelude.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
builtin,
2525
constant,
2626
indexing,
27+
iterable,
2728
assertion,
2829
)
2930
from kirin.passes.fold import Fold
@@ -45,6 +46,7 @@
4546
len,
4647
tuple,
4748
assertion,
49+
iterable,
4850
]
4951
)
5052
def python_basic(self):

test/program/py/test_loop.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from kirin.prelude import basic_no_opt
2+
3+
4+
def test_simple():
5+
@basic_no_opt
6+
def main(x: int):
7+
for i in range(5):
8+
x = x + 1
9+
return x
10+
11+
assert main.py_func is not None
12+
assert main.py_func(1) == main(1)
13+
14+
15+
# generate some more complicated loop
16+
def test_nested():
17+
@basic_no_opt
18+
def main(x: int):
19+
for i in range(5):
20+
for j in range(5):
21+
x = x + 1
22+
return x
23+
24+
assert main.py_func is not None
25+
assert main.py_func(1) == main(1)
26+
27+
28+
def test_nested2():
29+
@basic_no_opt
30+
def main(x: int):
31+
for i in range(5):
32+
for j in range(5):
33+
for k in range(5):
34+
x = x + 1
35+
return x
36+
37+
assert main.py_func is not None
38+
assert main.py_func(1) == main(1)

0 commit comments

Comments
 (0)