Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/kirin/dialects/func/closurefield.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from .stmts import Invoke, GetField
from ._dialect import dialect


Expand All @@ -13,7 +13,7 @@ class ClosureField(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, func.Invoke):
if not isinstance(node, Invoke):
return RewriteResult(has_done_something=False)
method = node.callee
if not method.fields:
Expand All @@ -27,7 +27,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
return RewriteResult(has_done_something=changed).join(rewrite_result)

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
Expand All @@ -43,7 +43,7 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
for region in method.code.regions:
for block in region.blocks:
for stmt in list(block.stmts):
if not isinstance(stmt, func.GetField):
if not isinstance(stmt, GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
Expand All @@ -53,6 +53,8 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
if isinstance(captured, ir.Method):
continue
# Replace GetField with Constant.
from kirin.dialects import py

const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
Expand Down
23 changes: 16 additions & 7 deletions src/kirin/dialects/func/lambdalifting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing_extensions import TYPE_CHECKING

if TYPE_CHECKING:
from kirin.dialects import py

from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from .stmts import Lambda, Function, GetField
from ._dialect import dialect


Expand All @@ -12,12 +17,14 @@ class LambdaLifting(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
from kirin.dialects import py

if not isinstance(node, py.Constant):
return RewriteResult(has_done_something=False)
method = self._get_method_from_constant(node)
if method is None:
return RewriteResult(has_done_something=False)
if not isinstance(method.code, func.Lambda):
if not isinstance(method.code, Lambda):
return RewriteResult(has_done_something=False)
self._promote_lambda(method)

Expand All @@ -26,15 +33,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
return RewriteResult(has_done_something=True).join(rewrite_result)

def _get_method_from_constant(self, const_stmt: py.Constant) -> ir.Method | None:
def _get_method_from_constant(self, const_stmt: "py.Constant") -> ir.Method | None:
pyattr_data = const_stmt.value
if isinstance(pyattr_data, ir.PyAttr) and isinstance(
pyattr_data.data, ir.Method
):
return pyattr_data.data
return None

def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
fld = getfield_stmt.attributes.get("field")
if fld:
return getfield_stmt.field
Expand All @@ -44,26 +51,28 @@ def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
def _promote_lambda(self, method: ir.Method) -> None:
new_method = method.similar()
assert isinstance(
new_method.code, func.Lambda
new_method.code, Lambda
), "expected method.code to be func.Lambda before promotion"

captured_fields = method.fields
if captured_fields:
for stmt in new_method.code.body.blocks[0].stmts:
if not isinstance(stmt, func.GetField):
if not isinstance(stmt, GetField):
continue
idx = self._get_field_index(stmt)
if idx is None:
continue
captured = new_method.fields[idx]
from kirin.dialects import py

const_stmt = py.Constant(captured)
const_stmt.insert_before(stmt)
if stmt.results and const_stmt.results:
stmt.results[0].replace_by(const_stmt.results[0])
stmt.delete()
new_method.code

fn = func.Function(
fn = Function(
sym_name=new_method.code.sym_name,
slots=new_method.code.slots,
signature=new_method.code.signature,
Expand Down