diff --git a/src/kirin/dialects/func/__init__.py b/src/kirin/dialects/func/__init__.py index 62c3d28e6..0aeecc34f 100644 --- a/src/kirin/dialects/func/__init__.py +++ b/src/kirin/dialects/func/__init__.py @@ -20,6 +20,4 @@ from . import ( _julia as _julia, - closurefield as closurefield, - lambdalifting as lambdalifting, ) diff --git a/src/kirin/dialects/func/rewrite/__init__.py b/src/kirin/dialects/func/rewrite/__init__.py new file mode 100644 index 000000000..cf784a986 --- /dev/null +++ b/src/kirin/dialects/func/rewrite/__init__.py @@ -0,0 +1,2 @@ +from .closurefield import ClosureField as ClosureField +from .lambdalifting import LambdaLifting as LambdaLifting diff --git a/src/kirin/dialects/func/closurefield.py b/src/kirin/dialects/func/rewrite/closurefield.py similarity index 86% rename from src/kirin/dialects/func/closurefield.py rename to src/kirin/dialects/func/rewrite/closurefield.py index a3e160432..072891083 100644 --- a/src/kirin/dialects/func/closurefield.py +++ b/src/kirin/dialects/func/rewrite/closurefield.py @@ -1,8 +1,9 @@ from kirin import ir -from kirin.dialects import py, func +from kirin.passes import TypeInfer from kirin.rewrite.abc import RewriteRule, RewriteResult -from ._dialect import dialect +from ..stmts import Invoke, GetField +from .._dialect import dialect @dialect.canonicalize @@ -13,7 +14,7 @@ class ClosureField(RewriteRule): """ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - if not isinstance(node, func.Invoke): + if not isinstance(node, Invoke): return RewriteResult(has_done_something=False) method = node.callee if not method.fields: @@ -22,12 +23,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: changed = self._lower_captured_fields(method) if changed: method.fields = () - from kirin.passes import TypeInfer rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method) return RewriteResult(has_done_something=changed).join(rewrite_result) - def _get_field_index(self, getfield_stmt: func.GetField) -> int | None: + def _get_field_index(self, getfield_stmt: GetField) -> int | None: fld = getfield_stmt.attributes.get("field") if fld: return getfield_stmt.field @@ -43,7 +43,7 @@ def _lower_captured_fields(self, method: ir.Method) -> bool: for region in method.code.regions: for block in region.blocks: for stmt in list(block.stmts): - if not isinstance(stmt, func.GetField): + if not isinstance(stmt, GetField): continue idx = self._get_field_index(stmt) if idx is None: @@ -53,6 +53,8 @@ def _lower_captured_fields(self, method: ir.Method) -> bool: if isinstance(captured, ir.Method): continue # Replace GetField with Constant. + from kirin.dialects import py + const_stmt = py.Constant(captured) const_stmt.insert_before(stmt) if stmt.results and const_stmt.results: diff --git a/src/kirin/dialects/func/lambdalifting.py b/src/kirin/dialects/func/rewrite/lambdalifting.py similarity index 83% rename from src/kirin/dialects/func/lambdalifting.py rename to src/kirin/dialects/func/rewrite/lambdalifting.py index d175f646d..956c09abd 100644 --- a/src/kirin/dialects/func/lambdalifting.py +++ b/src/kirin/dialects/func/rewrite/lambdalifting.py @@ -1,8 +1,10 @@ from kirin import ir -from kirin.dialects import py, func +from kirin.passes import TypeInfer +from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from ._dialect import dialect +from ..stmts import Lambda, Function, GetField +from .._dialect import dialect @dialect.canonicalize @@ -12,17 +14,17 @@ class LambdaLifting(RewriteRule): """ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + from kirin.dialects import py + if not isinstance(node, py.Constant): return RewriteResult(has_done_something=False) method = self._get_method_from_constant(node) if method is None: return RewriteResult(has_done_something=False) - if not isinstance(method.code, func.Lambda): + if not isinstance(method.code, Lambda): return RewriteResult(has_done_something=False) self._promote_lambda(method) - from kirin.passes import TypeInfer - rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method) return RewriteResult(has_done_something=True).join(rewrite_result) @@ -34,7 +36,7 @@ def _get_method_from_constant(self, const_stmt: py.Constant) -> ir.Method | None return pyattr_data.data return None - def _get_field_index(self, getfield_stmt: func.GetField) -> int | None: + def _get_field_index(self, getfield_stmt: GetField) -> int | None: fld = getfield_stmt.attributes.get("field") if fld: return getfield_stmt.field @@ -44,18 +46,20 @@ def _get_field_index(self, getfield_stmt: func.GetField) -> int | None: def _promote_lambda(self, method: ir.Method) -> None: new_method = method.similar() assert isinstance( - new_method.code, func.Lambda + new_method.code, Lambda ), "expected method.code to be func.Lambda before promotion" captured_fields = method.fields if captured_fields: for stmt in new_method.code.body.blocks[0].stmts: - if not isinstance(stmt, func.GetField): + if not isinstance(stmt, GetField): continue idx = self._get_field_index(stmt) if idx is None: continue captured = new_method.fields[idx] + from kirin.dialects import py + const_stmt = py.Constant(captured) const_stmt.insert_before(stmt) if stmt.results and const_stmt.results: @@ -63,7 +67,7 @@ def _promote_lambda(self, method: ir.Method) -> None: stmt.delete() new_method.code - fn = func.Function( + fn = Function( sym_name=new_method.code.sym_name, slots=new_method.code.slots, signature=new_method.code.signature, diff --git a/test/dialects/func/test_closurefield.py b/test/dialects/func/test_closurefield.py index d7e7cb291..03137d50f 100644 --- a/test/dialects/func/test_closurefield.py +++ b/test/dialects/func/test_closurefield.py @@ -3,6 +3,7 @@ from kirin import rewrite from kirin.prelude import basic from kirin.dialects import py, func +from kirin.dialects.func.rewrite import closurefield def test_rewrite_closure_inner_lambda(): @@ -26,7 +27,7 @@ def main_lambda(z: int): inner_getfield_stmt, func.GetField ), "expected GetField before rewrite" - rewrite.Walk(func.closurefield.ClosureField()).rewrite(main_lambda.code) + rewrite.Walk(closurefield.ClosureField()).rewrite(main_lambda.code) inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0) assert isinstance( @@ -47,6 +48,6 @@ def boo(y): return boo(4) before = bar.code.regions[0].blocks[0].stmts.at(0) - rewrite.Walk(func.closurefield.ClosureField()).rewrite(bar.code) + rewrite.Walk(closurefield.ClosureField()).rewrite(bar.code) after = bar.code.regions[0].blocks[0].stmts.at(0) assert before is after diff --git a/test/dialects/func/test_lambdalifting.py b/test/dialects/func/test_lambdalifting.py index cece8e170..aa37896db 100644 --- a/test/dialects/func/test_lambdalifting.py +++ b/test/dialects/func/test_lambdalifting.py @@ -1,6 +1,7 @@ from kirin import ir, rewrite from kirin.prelude import basic from kirin.dialects import py, func +from kirin.dialects.func.rewrite import lambdalifting def test_rewrite_inner_lambda(): @@ -20,7 +21,7 @@ def inner(x: int): pyconstant_stmt.value.data.code, func.Lambda ), "expected a lambda Method in outer body" - rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer.code) + rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer.code) assert isinstance( pyconstant_stmt.value.data.code, func.Function ), "expected a Function in outer body" @@ -45,7 +46,7 @@ def inner2(x: int): assert isinstance( pyconstant_stmt.value.data.code, func.Lambda ), "expected a lambda Method in outer body" - rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer2.code) + rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer2.code) assert isinstance( pyconstant_stmt.value.data.code, func.Function ), "expected a Function in outer body"