Skip to content

Commit a61e51f

Browse files
authored
ClosureFieldLowering rewrite: Lower captured fields into constants (#538)
A ClosureFieldLowering pass that lowers closure-captured fields into `py.Constant`. e.g. the higher-order function `outer` returns `inner` which closes over parameter `y`. ``` @basic def outer(y: int): def inner(x: int): return x * y + 1 return outer inner_ker = outer(y=10) ``` `inner_ker.print()` will go from: ``` func.lambda inner(%y : !py.int) -> !Any { ^0(%inner_self, %x): │ %y_1 = func.getfield(%inner_self, 0) : !py.int │ %0 = py.binop.mult(%x : !py.int, %y_1) : ~T │ %1 = py.constant.constant 1 : !py.int │ %2 = py.binop.add(%0, %1) : ~T │ func.return %2 } // func.lambda inner ``` to: ``` func.lambda inner(%y : !py.int) -> !Any { ^0(%inner_self, %x): │ %y_1 = py.constant.constant 10 : !py.int │ %0 = py.binop.mult(%x : !py.int, %y_1) : ~T │ %1 = py.constant.constant 1 : !py.int │ %2 = py.binop.add(%0, %1) : ~T │ func.return %2 } // func.lambda inner ```
1 parent 8046478 commit a61e51f

File tree

5 files changed

+90
-1
lines changed

5 files changed

+90
-1
lines changed

src/kirin/ir/nodes/block.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,10 @@ def is_structurally_equal( # pyright: ignore[reportIncompatibleMethodOverride]
391391
if context is None:
392392
context = {}
393393

394+
if self in context:
395+
return context[self] is other
396+
context[self] = other
397+
394398
if len(self._args) != len(other._args) or len(self.stmts) != len(other.stmts):
395399
return False
396400

src/kirin/passes/fold.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
CFGCompactify,
1111
InlineGetItem,
1212
DeadCodeElimination,
13+
ClosureFieldLowering,
1314
)
1415
from kirin.passes.abc import Pass
1516
from kirin.rewrite.abc import RewriteResult
@@ -28,6 +29,7 @@ class Fold(Pass):
2829
- `InlineGetItem`
2930
- `Call2Invoke`
3031
- `DeadCodeElimination`
32+
- `ClosureFieldLowering`
3133
"""
3234

3335
hint_const: HintConst = field(init=False)
@@ -46,6 +48,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
4648
InlineGetItem(),
4749
Call2Invoke(),
4850
DeadCodeElimination(),
51+
ClosureFieldLowering(),
4952
)
5053
)
5154
)

src/kirin/rewrite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from .wrap_const import WrapConst as WrapConst
1414
from .call2invoke import Call2Invoke as Call2Invoke
1515
from .type_assert import InlineTypeAssert as InlineTypeAssert
16+
from .closurefieldlowering import ClosureFieldLowering as ClosureFieldLowering
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from dataclasses import dataclass
2+
3+
from kirin import ir
4+
from kirin.dialects import py, func
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
7+
8+
@dataclass
9+
class ClosureFieldLowering(RewriteRule):
10+
"""Lowers captured closure fields into py.Constants.
11+
- Trigger on func.Invoke
12+
- If the callee Method has non-empty .fields, lower its func.GetField to py.Constant
13+
"""
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16+
if not isinstance(node, func.Invoke):
17+
return RewriteResult(has_done_something=False)
18+
19+
method = node.callee
20+
if not method.fields:
21+
return RewriteResult(has_done_something=False)
22+
# Replace func.GetField with py.Constant.
23+
changed = self._lower_captured_fields(method)
24+
if changed:
25+
method.fields = ()
26+
return RewriteResult(has_done_something=changed)
27+
28+
def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
29+
fld = getfield_stmt.attributes.get("field")
30+
if fld:
31+
return getfield_stmt.field
32+
else:
33+
return None
34+
35+
def _lower_captured_fields(self, method: ir.Method) -> bool:
36+
changed = False
37+
fields = method.fields
38+
if not fields:
39+
return False
40+
41+
for region in method.code.regions:
42+
for block in region.blocks:
43+
for stmt in list(block.stmts):
44+
if not isinstance(stmt, func.GetField):
45+
continue
46+
idx = self._get_field_index(stmt)
47+
if idx is None:
48+
continue
49+
captured = fields[idx]
50+
# Skip Methods.
51+
if isinstance(captured, ir.Method):
52+
continue
53+
# Replace GetField with Constant.
54+
const_stmt = py.Constant(captured)
55+
const_stmt.insert_before(stmt)
56+
if stmt.results and const_stmt.results:
57+
stmt.results[0].replace_by(const_stmt.results[0])
58+
stmt.delete()
59+
changed = True
60+
return changed

test/serialization/test_jsonserializer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def foo(x: int, y: float, z: bool):
1010
c = [[(200.0, 200.0), (210.0, 200.0)]]
1111
if z:
12-
c.append([(222.0, 333.0)])
12+
c = [(222.0, 333.0)]
1313
else:
1414
return [1, 2, 3, 4]
1515
return c
@@ -47,6 +47,23 @@ def my_kernel2(y: int):
4747
return my_kernel1(y) * 10
4848

4949

50+
@basic
51+
def foo2(y: int):
52+
53+
def inner(x: int):
54+
return x * y + 1
55+
56+
return inner
57+
58+
59+
inner_ker = foo2(y=10)
60+
61+
62+
@basic
63+
def main_lambda(z: int):
64+
return inner_ker(z)
65+
66+
5067
@basic
5168
def slicing():
5269
in1 = ("a", "b", "c", "d", "e", "f", "g", "h")
@@ -94,6 +111,10 @@ def test_round_trip5():
94111
round_trip(slicing)
95112

96113

114+
def test_round_trip6():
115+
round_trip(main_lambda)
116+
117+
97118
def test_deterministic():
98119
serializer = Serializer()
99120
s1 = serializer.encode(loop_ilist)

0 commit comments

Comments
 (0)