Skip to content

Commit 65ed26a

Browse files
johnzl-777kaihsin
andcommitted
Adapt SquinToStim to use new MeasureIDAnalysis (#414)
This is dependent on PR #413 (Please review that one first) but the new SquinToStim pass should now be able to generate the correct record indices. A couple of comments in terms of the changes: - I don't see a need to generate a GetRecord unless it NEEDS to be there. The only time this is a reality is when IfsToStim sees a compatible boolean condition, in which case a GetRecord should be generated. - There was one pre-existing test with regards to SquinToStim that failed with the new changes, but this was because the record indexes were historically incorrect. The test in question is `simple_if_rewrite` in `test/stim/passes/test_squin_meas_to_stim.py` ``` # The test's structure (in pseudocode) is: q = qubits(4) ms = measure(q) if ms[0] # apply some gates that can be translated to feedforward if ms[1] # same as above measure(q) # another measurement ``` Because of #398 the global number of measures was used and the indexes generated were `-7` and `-8`. I think this is still technically correct (they are first - and oldest - measurements that occur) but the more correct behavior is that at the time of the scf.IfElse you only have 4 measurements total so the correct index generated should be -4 and -3. --------- Co-authored-by: Kai-Hsin Wu <[email protected]>
1 parent 8f4a889 commit 65ed26a

File tree

11 files changed

+124
-86
lines changed

11 files changed

+124
-86
lines changed
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
from . import impls as impls
2-
from .analysis import MeasurementIDAnalysis as MeasurementIDAnalysis
2+
from .analysis import (
3+
MeasureIDFrame as MeasureIDFrame,
4+
MeasurementIDAnalysis as MeasurementIDAnalysis,
5+
)
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
from .squin_to_stim import SquinToStimPass as SquinToStimPass
1+
from .squin_to_stim import (
2+
SquinToStimPass as SquinToStimPass,
3+
StimSimplifyIfs as StimSimplifyIfs,
4+
AggressiveForLoopUnroll as AggressiveForLoopUnroll,
5+
)

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,14 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
128128
)
129129

130130
# 2. rewrite
131+
## Invoke DCE afterwards to eliminate any GetItems
132+
## that are no longer being used. This allows for
133+
## SquinMeasureToStim to safely eliminate
134+
## unused measure statements.
131135
rewrite_result = (
132-
Walk(
133-
IfToStim(
134-
measure_analysis=meas_analysis_frame.entries,
135-
measure_count=mia.measure_count,
136-
)
136+
Chain(
137+
Walk(IfToStim(measure_frame=meas_analysis_frame)),
138+
Fixpoint(Walk(DeadCodeElimination())),
137139
)
138140
.rewrite(mt.code)
139141
.join(rewrite_result)
@@ -149,17 +151,15 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
149151
Walk(
150152
Chain(
151153
SquinQubitToStim(),
154+
SquinMeasureToStim(),
152155
SquinWireToStim(),
153-
SquinMeasureToStim(
154-
measure_id_result=meas_analysis_frame.entries,
155-
total_measure_count=mia.measure_count,
156-
), # reduce duplicated logic, can split out even more rules later
157156
SquinWireIdentityElimination(),
158157
)
159158
)
160159
.rewrite(mt.code)
161160
.join(rewrite_result)
162161
)
162+
163163
rewrite_result = (
164164
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
165165
.unsafe_run(mt)

src/bloqade/stim/rewrite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .ifs_to_stim import IfToStim as IfToStim
12
from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
23
from .wire_to_stim import SquinWireToStim as SquinWireToStim
34
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim

src/bloqade/stim/rewrite/ifs_to_stim.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
SQUIN_STIM_CONTROL_GATE_MAPPING,
1212
insert_qubit_idx_from_address,
1313
)
14+
from bloqade.analysis.measure_id import MeasureIDFrame
1415
from bloqade.stim.dialects.auxiliary import GetRecord
1516
from bloqade.analysis.measure_id.lattice import (
16-
MeasureId,
1717
MeasureIdBool,
1818
)
1919

@@ -127,8 +127,7 @@ class IfToStim(IfElseSimplification, RewriteRule):
127127
Rewrite if statements to stim equivalent statements.
128128
"""
129129

130-
measure_analysis: dict[ir.SSAValue, MeasureId]
131-
measure_count: int
130+
measure_frame: MeasureIDFrame
132131

133132
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
134133

@@ -140,7 +139,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
140139

141140
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
142141

143-
if not isinstance(self.measure_analysis[stmt.cond], MeasureIdBool):
142+
if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
144143
return RewriteResult()
145144

146145
# check that there is only qubit.Apply in the then-body,
@@ -161,12 +160,12 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
161160
return RewriteResult()
162161

163162
# get necessary measurement ID type from analysis
164-
measure_id_bool = self.measure_analysis[stmt.cond]
163+
measure_id_bool = self.measure_frame.entries[stmt.cond]
165164
assert isinstance(measure_id_bool, MeasureIdBool)
166165

167166
# generate get record statement
168167
measure_id_idx_stmt = py.Constant(
169-
(measure_id_bool.idx - 1) - self.measure_count
168+
(measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
170169
)
171170
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
172171

src/bloqade/stim/rewrite/squin_measure.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,16 @@
22
from dataclasses import dataclass
33

44
from kirin import ir
5-
from kirin.dialects import py, ilist
5+
from kirin.dialects import py
66
from kirin.rewrite.abc import RewriteRule, RewriteResult
77

88
from bloqade.squin import wire, qubit
99
from bloqade.squin.rewrite import AddressAttribute
10-
from bloqade.stim.dialects import collapse, auxiliary
10+
from bloqade.stim.dialects import collapse
1111
from bloqade.stim.rewrite.util import (
1212
is_measure_result_used,
1313
insert_qubit_idx_from_address,
1414
)
15-
from bloqade.analysis.measure_id.lattice import MeasureId, MeasureIdBool, MeasureIdTuple
16-
17-
18-
def replace_get_record(
19-
node: ir.Statement, measure_id_bool: MeasureIdBool, meas_count: int
20-
):
21-
assert isinstance(measure_id_bool, MeasureIdBool)
22-
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
23-
idx_stmt = py.constant.Constant(target_rec_idx)
24-
idx_stmt.insert_before(node)
25-
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
26-
node.replace_by(get_record_stmt)
27-
28-
29-
def insert_get_record_list(
30-
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count: int
31-
):
32-
"""
33-
Insert GetRecord statements before the given node
34-
"""
35-
get_record_ssas = []
36-
for measure_id_bool in measure_id_tuple.data:
37-
assert isinstance(measure_id_bool, MeasureIdBool)
38-
target_rec_idx = (measure_id_bool.idx - 1) - meas_count
39-
idx_stmt = py.constant.Constant(target_rec_idx)
40-
idx_stmt.insert_before(node)
41-
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
42-
get_record_stmt.insert_before(node)
43-
get_record_ssas.append(get_record_stmt.result)
44-
45-
node.replace_by(ilist.New(values=get_record_ssas))
4615

4716

4817
@dataclass
@@ -51,9 +20,6 @@ class SquinMeasureToStim(RewriteRule):
5120
Rewrite squin measure-related statements to stim statements.
5221
"""
5322

54-
measure_id_result: dict[ir.SSAValue, MeasureId]
55-
total_measure_count: int
56-
5723
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
5824

5925
match node:
@@ -70,10 +36,6 @@ def rewrite_Measure(
7036
if qubit_idx_ssas is None:
7137
return RewriteResult()
7238

73-
measure_id = self.measure_id_result[measure_stmt.result]
74-
if not isinstance(measure_id, (MeasureIdBool, MeasureIdTuple)):
75-
return RewriteResult()
76-
7739
prob_noise_stmt = py.constant.Constant(0.0)
7840
stim_measure_stmt = collapse.MZ(
7941
p=prob_noise_stmt.result,
@@ -84,27 +46,6 @@ def rewrite_Measure(
8446

8547
if not is_measure_result_used(measure_stmt):
8648
measure_stmt.delete()
87-
return RewriteResult(has_done_something=True)
88-
89-
# replace dataflow with new stmt!
90-
measure_id = self.measure_id_result[measure_stmt.result]
91-
if isinstance(measure_id, MeasureIdBool):
92-
replace_get_record(
93-
node=measure_stmt,
94-
measure_id_bool=measure_id,
95-
meas_count=self.total_measure_count,
96-
)
97-
elif isinstance(measure_id, MeasureIdTuple):
98-
insert_get_record_list(
99-
node=measure_stmt,
100-
measure_id_tuple=measure_id,
101-
meas_count=self.total_measure_count,
102-
)
103-
else:
104-
# already checked before, so this should not happen
105-
raise ValueError(
106-
f"Unexpected measure ID type: {type(measure_id)} for measure statement {measure_stmt}"
107-
)
10849

10950
return RewriteResult(has_done_something=True)
11051

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
MZ(0.00000000) 0 1 2 3
3+
CZ rec[-4] 0
4+
MZ(0.00000000) 0 1 2 3
5+
CX rec[-4] 1
6+
MZ(0.00000000) 0 1 2 3
7+
CY rec[-4] 2
8+
CZ rec[-1] 3
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
MZ(0.00000000) 0 1 2 3
3+
CZ rec[-4] 0
4+
MZ(0.00000000) 0 1 2 3
5+
CX rec[-4] 0
6+
MZ(0.00000000) 0 1 2 3
7+
CY rec[-12] 1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
MZ(0.00000000) 0 1 2 3
3+
CZ rec[-4] 0
4+
CX rec[-4] 1 rec[-4] 2 rec[-4] 3
5+
CZ rec[-4] 0 rec[-4] 1 rec[-4] 2 rec[-4] 3
6+
CX rec[-3] 0
7+
CY rec[-3] 1
8+
MZ(0.00000000) 0 1 2 3

test/stim/passes/stim_reference_programs/simple_if_rewrite.txt

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)