Skip to content

Commit f5bee94

Browse files
authored
Fix circular import error. issue #549 (#550)
Fixes bug in issue #549: circular import error when importing `kirin.dialects.py`. Small edits to `src/kirin/dialects/func/closurefield.py` and `src/kirin/dialects/func/lambdalifting.py`. Delaying import.
1 parent 10b4f47 commit f5bee94

File tree

6 files changed

+29
-21
lines changed

6 files changed

+29
-21
lines changed

src/kirin/dialects/func/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,4 @@
2020

2121
from . import (
2222
_julia as _julia,
23-
closurefield as closurefield,
24-
lambdalifting as lambdalifting,
2523
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .closurefield import ClosureField as ClosureField
2+
from .lambdalifting import LambdaLifting as LambdaLifting

src/kirin/dialects/func/closurefield.py renamed to src/kirin/dialects/func/rewrite/closurefield.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from kirin import ir
2-
from kirin.dialects import py, func
2+
from kirin.passes import TypeInfer
33
from kirin.rewrite.abc import RewriteRule, RewriteResult
44

5-
from ._dialect import dialect
5+
from ..stmts import Invoke, GetField
6+
from .._dialect import dialect
67

78

89
@dialect.canonicalize
@@ -13,7 +14,7 @@ class ClosureField(RewriteRule):
1314
"""
1415

1516
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16-
if not isinstance(node, func.Invoke):
17+
if not isinstance(node, Invoke):
1718
return RewriteResult(has_done_something=False)
1819
method = node.callee
1920
if not method.fields:
@@ -22,12 +23,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2223
changed = self._lower_captured_fields(method)
2324
if changed:
2425
method.fields = ()
25-
from kirin.passes import TypeInfer
2626

2727
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
2828
return RewriteResult(has_done_something=changed).join(rewrite_result)
2929

30-
def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
30+
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
3131
fld = getfield_stmt.attributes.get("field")
3232
if fld:
3333
return getfield_stmt.field
@@ -43,7 +43,7 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
4343
for region in method.code.regions:
4444
for block in region.blocks:
4545
for stmt in list(block.stmts):
46-
if not isinstance(stmt, func.GetField):
46+
if not isinstance(stmt, GetField):
4747
continue
4848
idx = self._get_field_index(stmt)
4949
if idx is None:
@@ -53,6 +53,8 @@ def _lower_captured_fields(self, method: ir.Method) -> bool:
5353
if isinstance(captured, ir.Method):
5454
continue
5555
# Replace GetField with Constant.
56+
from kirin.dialects import py
57+
5658
const_stmt = py.Constant(captured)
5759
const_stmt.insert_before(stmt)
5860
if stmt.results and const_stmt.results:

src/kirin/dialects/func/lambdalifting.py renamed to src/kirin/dialects/func/rewrite/lambdalifting.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from kirin import ir
2-
from kirin.dialects import py, func
2+
from kirin.passes import TypeInfer
3+
from kirin.dialects import py
34
from kirin.rewrite.abc import RewriteRule, RewriteResult
45

5-
from ._dialect import dialect
6+
from ..stmts import Lambda, Function, GetField
7+
from .._dialect import dialect
68

79

810
@dialect.canonicalize
@@ -12,17 +14,17 @@ class LambdaLifting(RewriteRule):
1214
"""
1315

1416
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
17+
from kirin.dialects import py
18+
1519
if not isinstance(node, py.Constant):
1620
return RewriteResult(has_done_something=False)
1721
method = self._get_method_from_constant(node)
1822
if method is None:
1923
return RewriteResult(has_done_something=False)
20-
if not isinstance(method.code, func.Lambda):
24+
if not isinstance(method.code, Lambda):
2125
return RewriteResult(has_done_something=False)
2226
self._promote_lambda(method)
2327

24-
from kirin.passes import TypeInfer
25-
2628
rewrite_result = TypeInfer(dialects=method.dialects).unsafe_run(method)
2729
return RewriteResult(has_done_something=True).join(rewrite_result)
2830

@@ -34,7 +36,7 @@ def _get_method_from_constant(self, const_stmt: py.Constant) -> ir.Method | None
3436
return pyattr_data.data
3537
return None
3638

37-
def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
39+
def _get_field_index(self, getfield_stmt: GetField) -> int | None:
3840
fld = getfield_stmt.attributes.get("field")
3941
if fld:
4042
return getfield_stmt.field
@@ -44,26 +46,28 @@ def _get_field_index(self, getfield_stmt: func.GetField) -> int | None:
4446
def _promote_lambda(self, method: ir.Method) -> None:
4547
new_method = method.similar()
4648
assert isinstance(
47-
new_method.code, func.Lambda
49+
new_method.code, Lambda
4850
), "expected method.code to be func.Lambda before promotion"
4951

5052
captured_fields = method.fields
5153
if captured_fields:
5254
for stmt in new_method.code.body.blocks[0].stmts:
53-
if not isinstance(stmt, func.GetField):
55+
if not isinstance(stmt, GetField):
5456
continue
5557
idx = self._get_field_index(stmt)
5658
if idx is None:
5759
continue
5860
captured = new_method.fields[idx]
61+
from kirin.dialects import py
62+
5963
const_stmt = py.Constant(captured)
6064
const_stmt.insert_before(stmt)
6165
if stmt.results and const_stmt.results:
6266
stmt.results[0].replace_by(const_stmt.results[0])
6367
stmt.delete()
6468
new_method.code
6569

66-
fn = func.Function(
70+
fn = Function(
6771
sym_name=new_method.code.sym_name,
6872
slots=new_method.code.slots,
6973
signature=new_method.code.signature,

test/dialects/func/test_closurefield.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from kirin import rewrite
44
from kirin.prelude import basic
55
from kirin.dialects import py, func
6+
from kirin.dialects.func.rewrite import closurefield
67

78

89
def test_rewrite_closure_inner_lambda():
@@ -26,7 +27,7 @@ def main_lambda(z: int):
2627
inner_getfield_stmt, func.GetField
2728
), "expected GetField before rewrite"
2829

29-
rewrite.Walk(func.closurefield.ClosureField()).rewrite(main_lambda.code)
30+
rewrite.Walk(closurefield.ClosureField()).rewrite(main_lambda.code)
3031

3132
inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0)
3233
assert isinstance(
@@ -47,6 +48,6 @@ def boo(y):
4748
return boo(4)
4849

4950
before = bar.code.regions[0].blocks[0].stmts.at(0)
50-
rewrite.Walk(func.closurefield.ClosureField()).rewrite(bar.code)
51+
rewrite.Walk(closurefield.ClosureField()).rewrite(bar.code)
5152
after = bar.code.regions[0].blocks[0].stmts.at(0)
5253
assert before is after

test/dialects/func/test_lambdalifting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from kirin import ir, rewrite
22
from kirin.prelude import basic
33
from kirin.dialects import py, func
4+
from kirin.dialects.func.rewrite import lambdalifting
45

56

67
def test_rewrite_inner_lambda():
@@ -20,7 +21,7 @@ def inner(x: int):
2021
pyconstant_stmt.value.data.code, func.Lambda
2122
), "expected a lambda Method in outer body"
2223

23-
rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer.code)
24+
rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer.code)
2425
assert isinstance(
2526
pyconstant_stmt.value.data.code, func.Function
2627
), "expected a Function in outer body"
@@ -45,7 +46,7 @@ def inner2(x: int):
4546
assert isinstance(
4647
pyconstant_stmt.value.data.code, func.Lambda
4748
), "expected a lambda Method in outer body"
48-
rewrite.Walk(func.lambdalifting.LambdaLifting()).rewrite(outer2.code)
49+
rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer2.code)
4950
assert isinstance(
5051
pyconstant_stmt.value.data.code, func.Function
5152
), "expected a Function in outer body"

0 commit comments

Comments
 (0)