Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/kirin/dialects/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@

from . import (
_julia as _julia,
closurefield as closurefield,
lambdalifting as lambdalifting,
)
2 changes: 2 additions & 0 deletions src/kirin/dialects/func/rewrite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .closurefield import ClosureField as ClosureField
from .lambdalifting import LambdaLifting as LambdaLifting
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from ._dialect import dialect
from ..stmts import Invoke, GetField
from .._dialect import dialect


@dialect.canonicalize
Expand All @@ -13,7 +13,7 @@ class ClosureField(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, func.Invoke):
if not isinstance(node, Invoke):
return RewriteResult(has_done_something=False)
method = node.callee
if not method.fields:
Expand All @@ -27,7 +27,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
return RewriteResult(has_done_something=changed).join(rewrite_result)

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
Expand All @@ -43,7 +43,7 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
for region in method.code.regions:
for block in region.blocks:
for stmt in list(block.stmts):
if not isinstance(stmt, func.GetField):
if not isinstance(stmt, GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
Expand All @@ -53,6 +53,8 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
if isinstance(captured, ir.Method):
continue
# Replace GetField with Constant.
from kirin.dialects import py

const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing_extensions import TYPE_CHECKING

if TYPE_CHECKING:
from kirin.dialects import py

from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from ._dialect import dialect
from ..stmts import Lambda, Function, GetField
from .._dialect import dialect


@dialect.canonicalize
Expand All @@ -12,12 +17,14 @@ class LambdaLifting(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
from kirin.dialects import py

if not isinstance(node, py.Constant):
return RewriteResult(has_done_something=False)
method = self._get_method_from_constant(node)
if method is None:
return RewriteResult(has_done_something=False)
if not isinstance(method.code, func.Lambda):
if not isinstance(method.code, Lambda):
return RewriteResult(has_done_something=False)
self._promote_lambda(method)

Expand All @@ -26,15 +33,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
return RewriteResult(has_done_something=True).join(rewrite_result)

def _get_method_from_constant(self, const_stmt: py.Constant) -> ir.Method | None:
def _get_method_from_constant(self, const_stmt: "py.Constant") -> ir.Method | None:
pyattr_data = const_stmt.value
if isinstance(pyattr_data, ir.PyAttr) and isinstance(
pyattr_data.data, ir.Method
):
return pyattr_data.data
return None

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
Expand All @@ -44,26 +51,28 @@ def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _promote_lambda(self, method: ir.Method) -> None:
new_method = method.similar()
assert isinstance(
new_method.code, func.Lambda
new_method.code, Lambda
), "expected method.code to be func.Lambda before promotion"

captured_fields = method.fields
if captured_fields:
for stmt in new_method.code.body.blocks[0].stmts:
if not isinstance(stmt, func.GetField):
if not isinstance(stmt, GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
continue
captured = new_method.fields[idx]
from kirin.dialects import py

const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
stmt.results[0].replace_by(const_stmt.results[0])
stmt.delete()
new_method.code

fn = func.Function(
fn = Function(
sym_name=new_method.code.sym_name,
slots=new_method.code.slots,
signature=new_method.code.signature,
Expand Down
5 changes: 3 additions & 2 deletions test/dialects/func/test_closurefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from kirin import rewrite
from kirin.prelude import basic
from kirin.dialects import py, func
from kirin.dialects.func.rewrite import closurefield


def test_rewrite_closure_inner_lambda():
Expand All @@ -26,7 +27,7 @@ def main_lambda(z: int):
inner_getfield_stmt, func.GetField
), "expected GetField before rewrite"

rewrite.Walk(func.closurefield.ClosureField()).rewrite(main_lambda.code)
rewrite.Walk(closurefield.ClosureField()).rewrite(main_lambda.code)

inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0)
assert isinstance(
Expand All @@ -47,6 +48,6 @@ def boo(y):
return boo(4)

before = bar.code.regions[0].blocks[0].stmts.at(0)
rewrite.Walk(func.closurefield.ClosureField()).rewrite(bar.code)
rewrite.Walk(closurefield.ClosureField()).rewrite(bar.code)
after = bar.code.regions[0].blocks[0].stmts.at(0)
assert before is after
5 changes: 3 additions & 2 deletions test/dialects/func/test_lambdalifting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kirin import ir, rewrite
from kirin.prelude import basic
from kirin.dialects import py, func
from kirin.dialects.func.rewrite import lambdalifting


def test_rewrite_inner_lambda():
Expand All @@ -20,7 +21,7 @@ def inner(x: int):
pyconstant_stmt.value.data.code, func.Lambda
), "expected a lambda Method in outer body"

rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer.code)
rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer.code)
assert isinstance(
pyconstant_stmt.value.data.code, func.Function
), "expected a Function in outer body"
Expand All @@ -45,7 +46,7 @@ def inner2(x: int):
assert isinstance(
pyconstant_stmt.value.data.code, func.Lambda
), "expected a lambda Method in outer body"
rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer2.code)
rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer2.code)
assert isinstance(
pyconstant_stmt.value.data.code, func.Function
), "expected a Function in outer body"