Skip to content

Commit 300e4f2

Browse files
authored
Moved ClosureField into dialects/func. Implemented LambdaLifting. (#543)
Renamed `ClosureFieldLowering` to `ClosureField` and moved it into the `dialects/func` directory because it is specific to `func.Invoke`. Register the rewrite as canonicalize rule. Also, implemented `LambdaLifting` that rewrites `py.Constant(ir.Method)` with method body being lambda into a regular function. e.g, this function: ```python @basic def outer(): def inner(x: int): return x + 1 return inner ``` Will rewrite from: ``` func.func @outer() -> !Any { ^0(%outer_self): │ %inner = py.constant.constant Method("inner") : !py.Method[!py.tuple[!py.int], !Any] │ func.return %inner } // func.func outer func.lambda inner() -> !Any { ^0(%inner_self, %x): │ %0 = py.constant.constant 1 : !py.int │ %1 = py.binop.add(%x : !py.int, %0) : ~T │ func.return %1 } // func.lambda inner ``` to: ``` func.func @outer() -> !Any { ^0(%outer_self): │ %inner = py.constant.constant Method("inner") : !py.Method[!py.tuple[!py.int], !Any] │ func.return %inner } // func.func outer func.func @inner(x : !py.int) -> !Any { ^0(%inner_self, %x): │ %0 = py.constant.constant 1 : !py.int │ %1 = py.binop.add(%x : !py.int, %0) : ~T │ func.return %1 } // func.func inner ``` Added test files for these 2 rewrites.
1 parent f4bd7b8 commit 300e4f2

File tree

8 files changed

+188
-11
lines changed

8 files changed

+188
-11
lines changed

src/kirin/dialects/func/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@
1818
)
1919
from kirin.dialects.func._dialect import dialect as dialect
2020

21-
from . import _julia as _julia
21+
from . import (
22+
_julia as _julia,
23+
closurefield as closurefield,
24+
lambdalifting as lambdalifting,
25+
)
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from dataclasses import dataclass
2-
31
from kirin import ir
42
from kirin.dialects import py, func
53
from kirin.rewrite.abc import RewriteRule, RewriteResult
64

5+
from ._dialect import dialect
6+
77

8-
@dataclass
9-
class ClosureFieldLowering(RewriteRule):
8+
@dialect.canonicalize
9+
class ClosureField(RewriteRule):
1010
"""Lowers captured closure fields into py.Constants.
1111
- Trigger on func.Invoke
1212
- If the callee Method has non-empty .fields, lower its func.GetField to py.Constant
@@ -15,15 +15,17 @@ class ClosureFieldLowering(RewriteRule):
1515
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1616
if not isinstance(node, func.Invoke):
1717
return RewriteResult(has_done_something=False)
18-
1918
method = node.callee
2019
if not method.fields:
2120
return RewriteResult(has_done_something=False)
2221
# Replace func.GetField with py.Constant.
2322
changed = self._lower_captured_fields(method)
2423
if changed:
2524
method.fields = ()
26-
return RewriteResult(has_done_something=changed)
25+
from kirin.passes import TypeInfer
26+
27+
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
28+
return RewriteResult(has_done_something=changed).join(rewrite_result)
2729

2830
def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
2931
fld = getfield_stmt.attributes.get("field")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from kirin import ir
2+
from kirin.dialects import py, func
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from ._dialect import dialect
6+
7+
8+
@dialect.canonicalize
9+
class LambdaLifting(RewriteRule):
10+
"""Lifts func.Lambda methods embedded in py.Constant into func.Function.
11+
- Trigger on py.Constant
12+
"""
13+
14+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
15+
if not isinstance(node, py.Constant):
16+
return RewriteResult(has_done_something=False)
17+
method = self._get_method_from_constant(node)
18+
if method is None:
19+
return RewriteResult(has_done_something=False)
20+
if not isinstance(method.code, func.Lambda):
21+
return RewriteResult(has_done_something=False)
22+
self._promote_lambda(method)
23+
24+
from kirin.passes import TypeInfer
25+
26+
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
27+
return RewriteResult(has_done_something=True).join(rewrite_result)
28+
29+
def _get_method_from_constant(self, const_stmt: py.Constant) -> ir.Method | None:
30+
pyattr_data = const_stmt.value
31+
if isinstance(pyattr_data, ir.PyAttr) and isinstance(
32+
pyattr_data.data, ir.Method
33+
):
34+
return pyattr_data.data
35+
return None
36+
37+
def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
38+
fld = getfield_stmt.attributes.get("field")
39+
if fld:
40+
return getfield_stmt.field
41+
else:
42+
return None
43+
44+
def _promote_lambda(self, method: ir.Method) -> None:
45+
new_method = method.similar()
46+
assert isinstance(
47+
new_method.code, func.Lambda
48+
), "expected method.code to be func.Lambda before promotion"
49+
50+
captured_fields = method.fields
51+
if captured_fields:
52+
for stmt in new_method.code.body.blocks[0].stmts:
53+
if not isinstance(stmt, func.GetField):
54+
continue
55+
idx = self._get_field_index(stmt)
56+
if idx is None:
57+
continue
58+
captured = new_method.fields[idx]
59+
const_stmt = py.Constant(captured)
60+
const_stmt.insert_before(stmt)
61+
if stmt.results and const_stmt.results:
62+
stmt.results[0].replace_by(const_stmt.results[0])
63+
stmt.delete()
64+
new_method.code
65+
66+
fn = func.Function(
67+
sym_name=new_method.code.sym_name,
68+
slots=new_method.code.slots,
69+
signature=new_method.code.signature,
70+
body=new_method.code.body,
71+
)
72+
method.code = fn

src/kirin/passes/fold.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
CFGCompactify,
1111
InlineGetItem,
1212
DeadCodeElimination,
13-
ClosureFieldLowering,
1413
)
1514
from kirin.passes.abc import Pass
1615
from kirin.rewrite.abc import RewriteResult
@@ -29,7 +28,6 @@ class Fold(Pass):
2928
- `InlineGetItem`
3029
- `Call2Invoke`
3130
- `DeadCodeElimination`
32-
- `ClosureFieldLowering`
3331
"""
3432

3533
hint_const: HintConst = field(init=False)
@@ -48,7 +46,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
4846
InlineGetItem(),
4947
Call2Invoke(),
5048
DeadCodeElimination(),
51-
ClosureFieldLowering(),
5249
)
5350
)
5451
)

src/kirin/rewrite/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@
1313
from .wrap_const import WrapConst as WrapConst
1414
from .call2invoke import Call2Invoke as Call2Invoke
1515
from .type_assert import InlineTypeAssert as InlineTypeAssert
16-
from .closurefieldlowering import ClosureFieldLowering as ClosureFieldLowering

test/dialects/func/__init__.py

Whitespace-only changes.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import cast
2+
3+
from kirin import rewrite
4+
from kirin.prelude import basic
5+
from kirin.dialects import py, func
6+
7+
8+
def test_rewrite_closure_inner_lambda():
9+
@basic
10+
def outer(y: int):
11+
def inner(x: int):
12+
return x * y + 1
13+
14+
return inner
15+
16+
inner_ker = outer(y=10)
17+
18+
@basic
19+
def main_lambda(z: int):
20+
return inner_ker(z)
21+
22+
main_invoke = main_lambda.code.regions[0].blocks[0].stmts.at(0)
23+
inner_lambda = cast(func.Invoke, main_invoke).callee.code
24+
inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0)
25+
assert isinstance(
26+
inner_getfield_stmt, func.GetField
27+
), "expected GetField before rewrite"
28+
29+
rewrite.Walk(func.closurefield.ClosureField()).rewrite(main_lambda.code)
30+
31+
inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0)
32+
assert isinstance(
33+
inner_getfield_stmt, py.Constant
34+
), "GetField should be lowered to Constant"
35+
36+
37+
def test_rewrite_closure_no_fields():
38+
@basic
39+
def bar():
40+
def goo(x: int):
41+
a = (3, 4)
42+
return a[0]
43+
44+
def boo(y):
45+
return goo(y) + 1
46+
47+
return boo(4)
48+
49+
before = bar.code.regions[0].blocks[0].stmts.at(0)
50+
rewrite.Walk(func.closurefield.ClosureField()).rewrite(bar.code)
51+
after = bar.code.regions[0].blocks[0].stmts.at(0)
52+
assert before is after
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from kirin import ir, rewrite
2+
from kirin.prelude import basic
3+
from kirin.dialects import py, func
4+
5+
6+
def test_rewrite_inner_lambda():
7+
@basic
8+
def outer():
9+
def inner(x: int):
10+
return x + 1
11+
12+
return inner
13+
14+
pyconstant_stmt = outer.code.regions[0].blocks[0].stmts.at(0)
15+
assert isinstance(pyconstant_stmt, py.Constant), "expected a Constant in outer body"
16+
assert isinstance(
17+
pyconstant_stmt.value, ir.PyAttr
18+
), "expected a PyAttr in outer body"
19+
assert isinstance(
20+
pyconstant_stmt.value.data.code, func.Lambda
21+
), "expected a lambda Method in outer body"
22+
23+
rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer.code)
24+
assert isinstance(
25+
pyconstant_stmt.value.data.code, func.Function
26+
), "expected a Function in outer body"
27+
28+
29+
def test_rewrite_inner_lambda_with_captured_vars():
30+
@basic
31+
def outer2():
32+
z = 10
33+
y = 3 + z
34+
35+
def inner2(x: int):
36+
return x + y + 5
37+
38+
return inner2
39+
40+
pyconstant_stmt = outer2.code.regions[0].blocks[0].stmts.at(0)
41+
assert isinstance(pyconstant_stmt, py.Constant), "expected a Constant in outer body"
42+
assert isinstance(
43+
pyconstant_stmt.value, ir.PyAttr
44+
), "expected a PyAttr in outer body"
45+
assert isinstance(
46+
pyconstant_stmt.value.data.code, func.Lambda
47+
), "expected a lambda Method in outer body"
48+
rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer2.code)
49+
assert isinstance(
50+
pyconstant_stmt.value.data.code, func.Function
51+
), "expected a Function in outer body"

0 commit comments

Comments
 (0)