Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/kirin/dialects/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@

from . import (
_julia as _julia,
closurefield as closurefield,
lambdalifting as lambdalifting,
)
2 changes: 2 additions & 0 deletions src/kirin/dialects/func/rewrite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .closurefield import ClosureField as ClosureField
from .lambdalifting import LambdaLifting as LambdaLifting
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from kirin import ir
from kirin.dialects import py, func
from kirin.passes import TypeInfer
from kirin.rewrite.abc import RewriteRule, RewriteResult

from ._dialect import dialect
from ..stmts import Invoke, GetField
from .._dialect import dialect


@dialect.canonicalize
Expand All @@ -13,7 +14,7 @@ class ClosureField(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, func.Invoke):
if not isinstance(node, Invoke):
return RewriteResult(has_done_something=False)
method = node.callee
if not method.fields:
Expand All @@ -22,12 +23,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
changed = self._lower_captured_fields(method)
if changed:
method.fields = ()
from kirin.passes import TypeInfer

rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
return RewriteResult(has_done_something=changed).join(rewrite_result)

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
Expand All @@ -43,7 +43,7 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
for region in method.code.regions:
for block in region.blocks:
for stmt in list(block.stmts):
if not isinstance(stmt, func.GetField):
if not isinstance(stmt, GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
Expand All @@ -53,6 +53,8 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
if isinstance(captured, ir.Method):
continue
# Replace GetField with Constant.
from kirin.dialects import py

const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from kirin import ir
from kirin.dialects import py, func
from kirin.passes import TypeInfer
from kirin.dialects import py
from kirin.rewrite.abc import RewriteRule, RewriteResult

from ._dialect import dialect
from ..stmts import Lambda, Function, GetField
from .._dialect import dialect


@dialect.canonicalize
Expand All @@ -12,17 +14,17 @@ class LambdaLifting(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
from kirin.dialects import py

if not isinstance(node, py.Constant):
return RewriteResult(has_done_something=False)
method = self._get_method_from_constant(node)
if method is None:
return RewriteResult(has_done_something=False)
if not isinstance(method.code, func.Lambda):
if not isinstance(method.code, Lambda):
return RewriteResult(has_done_something=False)
self._promote_lambda(method)

from kirin.passes import TypeInfer

rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
return RewriteResult(has_done_something=True).join(rewrite_result)

Expand All @@ -34,7 +36,7 @@ def _get_method_from_constant(self, const_stmt: py.Constant) -> ir.Method | None
return pyattr_data.data
return None

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
Expand All @@ -44,26 +46,28 @@ def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _promote_lambda(self, method: ir.Method) -> None:
new_method = method.similar()
assert isinstance(
new_method.code, func.Lambda
new_method.code, Lambda
), "expected method.code to be func.Lambda before promotion"

captured_fields = method.fields
if captured_fields:
for stmt in new_method.code.body.blocks[0].stmts:
if not isinstance(stmt, func.GetField):
if not isinstance(stmt, GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
continue
captured = new_method.fields[idx]
from kirin.dialects import py

const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
stmt.results[0].replace_by(const_stmt.results[0])
stmt.delete()
new_method.code

fn = func.Function(
fn = Function(
sym_name=new_method.code.sym_name,
slots=new_method.code.slots,
signature=new_method.code.signature,
Expand Down
5 changes: 3 additions & 2 deletions test/dialects/func/test_closurefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from kirin import rewrite
from kirin.prelude import basic
from kirin.dialects import py, func
from kirin.dialects.func.rewrite import closurefield


def test_rewrite_closure_inner_lambda():
Expand All @@ -26,7 +27,7 @@ def main_lambda(z: int):
inner_getfield_stmt, func.GetField
), "expected GetField before rewrite"

rewrite.Walk(func.closurefield.ClosureField()).rewrite(main_lambda.code)
rewrite.Walk(closurefield.ClosureField()).rewrite(main_lambda.code)

inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0)
assert isinstance(
Expand All @@ -47,6 +48,6 @@ def boo(y):
return boo(4)

before = bar.code.regions[0].blocks[0].stmts.at(0)
rewrite.Walk(func.closurefield.ClosureField()).rewrite(bar.code)
rewrite.Walk(closurefield.ClosureField()).rewrite(bar.code)
after = bar.code.regions[0].blocks[0].stmts.at(0)
assert before is after
5 changes: 3 additions & 2 deletions test/dialects/func/test_lambdalifting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kirin import ir, rewrite
from kirin.prelude import basic
from kirin.dialects import py, func
from kirin.dialects.func.rewrite import lambdalifting


def test_rewrite_inner_lambda():
Expand All @@ -20,7 +21,7 @@ def inner(x: int):
pyconstant_stmt.value.data.code, func.Lambda
), "expected a lambda Method in outer body"

rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer.code)
rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer.code)
assert isinstance(
pyconstant_stmt.value.data.code, func.Function
), "expected a Function in outer body"
Expand All @@ -45,7 +46,7 @@ def inner2(x: int):
assert isinstance(
pyconstant_stmt.value.data.code, func.Lambda
), "expected a lambda Method in outer body"
rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer2.code)
rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer2.code)
assert isinstance(
pyconstant_stmt.value.data.code, func.Function
), "expected a Function in outer body"