Skip to content

Commit 1642980

Browse files
cduckweinbe58
andauthored
Adjust AggressiveUnroll to fix type error during unroll (#593)
I'm not sure what was causing the error but rearranging the passes with the help of @weinbe58 fixed the problem. For context, this was the error. I did not find a MWE. ```python Traceback (most recent call last): File "my_script.py", line 49, in compile_to_stim PhysicalAndSquinToStim(main.dialects, no_raise=no_raise)(main) File "site-packages/kirin/passes/abc.py", line 30, in __call__ result = self.unsafe_run(mt) File "squin_and_to_stim.py", line 99, in unsafe_run .unsafe_run(mt) File "site-packages/bloqade/rewrite/passes/canonicalize_ilist.py", line 29, in unsafe_run ).rewrite(mt.code) File "site-packages/kirin/rewrite/fixpoint.py", line 24, in rewrite result = self.rule.rewrite(node) File "site-packages/kirin/rewrite/walk.py", line 40, in rewrite result = self.rule.rewrite(subnode) File "site-packages/kirin/rewrite/chain.py", line 29, in rewrite result = rule.rewrite(node) File "site-packages/kirin/rewrite/abc.py", line 38, in rewrite return self.rewrite_Statement(cast(Statement, node)) File "site-packages/kirin/dialects/ilist/rewrite/flatten_add.py", line 49, in rewrite_Statement assert isinstance(rhs_type := rhs.type, types.Generic), "Impossible" AssertionError: Impossible ``` --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent b031dd6 commit 1642980

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

src/bloqade/rewrite/passes/aggressive_unroll.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
InlineGetItem,
1515
InlineGetField,
1616
DeadCodeElimination,
17+
CommonSubexpressionElimination,
1718
)
1819
from kirin.dialects import scf, ilist
1920
from kirin.ir.method import Method
2021
from kirin.rewrite.abc import RewriteResult
21-
from kirin.rewrite.cse import CommonSubexpressionElimination
2222
from kirin.passes.aggressive import UnrollScf
2323

24+
from .canonicalize_ilist import CanonicalizeIList
25+
2426

2527
@dataclass
2628
class Fold(Pass):
@@ -55,30 +57,36 @@ class AggressiveUnroll(Pass):
5557
fold: Fold = field(init=False)
5658
typeinfer: TypeInfer = field(init=False)
5759
scf_unroll: UnrollScf = field(init=False)
60+
canonicalize_ilist: CanonicalizeIList = field(init=False)
5861

5962
def __post_init__(self):
6063
self.fold = Fold(self.dialects, no_raise=self.no_raise)
6164
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
6265
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
66+
self.canonicalize_ilist = CanonicalizeIList(
67+
self.dialects, no_raise=self.no_raise
68+
)
6369

6470
def unsafe_run(self, mt: Method) -> RewriteResult:
6571
result = RewriteResult()
72+
result = self.fold.unsafe_run(mt).join(result)
6673
result = self.scf_unroll.unsafe_run(mt).join(result)
74+
self.typeinfer.unsafe_run(
75+
mt
76+
) # Do not join the result of typeinfer or fixpoint will waste time
6777
result = (
6878
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
6979
.rewrite(mt.code)
7080
.join(result)
7181
)
72-
self.typeinfer.unsafe_run(mt)
73-
result = self.fold.unsafe_run(mt).join(result)
7482
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
7583
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
76-
84+
result = self.canonicalize_ilist.fixpoint(mt).join(result)
7785
rule = Chain(
7886
CommonSubexpressionElimination(),
7987
DeadCodeElimination(),
8088
)
81-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
89+
result = Walk(rule).rewrite(mt.code).join(result)
8290

8391
return result
8492

0 commit comments

Comments
 (0)