|
2 | 2 | from dataclasses import field, dataclass |
3 | 3 |
|
4 | 4 | from kirin import ir |
5 | | -from kirin.passes import Pass, HintConst |
6 | | -from kirin.rewrite import ( |
7 | | - Walk, |
8 | | - Chain, |
9 | | - Fixpoint, |
10 | | - Call2Invoke, |
11 | | - ConstantFold, |
12 | | - InlineGetItem, |
13 | | - InlineGetField, |
14 | | - DeadCodeElimination, |
15 | | -) |
16 | | -from kirin.dialects import ilist |
17 | | -from kirin.ir.method import Method |
| 5 | +from kirin.passes import Pass |
18 | 6 | from kirin.rewrite.abc import RewriteResult |
19 | | -from kirin.rewrite.cse import CommonSubexpressionElimination |
20 | | -from kirin.passes.inline import InlinePass |
21 | 7 |
|
22 | 8 | from bloqade.qasm2.passes.fold import AggressiveUnroll |
23 | 9 | from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs |
24 | 10 |
|
25 | 11 |
|
26 | 12 | @dataclass |
27 | | -class Fold(Pass): |
28 | | - hint_const: HintConst = field(init=False) |
29 | | - |
30 | | - def __post_init__(self): |
31 | | - self.hint_const = HintConst(self.dialects, no_raise=self.no_raise) |
32 | | - |
33 | | - def unsafe_run(self, mt: Method) -> RewriteResult: |
34 | | - result = RewriteResult() |
35 | | - result = self.hint_const.unsafe_run(mt).join(result) |
36 | | - rule = Chain( |
37 | | - ConstantFold(), |
38 | | - Call2Invoke(), |
39 | | - InlineGetField(), |
40 | | - InlineGetItem(), |
41 | | - ilist.rewrite.InlineGetItem(), |
42 | | - ilist.rewrite.HintLen(), |
43 | | - DeadCodeElimination(), |
44 | | - CommonSubexpressionElimination(), |
45 | | - ) |
46 | | - result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) |
| 13 | +class Flatten(Pass): |
47 | 14 |
|
48 | | - return result |
| 15 | + unroll: AggressiveUnroll = field(init=False) |
| 16 | + simplify_if: StimSimplifyIfs = field(init=False) |
49 | 17 |
|
| 18 | + def __post_init__(self): |
| 19 | + self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise) |
| 20 | + self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise) |
50 | 21 |
|
51 | | -class Flatten(Pass): |
52 | 22 | def unsafe_run(self, mt: ir.Method) -> RewriteResult: |
53 | | - rewrite_result = InlinePass(dialects=mt.dialects, no_raise=self.no_raise)(mt) |
54 | | - rewrite_result = AggressiveUnroll(dialects=mt.dialects, no_raise=self.no_raise)( |
55 | | - mt |
56 | | - ).join(rewrite_result) |
57 | | - rewrite_result = StimSimplifyIfs(dialects=mt.dialects, no_raise=self.no_raise)( |
58 | | - mt |
59 | | - ).join(rewrite_result) |
60 | | - |
| 23 | + rewrite_result = RewriteResult() |
| 24 | + rewrite_result = self.simplify_if(mt).join(rewrite_result) |
| 25 | + rewrite_result = self.unroll(mt).join(rewrite_result) |
61 | 26 | return rewrite_result |
0 commit comments