Skip to content

Commit f62319f

Browse files
committed
Tune aggressive unroll for QEC needs
1 parent a24c328 commit f62319f

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/bloqade/rewrite/passes/aggressive_unroll.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from kirin.rewrite.cse import CommonSubexpressionElimination
2222
from kirin.passes.aggressive import UnrollScf
2323

24+
from .canonicalize_ilist import CanonicalizeIList
25+
2426

2527
@dataclass
2628
class Fold(Pass):
@@ -55,30 +57,36 @@ class AggressiveUnroll(Pass):
5557
fold: Fold = field(init=False)
5658
typeinfer: TypeInfer = field(init=False)
5759
scf_unroll: UnrollScf = field(init=False)
60+
canonicalize_ilist: CanonicalizeIList = field(init=False)
5861

5962
def __post_init__(self):
6063
self.fold = Fold(self.dialects, no_raise=self.no_raise)
6164
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
6265
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
66+
self.canonicalize_ilist = CanonicalizeIList(
67+
self.dialects, no_raise=self.no_raise
68+
)
6369

6470
def unsafe_run(self, mt: Method) -> RewriteResult:
6571
result = RewriteResult()
72+
result = self.fold.unsafe_run(mt).join(result)
6673
result = self.scf_unroll.unsafe_run(mt).join(result)
74+
self.typeinfer.unsafe_run(
75+
mt
76+
) # Do not join the result of typeinfer or fixpoint will waste time
6777
result = (
6878
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
6979
.rewrite(mt.code)
7080
.join(result)
7181
)
72-
self.typeinfer.unsafe_run(mt)
73-
result = self.fold.unsafe_run(mt).join(result)
7482
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
7583
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
76-
84+
result = self.canonicalize_ilist.fixpoint(mt).join(result)
7785
rule = Chain(
7886
CommonSubexpressionElimination(),
7987
DeadCodeElimination(),
8088
)
81-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
89+
result = Walk(rule).rewrite(mt.code).join(result)
8290

8391
return result
8492

0 commit comments

Comments
 (0)