|
10 | 10 | from kirin.ir.method import Method |
11 | 11 | from kirin.passes.abc import Pass |
12 | 12 | from kirin.rewrite.abc import RewriteResult |
| 13 | +from kirin.passes.hint_const import HintConst |
13 | 14 |
|
14 | 15 | from bloqade.stim.rewrite import ( |
| 16 | + IfToStimPartial, |
15 | 17 | PyConstantToStim, |
| 18 | + ResolveGetRecIdx, |
16 | 19 | SquinNoiseToStim, |
17 | 20 | SquinQubitToStim, |
| 21 | + SetDetectorPartial, |
18 | 22 | SquinMeasureToStim, |
| 23 | + SetObservablePartial, |
19 | 24 | ) |
20 | 25 | from bloqade.squin.rewrite import ( |
21 | 26 | SquinU3ToClifford, |
|
24 | 29 | ) |
25 | 30 | from bloqade.rewrite.passes import CanonicalizeIList |
26 | 31 | from bloqade.analysis.address import AddressAnalysis |
| 32 | +from bloqade.record_idx_helper import dialect as record_idx_helper_dialect |
27 | 33 | from bloqade.analysis.measure_id import MeasurementIDAnalysis |
28 | 34 | from bloqade.stim.passes.flatten import Flatten |
29 | 35 |
|
30 | | -from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim |
31 | | - |
32 | 36 |
|
33 | 37 | @dataclass |
34 | 38 | class SquinToStimPass(Pass): |
35 | 39 |
|
36 | 40 | def unsafe_run(self, mt: Method) -> RewriteResult: |
37 | 41 |
|
38 | | - # inline aggressively: |
39 | 42 | rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint( |
40 | 43 | mt |
41 | 44 | ) |
42 | 45 |
|
43 | | - # after this the program should be in a state where it is analyzable |
44 | | - # ------------------------------------------------------------------- |
45 | | - |
46 | | - mia = MeasurementIDAnalysis(dialects=mt.dialects) |
47 | | - meas_analysis_frame, _ = mia.run(mt) |
48 | | - |
49 | 46 | aa = AddressAnalysis(dialects=mt.dialects) |
50 | 47 | address_analysis_frame, _ = aa.run(mt) |
51 | 48 |
|
52 | | - # wrap the address analysis result |
53 | 49 | rewrite_result = ( |
54 | 50 | Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries)) |
55 | 51 | .rewrite(mt.code) |
56 | 52 | .join(rewrite_result) |
57 | 53 | ) |
58 | 54 |
|
59 | | - # 2. rewrite |
60 | | - ## Invoke DCE afterwards to eliminate any GetItems |
61 | | - ## that are no longer being used. This allows for |
62 | | - ## SquinMeasureToStim to safely eliminate |
63 | | - ## unused measure statements. |
| 55 | + # --- partial rewrite (before analysis) --- |
64 | 56 | rewrite_result = ( |
65 | | - Chain( |
66 | | - Walk(IfToStim(measure_frame=meas_analysis_frame)), |
67 | | - Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)), |
68 | | - Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)), |
69 | | - Fixpoint(Walk(DeadCodeElimination())), |
| 57 | + Walk( |
| 58 | + Chain( |
| 59 | + SetDetectorPartial(), |
| 60 | + SetObservablePartial(), |
| 61 | + IfToStimPartial(), |
| 62 | + ) |
70 | 63 | ) |
71 | 64 | .rewrite(mt.code) |
72 | 65 | .join(rewrite_result) |
73 | 66 | ) |
74 | 67 |
|
75 | | - # Rewrite the noise statements first. |
76 | 68 | rewrite_result = Walk(SquinNoiseToStim()).rewrite(mt.code).join(rewrite_result) |
77 | | - |
78 | | - # Wrap Rewrite + SquinToStim can happen w/ standard walk |
79 | 69 | rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result) |
| 70 | + rewrite_result = Walk(SquinQubitToStim()).rewrite(mt.code).join(rewrite_result) |
80 | 71 |
|
| 72 | + # --- analysis (produces RecId for GetRecIdxFromMeasurement / GetRecIdxFromPredicate) --- |
| 73 | + analysis_dialects = mt.dialects.add(record_idx_helper_dialect) |
81 | 74 | rewrite_result = ( |
82 | | - Walk( |
83 | | - Chain( |
84 | | - SquinQubitToStim(), |
85 | | - SquinMeasureToStim(), |
86 | | - ) |
| 75 | + HintConst(analysis_dialects, no_raise=self.no_raise) |
| 76 | + .unsafe_run(mt) |
| 77 | + .join(rewrite_result) |
| 78 | + ) |
| 79 | + mia = MeasurementIDAnalysis(dialects=analysis_dialects) |
| 80 | + meas_analysis_frame, _ = mia.run(mt) |
| 81 | + |
| 82 | + # --- post-analysis: resolve helper stmts into direct integer constants --- |
| 83 | + rewrite_result = ( |
| 84 | + Chain( |
| 85 | + Walk(ResolveGetRecIdx(measure_id_frame=meas_analysis_frame)), |
| 86 | + Fixpoint(Walk(DeadCodeElimination())), |
87 | 87 | ) |
88 | 88 | .rewrite(mt.code) |
89 | 89 | .join(rewrite_result) |
90 | 90 | ) |
91 | 91 |
|
| 92 | + # --- rewrite measures (must stay until after analysis) --- |
| 93 | + rewrite_result = ( |
| 94 | + Walk(SquinMeasureToStim()).rewrite(mt.code).join(rewrite_result) |
| 95 | + ) |
| 96 | + |
92 | 97 | rewrite_result = ( |
93 | 98 | CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise) |
94 | 99 | .unsafe_run(mt) |
95 | 100 | .join(rewrite_result) |
96 | 101 | ) |
97 | 102 |
|
98 | | - # Convert all PyConsts to Stim Constants |
99 | 103 | rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result) |
100 | 104 |
|
101 | | - # clear up leftover stmts |
102 | | - # - remove any squin.qalloc that's left around |
103 | 105 | rewrite_result = ( |
104 | 106 | Fixpoint( |
105 | 107 | Walk( |
|
0 commit comments