Skip to content

Commit d6d0023

Browse files
squin to stim rewrite (#148)
Co-authored-by: Roger-luo <[email protected]>
1 parent 676ed3a commit d6d0023

File tree

13 files changed

+1176
-1
lines changed

13 files changed

+1176
-1
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ def apply(
206206
):
207207
return frame.get_values(stmt.inputs)
208208

209+
@interp.impl(squin.wire.MeasureAndReset)
210+
def measure_and_reset(
211+
self,
212+
interp_: AddressAnalysis,
213+
frame: ForwardFrame[Address],
214+
stmt: squin.wire.MeasureAndReset,
215+
):
216+
217+
# take the address data from the incoming wire
218+
# and propagate that forward to the new wire generated.
219+
# The first entry can safely be NotQubit because
220+
# it's an integer
221+
return (NotQubit(), frame.get(stmt.wire))
222+
209223

210224
@squin.qubit.dialect.register(key="qubit.address")
211225
class SquinQubitMethodTable(interp.MethodTable):

src/bloqade/squin/analysis/nsites/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Need this for impl registration to work properly!
22
from . import impls as impls
33
from .lattice import (
4+
Sites as Sites,
45
NoSites as NoSites,
56
AnySites as AnySites,
67
NumberSites as NumberSites,

src/bloqade/squin/analysis/nsites/impls.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from kirin import interp
22

3-
from bloqade.squin import op
3+
from bloqade.squin import op, wire
44

55
from .lattice import (
66
NoSites,
@@ -9,6 +9,30 @@
99
from .analysis import NSitesAnalysis
1010

1111

12+
@wire.dialect.register(key="op.nsites")
13+
class SquinWire(interp.MethodTable):
14+
15+
@interp.impl(wire.Apply)
16+
@interp.impl(wire.Broadcast)
17+
def apply(
18+
self,
19+
interp: NSitesAnalysis,
20+
frame: interp.Frame,
21+
stmt: wire.Apply | wire.Broadcast,
22+
):
23+
24+
return tuple(frame.get(input) for input in stmt.inputs)
25+
26+
@interp.impl(wire.MeasureAndReset)
27+
def measure_and_reset(
28+
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset
29+
):
30+
31+
# MeasureAndReset produces both a new wire
32+
# and an integer which don't have any sites at all
33+
return (NoSites(), NoSites())
34+
35+
1236
@op.dialect.register(key="op.nsites")
1337
class SquinOp(interp.MethodTable):
1438

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .stim import SquinToStim as SquinToStim

src/bloqade/squin/passes/stim.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from dataclasses import dataclass
2+
3+
from kirin.passes import Fold
4+
from kirin.rewrite import (
5+
Walk,
6+
Chain,
7+
Fixpoint,
8+
DeadCodeElimination,
9+
CommonSubexpressionElimination,
10+
)
11+
from kirin.ir.method import Method
12+
from kirin.passes.abc import Pass
13+
from kirin.rewrite.abc import RewriteResult
14+
15+
from bloqade.squin.rewrite import (
16+
SquinWireToStim,
17+
SquinQubitToStim,
18+
WrapSquinAnalysis,
19+
SquinMeasureToStim,
20+
SquinWireIdentityElimination,
21+
)
22+
from bloqade.analysis.address import AddressAnalysis
23+
from bloqade.squin.analysis.nsites import (
24+
NSitesAnalysis,
25+
)
26+
27+
28+
@dataclass
29+
class SquinToStim(Pass):
30+
31+
def unsafe_run(self, mt: Method) -> RewriteResult:
32+
fold_pass = Fold(mt.dialects)
33+
# propagate constants
34+
rewrite_result = fold_pass(mt)
35+
36+
# Get necessary analysis results to plug into hints
37+
address_analysis = AddressAnalysis(mt.dialects)
38+
address_frame, _ = address_analysis.run_analysis(mt)
39+
site_analysis = NSitesAnalysis(mt.dialects)
40+
sites_frame, _ = site_analysis.run_analysis(mt)
41+
42+
# Wrap Rewrite + SquinToStim can happen w/ standard walk
43+
rewrite_result = (
44+
Walk(
45+
Chain(
46+
WrapSquinAnalysis(
47+
address_analysis=address_frame.entries,
48+
op_site_analysis=sites_frame.entries,
49+
),
50+
SquinQubitToStim(),
51+
SquinWireToStim(),
52+
SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
53+
SquinWireIdentityElimination(),
54+
)
55+
)
56+
.rewrite(mt.code)
57+
.join(rewrite_result)
58+
)
59+
60+
rewrite_result = (
61+
Fixpoint(
62+
Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
63+
)
64+
.rewrite(mt.code)
65+
.join(rewrite_result)
66+
)
67+
68+
return rewrite_result
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .wire_to_stim import SquinWireToStim as SquinWireToStim
2+
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3+
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
4+
from .wrap_analysis import (
5+
SitesAttribute as SitesAttribute,
6+
AddressAttribute as AddressAttribute,
7+
WrapSquinAnalysis as WrapSquinAnalysis,
8+
)
9+
from .wire_identity_elimination import (
10+
SquinWireIdentityElimination as SquinWireIdentityElimination,
11+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from kirin import ir
2+
from kirin.rewrite.abc import RewriteRule, RewriteResult
3+
4+
from bloqade import stim
5+
from bloqade.squin import op, qubit
6+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7+
from bloqade.squin.rewrite.stim_rewrite_util import (
8+
SQUIN_STIM_GATE_MAPPING,
9+
rewrite_Control,
10+
insert_qubit_idx_from_address,
11+
)
12+
13+
14+
class SquinQubitToStim(RewriteRule):
15+
16+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
17+
18+
match node:
19+
case qubit.Apply() | qubit.Broadcast():
20+
return self.rewrite_Apply_and_Broadcast(node)
21+
case qubit.Reset():
22+
return self.rewrite_Reset(node)
23+
case _:
24+
return RewriteResult()
25+
26+
def rewrite_Apply_and_Broadcast(
27+
self, stmt: qubit.Apply | qubit.Broadcast
28+
) -> RewriteResult:
29+
"""
30+
Rewrite Apply and Broadcast nodes to their stim equivalent statements.
31+
"""
32+
33+
# this is an SSAValue, need it to be the actual operator
34+
applied_op = stmt.operator.owner
35+
assert isinstance(applied_op, op.stmts.Operator)
36+
37+
if isinstance(applied_op, op.stmts.Control):
38+
return rewrite_Control(stmt)
39+
40+
# need to handle Control through separate means
41+
# but we can handle X, Y, Z, H, and S here just fine
42+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
43+
if stim_1q_op is None:
44+
return RewriteResult()
45+
46+
address_attr = stmt.qubits.hints.get("address")
47+
if address_attr is None:
48+
return RewriteResult()
49+
50+
assert isinstance(address_attr, AddressAttribute)
51+
qubit_idx_ssas = insert_qubit_idx_from_address(
52+
address=address_attr, stmt_to_insert_before=stmt
53+
)
54+
55+
if qubit_idx_ssas is None:
56+
return RewriteResult()
57+
58+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
59+
stmt.replace_by(stim_1q_stmt)
60+
61+
return RewriteResult(has_done_something=True)
62+
63+
def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
64+
qubit_ilist_ssa = reset_stmt.qubits
65+
# qubits are in an ilist which makes up an AddressTuple
66+
address_attr = qubit_ilist_ssa.hints.get("address")
67+
if address_attr is None:
68+
return RewriteResult()
69+
70+
assert isinstance(address_attr, AddressAttribute)
71+
qubit_idx_ssas = insert_qubit_idx_from_address(
72+
address=address_attr, stmt_to_insert_before=reset_stmt
73+
)
74+
75+
if qubit_idx_ssas is None:
76+
return RewriteResult()
77+
78+
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
79+
reset_stmt.replace_by(stim_rz_stmt)
80+
81+
return RewriteResult(has_done_something=True)
82+
83+
84+
# put rewrites for measure statements in separate rule, then just have to dispatch
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# create rewrite rule name SquinMeasureToStim using kirin
2+
from kirin import ir
3+
from kirin.dialects import py
4+
from kirin.rewrite.abc import RewriteRule, RewriteResult
5+
6+
from bloqade import stim
7+
from bloqade.squin import wire, qubit
8+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9+
from bloqade.squin.rewrite.stim_rewrite_util import (
10+
is_measure_result_used,
11+
insert_qubit_idx_from_address,
12+
)
13+
14+
15+
class SquinMeasureToStim(RewriteRule):
16+
"""
17+
Rewrite squin measure-related statements to stim statements.
18+
"""
19+
20+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21+
22+
match node:
23+
case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
24+
return self.rewrite_Measure(node)
25+
case qubit.MeasureAndReset() | wire.MeasureAndReset():
26+
return self.rewrite_MeasureAndReset(node)
27+
case _:
28+
return RewriteResult()
29+
30+
def rewrite_Measure(
31+
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
32+
) -> RewriteResult:
33+
if is_measure_result_used(measure_stmt):
34+
return RewriteResult()
35+
36+
qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
37+
if qubit_idx_ssas is None:
38+
return RewriteResult()
39+
40+
prob_noise_stmt = py.constant.Constant(0.0)
41+
stim_measure_stmt = stim.collapse.MZ(
42+
p=prob_noise_stmt.result,
43+
targets=qubit_idx_ssas,
44+
)
45+
prob_noise_stmt.insert_before(measure_stmt)
46+
measure_stmt.replace_by(stim_measure_stmt)
47+
48+
return RewriteResult(has_done_something=True)
49+
50+
def rewrite_MeasureAndReset(
51+
self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
52+
) -> RewriteResult:
53+
if not is_measure_result_used(meas_and_reset_stmt):
54+
return RewriteResult()
55+
56+
qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt)
57+
58+
if qubit_idx_ssas is None:
59+
return RewriteResult()
60+
61+
error_p_stmt = py.Constant(0.0)
62+
stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
63+
stim_rz_stmt = stim.collapse.RZ(
64+
targets=qubit_idx_ssas,
65+
)
66+
67+
error_p_stmt.insert_before(meas_and_reset_stmt)
68+
stim_mz_stmt.insert_before(meas_and_reset_stmt)
69+
meas_and_reset_stmt.replace_by(stim_rz_stmt)
70+
71+
return RewriteResult(has_done_something=True)
72+
73+
def get_qubit_idx_ssas(
74+
self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
75+
) -> tuple[ir.SSAValue, ...] | None:
76+
"""
77+
Extract the address attribute and insert qubit indices for the given measure statement.
78+
"""
79+
match measure_stmt:
80+
case qubit.MeasureQubit():
81+
address_attr = measure_stmt.qubit.hints.get("address")
82+
case qubit.MeasureQubitList():
83+
address_attr = measure_stmt.qubits.hints.get("address")
84+
case wire.Measure():
85+
address_attr = measure_stmt.wire.hints.get("address")
86+
case _:
87+
return None
88+
89+
if address_attr is None:
90+
return None
91+
92+
assert isinstance(address_attr, AddressAttribute)
93+
94+
qubit_idx_ssas = insert_qubit_idx_from_address(
95+
address=address_attr, stmt_to_insert_before=measure_stmt
96+
)
97+
98+
return qubit_idx_ssas

0 commit comments

Comments
 (0)