Skip to content

Commit e6c92b7

Browse files
committed
first round of meeting feedback implemented
1 parent fb9ff37 commit e6c92b7

File tree

7 files changed

+103
-48
lines changed

7 files changed

+103
-48
lines changed

src/bloqade/squin/passes/stim.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
from kirin.passes.abc import Pass
1313
from kirin.rewrite.abc import RewriteResult
1414

15-
from bloqade.squin.rewrite import SquinWireToStim, SquinQubitToStim, WrapSquinAnalysis
15+
from bloqade.squin.rewrite import (
16+
SquinWireToStim,
17+
SquinQubitToStim,
18+
WrapSquinAnalysis,
19+
SquinWireIdentityElimination,
20+
)
1621
from bloqade.analysis.address import AddressAnalysis
1722
from bloqade.squin.analysis.nsites import (
1823
NSitesAnalysis,
@@ -43,6 +48,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
4348
),
4449
SquinQubitToStim(),
4550
SquinWireToStim(),
51+
SquinWireIdentityElimination(),
4652
)
4753
)
4854
.rewrite(mt.code)

src/bloqade/squin/rewrite/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@
55
AddressAttribute as AddressAttribute,
66
WrapSquinAnalysis as WrapSquinAnalysis,
77
)
8+
from .wire_identity_elimination import (
9+
SquinWireIdentityElimination as SquinWireIdentityElimination,
10+
)

src/bloqade/squin/rewrite/qubit_to_stim.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66
from bloqade.squin import op, qubit
77
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
88
from bloqade.squin.rewrite.stim_rewrite_util import (
9+
SQUIN_STIM_GATE_MAPPING,
910
rewrite_Control,
10-
get_stim_1q_gate,
1111
are_sites_compatible,
12+
is_measure_result_used,
1213
insert_qubit_idx_from_address,
1314
)
1415

1516

1617
class SquinQubitToStim(RewriteRule):
1718

1819
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
20+
21+
# don't want to alloc dict, change back to if else/match case
1922
rewrite_methods = {
2023
qubit.Apply: self.rewrite_Apply_and_Broadcast,
2124
qubit.Broadcast: self.rewrite_Apply_and_Broadcast,
@@ -31,13 +34,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
3134

3235
return rewrite_method(node)
3336

34-
# handle Control
3537
def rewrite_Apply_and_Broadcast(
3638
self, stmt: qubit.Apply | qubit.Broadcast
3739
) -> RewriteResult:
3840
"""
3941
Rewrite Apply and Broadcast nodes to their stim equivalent statements.
4042
"""
43+
# get rid of are_sites_compatible, assume program is properly structured
4144
if not are_sites_compatible(stmt):
4245
return RewriteResult()
4346

@@ -50,7 +53,7 @@ def rewrite_Apply_and_Broadcast(
5053

5154
# need to handle Control through separate means
5255
# but we can handle X, Y, Z, H, and S here just fine
53-
stim_1q_op = get_stim_1q_gate(applied_op)
56+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
5457
if stim_1q_op is None:
5558
return RewriteResult()
5659

@@ -75,6 +78,9 @@ def rewrite_Measure(
7578
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList
7679
) -> RewriteResult:
7780

81+
if is_measure_result_used(measure_stmt):
82+
return RewriteResult()
83+
7884
# qubit_ssa will always be an ilist of qubits
7985
# but need to be careful with singular vs plural "qubit" attribute name
8086
if isinstance(measure_stmt, qubit.MeasureQubit):
@@ -132,6 +138,9 @@ def rewrite_MeasureAndReset(
132138
self, meas_and_reset_stmt: qubit.MeasureAndReset
133139
) -> RewriteResult:
134140

141+
if is_measure_result_used(meas_and_reset_stmt):
142+
return RewriteResult()
143+
135144
address_attr = meas_and_reset_stmt.qubits.hints.get("address")
136145
if address_attr is None:
137146
return RewriteResult()
@@ -154,3 +163,6 @@ def rewrite_MeasureAndReset(
154163
meas_and_reset_stmt.replace_by(stim_rz_stmt)
155164

156165
return RewriteResult(has_done_something=True)
166+
167+
168+
# put rewrites for measure statements in separate rule, then just have to dispatch

src/bloqade/squin/rewrite/stim_rewrite_util.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,14 @@
1010
from bloqade.squin.analysis.nsites import NumberSites
1111
from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute
1212

13-
14-
def get_stim_1q_gate(squin_op: op.stmts.Operator):
15-
"""
16-
Map squin 1Q Ops to stim Ops.
17-
"""
18-
gate_mapping = {
19-
op.stmts.X: stim.gate.X,
20-
op.stmts.Y: stim.gate.Y,
21-
op.stmts.Z: stim.gate.Z,
22-
op.stmts.H: stim.gate.H,
23-
op.stmts.S: stim.gate.S,
24-
op.stmts.Identity: stim.gate.Identity,
25-
}
26-
return gate_mapping.get(type(squin_op))
13+
SQUIN_STIM_GATE_MAPPING = {
14+
op.stmts.X: stim.gate.X,
15+
op.stmts.Y: stim.gate.Y,
16+
op.stmts.Z: stim.gate.Z,
17+
op.stmts.H: stim.gate.H,
18+
op.stmts.S: stim.gate.S,
19+
op.stmts.Identity: stim.gate.Identity,
20+
}
2721

2822

2923
def insert_qubit_idx_from_address(
@@ -190,3 +184,18 @@ def rewrite_Control(
190184
stmt_with_ctrl.replace_by(stim_stmt)
191185

192186
return RewriteResult(has_done_something=True)
187+
188+
189+
def is_measure_result_used(
190+
stmt: (
191+
qubit.MeasureAndReset
192+
| qubit.MeasureQubit
193+
| qubit.MeasureQubitList
194+
| wire.MeasureAndReset
195+
| wire.Measure
196+
),
197+
) -> bool:
198+
"""
199+
Check if the result of a measure statement is used in the program.
200+
"""
201+
return bool(stmt.result.uses)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from kirin import ir
2+
from kirin.rewrite.abc import RewriteRule, RewriteResult
3+
4+
from bloqade.squin import wire
5+
6+
7+
class SquinWireIdentityElimination(RewriteRule):
8+
9+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10+
"""
11+
Handle the case where an unwrap feeds a wire directly into a wrap,
12+
equivalent to nothing happening/identity operation
13+
14+
w = unwrap(qubit)
15+
wrap(qubit, w)
16+
"""
17+
if isinstance(node, wire.Wrap):
18+
wire_origin_stmt = node.wire.owner
19+
if isinstance(wire_origin_stmt, wire.Unwrap):
20+
node.delete() # get rid of wrap
21+
wire_origin_stmt.delete() # get rid of the unwrap
22+
return RewriteResult(has_done_something=True)
23+
24+
return RewriteResult()

src/bloqade/squin/rewrite/wire_to_stim.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from bloqade.squin import op, wire
77
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
88
from bloqade.squin.rewrite.stim_rewrite_util import (
9+
SQUIN_STIM_GATE_MAPPING,
910
rewrite_Control,
10-
get_stim_1q_gate,
1111
are_sites_compatible,
12+
is_measure_result_used,
1213
insert_qubit_idx_from_address,
1314
insert_qubit_idx_from_wire_ssa,
1415
)
@@ -17,21 +18,17 @@
1718
class SquinWireToStim(RewriteRule):
1819

1920
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
20-
21-
rewrite_methods = {
22-
wire.Apply: self.rewrite_Apply_and_Broadcast,
23-
wire.Broadcast: self.rewrite_Apply_and_Broadcast,
24-
wire.Wrap: self.rewrite_Wrap,
25-
wire.Measure: self.rewrite_Measure,
26-
wire.Reset: self.rewrite_Reset,
27-
wire.MeasureAndReset: self.rewrite_MeasureAndReset,
28-
}
29-
30-
rewrite_method = rewrite_methods.get(type(node))
31-
if rewrite_method is None:
32-
return RewriteResult()
33-
34-
return rewrite_method(node)
21+
match node:
22+
case wire.Apply() | wire.Broadcast():
23+
return self.rewrite_Apply_and_Broadcast(node)
24+
case wire.Measure():
25+
return self.rewrite_Measure(node)
26+
case wire.Reset():
27+
return self.rewrite_Reset(node)
28+
case wire.MeasureAndReset():
29+
return self.rewrite_MeasureAndReset(node)
30+
case _:
31+
return RewriteResult()
3532

3633
def rewrite_Apply_and_Broadcast(
3734
self, stmt: wire.Apply | wire.Broadcast
@@ -47,7 +44,7 @@ def rewrite_Apply_and_Broadcast(
4744
if isinstance(applied_op, op.stmts.Control):
4845
return rewrite_Control(stmt)
4946

50-
stim_1q_op = get_stim_1q_gate(applied_op)
47+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
5148
if stim_1q_op is None:
5249
return RewriteResult()
5350

@@ -69,21 +66,11 @@ def rewrite_Apply_and_Broadcast(
6966

7067
return RewriteResult(has_done_something=True)
7168

72-
def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult:
73-
74-
# structure at this point should be:
75-
## w = wire.Unwrap(wire)
76-
## wire.Wrap(qubit, w)
77-
78-
wire_origin_stmt = wrap_stmt.wire.owner
79-
if isinstance(wire_origin_stmt, wire.Unwrap):
80-
wrap_stmt.delete()
81-
return RewriteResult(has_done_something=True)
82-
83-
return RewriteResult()
84-
8569
def rewrite_Measure(self, measure_stmt: wire.Measure) -> RewriteResult:
8670

71+
if is_measure_result_used(measure_stmt):
72+
return RewriteResult()
73+
8774
wire_ssa = measure_stmt.wire
8875
address_attr = wire_ssa.hints.get("address")
8976
if address_attr is None:
@@ -126,6 +113,9 @@ def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
126113

127114
def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset):
128115

116+
if is_measure_result_used(meas_and_reset_stmt):
117+
return RewriteResult()
118+
129119
address_attr = meas_and_reset_stmt.wire.hints.get("address")
130120
if address_attr is None:
131121
return RewriteResult()
@@ -141,6 +131,7 @@ def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset):
141131
stim_rz_stmt = stim.collapse.RZ(
142132
targets=qubit_idx_ssas,
143133
)
134+
144135
error_p_stmt.insert_before(meas_and_reset_stmt)
145136
stim_mz_stmt.insert_before(meas_and_reset_stmt)
146137
meas_and_reset_stmt.replace_by(stim_rz_stmt)

test/squin/stim/stim.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,16 @@ def test_broadcast_wire_1q_application():
157157
constructed_method.print()
158158

159159

160+
# before ANY rewrite, aggressively inline everything, then do the rewrite
161+
# for Stim pass, need to call validation , check any invoke
162+
163+
# Put one codegen test to stim
164+
# finish measurement analysis Friday - if painful, ask help from Kai
165+
# work on other detector rewrite
166+
167+
# later on lower for loop to repeat
168+
169+
160170
def test_broadcast_qubit_1q_application():
161171

162172
stmts: list[ir.Statement] = [

0 commit comments

Comments
 (0)