diff --git a/src/bloqade/rewrite/passes/aggressive_unroll.py b/src/bloqade/rewrite/passes/aggressive_unroll.py index 2272d3c8..989eac70 100644 --- a/src/bloqade/rewrite/passes/aggressive_unroll.py +++ b/src/bloqade/rewrite/passes/aggressive_unroll.py @@ -14,13 +14,15 @@ InlineGetItem, InlineGetField, DeadCodeElimination, + CommonSubexpressionElimination, ) from kirin.dialects import scf, ilist from kirin.ir.method import Method from kirin.rewrite.abc import RewriteResult -from kirin.rewrite.cse import CommonSubexpressionElimination from kirin.passes.aggressive import UnrollScf +from .canonicalize_ilist import CanonicalizeIList + @dataclass class Fold(Pass): @@ -55,30 +57,36 @@ class AggressiveUnroll(Pass): fold: Fold = field(init=False) typeinfer: TypeInfer = field(init=False) scf_unroll: UnrollScf = field(init=False) + canonicalize_ilist: CanonicalizeIList = field(init=False) def __post_init__(self): self.fold = Fold(self.dialects, no_raise=self.no_raise) self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise) + self.canonicalize_ilist = CanonicalizeIList( + self.dialects, no_raise=self.no_raise + ) def unsafe_run(self, mt: Method) -> RewriteResult: result = RewriteResult() + result = self.fold.unsafe_run(mt).join(result) result = self.scf_unroll.unsafe_run(mt).join(result) + self.typeinfer.unsafe_run( + mt + ) # Do not join the result of typeinfer or fixpoint will waste time result = ( Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) .rewrite(mt.code) .join(result) ) - self.typeinfer.unsafe_run(mt) - result = self.fold.unsafe_run(mt).join(result) result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result) result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) - + result = self.canonicalize_ilist.fixpoint(mt).join(result) rule = Chain( CommonSubexpressionElimination(), DeadCodeElimination(), ) - result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) + result = Walk(rule).rewrite(mt.code).join(result) return result