|
1 | 1 | from dataclasses import field, dataclass |
2 | 2 |
|
3 | 3 | from kirin import ir |
4 | | -from kirin.passes import Pass, TypeInfer |
5 | | -from kirin.rewrite import ( |
6 | | - Walk, |
7 | | - Chain, |
8 | | - Inline, |
9 | | - Fixpoint, |
10 | | - WrapConst, |
11 | | - Call2Invoke, |
12 | | - ConstantFold, |
13 | | - CFGCompactify, |
14 | | - InlineGetItem, |
15 | | - InlineGetField, |
16 | | - DeadCodeElimination, |
17 | | - CommonSubexpressionElimination, |
18 | | -) |
19 | | -from kirin.analysis import const |
20 | | -from kirin.dialects import scf, ilist |
| 4 | +from kirin.passes import Pass |
21 | 5 | from kirin.ir.method import Method |
22 | 6 | from kirin.rewrite.abc import RewriteResult |
23 | 7 |
|
24 | 8 | from bloqade.qasm2.dialects import expr |
| 9 | +from bloqade.rewrite.passes import AggressiveUnroll |
25 | 10 |
|
26 | 11 | from .unroll_if import UnrollIfs |
27 | 12 |
|
|
30 | 15 | class QASM2Fold(Pass): |
31 | 16 | """Fold pass for qasm2.extended""" |
32 | 17 |
|
33 | | - constprop: const.Propagate = field(init=False) |
34 | 18 | inline_gate_subroutine: bool = True |
35 | 19 | unroll_ifs: bool = True |
| 20 | + aggressive_unroll: AggressiveUnroll = field(init=False) |
36 | 21 |
|
37 | 22 | def __post_init__(self): |
38 | | - self.constprop = const.Propagate(self.dialects) |
39 | | - self.typeinfer = TypeInfer(self.dialects) |
| 23 | + def inline_simple(node: ir.Statement): |
| 24 | + if isinstance(node, expr.GateFunction): |
| 25 | + return self.inline_gate_subroutine |
40 | 26 |
|
41 | | - def unsafe_run(self, mt: Method) -> RewriteResult: |
42 | | - result = RewriteResult() |
43 | | - frame, _ = self.constprop.run_analysis(mt) |
44 | | - result = Walk(WrapConst(frame)).rewrite(mt.code).join(result) |
45 | | - rule = Chain( |
46 | | - ConstantFold(), |
47 | | - Call2Invoke(), |
48 | | - InlineGetField(), |
49 | | - InlineGetItem(), |
50 | | - DeadCodeElimination(), |
51 | | - CommonSubexpressionElimination(), |
52 | | - ) |
53 | | - result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) |
| 27 | + return True |
54 | 28 |
|
55 | | - result = ( |
56 | | - Walk( |
57 | | - Chain( |
58 | | - scf.unroll.PickIfElse(), |
59 | | - scf.unroll.ForLoop(), |
60 | | - scf.trim.UnusedYield(), |
61 | | - ) |
62 | | - ) |
63 | | - .rewrite(mt.code) |
64 | | - .join(result) |
| 29 | + self.aggressive_unroll = AggressiveUnroll( |
| 30 | + self.dialects, inline_simple, no_raise=self.no_raise |
65 | 31 | ) |
66 | 32 |
|
67 | | - if self.unroll_ifs: |
68 | | - UnrollIfs(mt.dialects).unsafe_run(mt).join(result) |
69 | | - |
70 | | - # run typeinfer again after unroll etc. because we now insert |
71 | | - # a lot of new nodes, which might have more precise types |
72 | | - self.typeinfer.unsafe_run(mt) |
73 | | - result = ( |
74 | | - Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) |
75 | | - .rewrite(mt.code) |
76 | | - .join(result) |
77 | | - ) |
78 | | - |
79 | | - def inline_simple(node: ir.Statement): |
80 | | - if isinstance(node, expr.GateFunction): |
81 | | - return self.inline_gate_subroutine |
| 33 | + def unsafe_run(self, mt: Method) -> RewriteResult: |
| 34 | + result = RewriteResult() |
82 | 35 |
|
83 | | - if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)): |
84 | | - return True # always inline calls outside of loops and if-else |
| 36 | + if self.unroll_ifs: |
| 37 | + result = UnrollIfs(mt.dialects).unsafe_run(mt).join(result) |
85 | 38 |
|
86 | | - # inside loops and if-else, only inline simple functions, i.e. functions with a single block |
87 | | - if (trait := node.get_trait(ir.CallableStmtInterface)) is None: |
88 | | - return False # not a callable, don't inline to be safe |
89 | | - region = trait.get_callable_region(node) |
90 | | - return len(region.blocks) == 1 |
| 39 | + result = self.aggressive_unroll.unsafe_run(mt).join(result) |
91 | 40 |
|
92 | | - result = ( |
93 | | - Walk( |
94 | | - Inline(inline_simple), |
95 | | - ) |
96 | | - .rewrite(mt.code) |
97 | | - .join(result) |
98 | | - ) |
99 | | - result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result) |
100 | 41 | return result |
0 commit comments