|
14 | 14 | InlineGetItem, |
15 | 15 | InlineGetField, |
16 | 16 | DeadCodeElimination, |
| 17 | + CommonSubexpressionElimination, |
17 | 18 | ) |
18 | 19 | from kirin.dialects import scf, ilist |
19 | 20 | from kirin.ir.method import Method |
20 | 21 | from kirin.rewrite.abc import RewriteResult |
21 | | -from kirin.rewrite.cse import CommonSubexpressionElimination |
22 | 22 | from kirin.passes.aggressive import UnrollScf |
23 | 23 |
|
| 24 | +from .canonicalize_ilist import CanonicalizeIList |
| 25 | + |
24 | 26 |
|
25 | 27 | @dataclass |
26 | 28 | class Fold(Pass): |
@@ -55,30 +57,36 @@ class AggressiveUnroll(Pass): |
55 | 57 | fold: Fold = field(init=False) |
56 | 58 | typeinfer: TypeInfer = field(init=False) |
57 | 59 | scf_unroll: UnrollScf = field(init=False) |
| 60 | + canonicalize_ilist: CanonicalizeIList = field(init=False) |
58 | 61 |
|
59 | 62 | def __post_init__(self): |
60 | 63 | self.fold = Fold(self.dialects, no_raise=self.no_raise) |
61 | 64 | self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) |
62 | 65 | self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise) |
| 66 | + self.canonicalize_ilist = CanonicalizeIList( |
| 67 | + self.dialects, no_raise=self.no_raise |
| 68 | + ) |
63 | 69 |
|
64 | 70 | def unsafe_run(self, mt: Method) -> RewriteResult: |
65 | 71 | result = RewriteResult() |
| 72 | + result = self.fold.unsafe_run(mt).join(result) |
66 | 73 | result = self.scf_unroll.unsafe_run(mt).join(result) |
| 74 | + self.typeinfer.unsafe_run( |
| 75 | + mt |
| 76 | + ) # Do not join the result of typeinfer or fixpoint will waste time |
67 | 77 | result = ( |
68 | 78 | Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) |
69 | 79 | .rewrite(mt.code) |
70 | 80 | .join(result) |
71 | 81 | ) |
72 | | - self.typeinfer.unsafe_run(mt) |
73 | | - result = self.fold.unsafe_run(mt).join(result) |
74 | 82 | result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result) |
75 | 83 | result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) |
76 | | - |
| 84 | + result = self.canonicalize_ilist.fixpoint(mt).join(result) |
77 | 85 | rule = Chain( |
78 | 86 | CommonSubexpressionElimination(), |
79 | 87 | DeadCodeElimination(), |
80 | 88 | ) |
81 | | - result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) |
| 89 | + result = Walk(rule).rewrite(mt.code).join(result) |
82 | 90 |
|
83 | 91 | return result |
84 | 92 |
|
|
0 commit comments