Skip to content

Commit 7ba5670

Browse files
committed
finally put everything into a pass
1 parent e76b3f3 commit 7ba5670

File tree

5 files changed

+119
-70
lines changed

5 files changed

+119
-70
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .stim import SquinToStim as SquinToStim

src/bloqade/squin/passes/stim.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from dataclasses import dataclass
2+
3+
from kirin.passes import Fold
4+
from kirin.rewrite import (
5+
Walk,
6+
Chain,
7+
Fixpoint,
8+
DeadCodeElimination,
9+
CommonSubexpressionElimination,
10+
)
11+
from kirin.ir.method import Method
12+
from kirin.passes.abc import Pass
13+
from kirin.rewrite.abc import RewriteResult
14+
15+
import bloqade.squin.rewrite as squin_rewrite
16+
from bloqade.analysis.address import AddressAnalysis
17+
from bloqade.squin.analysis.nsites import (
18+
NSitesAnalysis,
19+
)
20+
21+
22+
@dataclass
23+
class SquinToStim(Pass):
24+
25+
def unsafe_run(self, mt: Method) -> RewriteResult:
26+
fold_pass = Fold(mt.dialects)
27+
# propagate constants
28+
rewrite_result = fold_pass(mt)
29+
30+
# Get necessary analysis results to plug into hints
31+
address_analysis = AddressAnalysis(mt.dialects)
32+
address_frame, _ = address_analysis.run_analysis(mt)
33+
site_analysis = NSitesAnalysis(mt.dialects)
34+
sites_frame, _ = site_analysis.run_analysis(mt)
35+
36+
# Wrap Rewrite + SquinToStim can happen w/ standard walk
37+
rewrite_result = (
38+
Walk(
39+
Chain(
40+
squin_rewrite.WrapSquinAnalysis(
41+
address_analysis=address_frame.entries,
42+
op_site_analysis=sites_frame.entries,
43+
),
44+
squin_rewrite._SquinToStim(),
45+
)
46+
)
47+
.rewrite(mt.code)
48+
.join(rewrite_result)
49+
)
50+
51+
rewrite_result = (
52+
Fixpoint(
53+
Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
54+
)
55+
.rewrite(mt.code)
56+
.join(rewrite_result)
57+
)
58+
59+
return rewrite_result
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .stim import (
2-
SquinToStim as SquinToStim,
32
SitesAttribute as SitesAttribute,
43
AddressAttribute as AddressAttribute,
54
WrapSquinAnalysis as WrapSquinAnalysis,
5+
_SquinToStim as _SquinToStim,
66
)

src/bloqade/squin/rewrite/stim.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from bloqade import stim
1010
from bloqade.squin import op, wire, qubit
11-
from bloqade.analysis.address import Address
11+
from bloqade.analysis.address import Address, AddressWire, AddressTuple
1212
from bloqade.squin.analysis.nsites import Sites
1313

1414
# Probably best to move these attributes to a
@@ -80,13 +80,19 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
8080

8181

8282
@dataclass
83-
class SquinToStim(RewriteRule):
83+
class _SquinToStim(RewriteRule):
8484

8585
def get_address(self, value: ir.SSAValue):
86-
return value.hints.get("address")
86+
try:
87+
return value.hints["address"]
88+
except KeyError:
89+
raise KeyError(f"The address analysis hint for {value} does not exist")
8790

8891
def get_sites(self, value: ir.SSAValue):
89-
return value.hints.get("sites")
92+
try:
93+
return value.hints["sites"]
94+
except KeyError:
95+
raise KeyError(f"The sites analysis hint for {value} does not exist")
9096

9197
# Go from (most) squin 1Q Ops to stim Ops
9298
## X, Y, Z, H, S, (no T!)
@@ -105,7 +111,9 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
105111
case op.stmts.Identity(): # enforce sites defined = num wires in
106112
return stim.gate.Identity
107113
case _:
108-
return None
114+
raise NotImplementedError(
115+
f"The squin operator {squin_op} is not supported in the stim dialect"
116+
)
109117

110118
# get the qubit indices from the Apply statement argument
111119
# wires/qubits
@@ -158,6 +166,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
158166
return self.rewrite_Apply(node)
159167
case wire.Wrap():
160168
return self.rewrite_Wrap(node)
169+
case wire.Measure() | qubit.Measure():
170+
return self.rewrite_Measure(node)
161171
case _:
162172
return RewriteResult()
163173

@@ -236,3 +246,41 @@ def rewrite_Control(
236246
stim_stmt.insert_before(apply_stmt_ctrl)
237247

238248
return RewriteResult(has_done_something=True)
249+
250+
def rewrite_Measure(
251+
self, measure_stmt: qubit.Measure | wire.Measure
252+
) -> RewriteResult:
253+
254+
if isinstance(measure_stmt, qubit.Measure):
255+
qubit_ilist_ssa = measure_stmt.qubits
256+
# qubits are in an ilist which makes up an AddressTuple
257+
address_tuple: AddressTuple = self.get_address(qubit_ilist_ssa).address
258+
qubit_idx_ssas = []
259+
for qubit_address in address_tuple:
260+
qubit_idx = qubit_address.data
261+
qubit_idx_stmt = py.constant.Constant(qubit_idx)
262+
qubit_idx_stmt.insert_before(measure_stmt)
263+
qubit_idx_ssas.append(qubit_idx_stmt.result)
264+
qubit_idx_ssas = tuple(qubit_idx_ssas)
265+
266+
elif isinstance(measure_stmt, wire.Measure):
267+
wire_ssa = measure_stmt.wire
268+
wire_address: AddressWire = self.get_address(wire_ssa).address
269+
270+
qubit_idx = wire_address.origin_qubit.data
271+
qubit_idx_stmt = py.constant.Constant(qubit_idx)
272+
qubit_idx_stmt.insert_before(measure_stmt)
273+
qubit_idx_ssas = (qubit_idx_stmt.result,)
274+
275+
else:
276+
return RewriteResult()
277+
278+
prob_noise_stmt = py.constant.Constant(0.0)
279+
stim_measure_stmt = stim.collapse.MZ(
280+
p=prob_noise_stmt.result,
281+
targets=qubit_idx_ssas,
282+
)
283+
prob_noise_stmt.insert_before(measure_stmt)
284+
stim_measure_stmt.insert_before(measure_stmt)
285+
286+
return RewriteResult(has_done_something=True)

test/squin/stim/stim.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from kirin import ir, types
2-
from kirin.passes import Fold
3-
from kirin.rewrite import Walk, Fixpoint, DeadCodeElimination
42
from kirin.dialects import py, func, ilist
53

4+
import bloqade.squin.passes as squin_passes
65
from bloqade import qasm2, squin
7-
from bloqade.analysis import address
8-
from bloqade.squin.rewrite import SquinToStim, WrapSquinAnalysis
9-
from bloqade.squin.analysis import nsites
106

117

128
def as_int(value: int):
@@ -38,9 +34,6 @@ def gen_func_from_stmts(stmts):
3834
arg_names=[],
3935
)
4036

41-
fold_pass = Fold(extended_dialect)
42-
fold_pass(constructed_method)
43-
4437
return constructed_method
4538

4639

@@ -77,34 +70,8 @@ def test_1q():
7770

7871
constructed_method.print()
7972

80-
address_frame, _ = address.AddressAnalysis(
81-
constructed_method.dialects
82-
).run_analysis(constructed_method, no_raise=False)
83-
84-
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
85-
constructed_method, no_raise=False
86-
)
87-
88-
constructed_method.print(analysis=address_frame.entries)
89-
constructed_method.print(analysis=nsites_frame.entries)
90-
91-
# attempt to wrap analysis results
92-
93-
wrap_squin_analysis = WrapSquinAnalysis(
94-
address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries
95-
)
96-
fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis))
97-
rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code)
98-
99-
# attempt rewrite to Stim
100-
# Be careful with Fixpoint, can go to infinity until reaches defined threshold
101-
squin_to_stim = Walk(SquinToStim())
102-
rewrite_res = squin_to_stim.rewrite(constructed_method.code)
103-
104-
# Get rid of the unused statements
105-
dce = Fixpoint(Walk(DeadCodeElimination()))
106-
rewrite_res = dce.rewrite(constructed_method.code)
107-
print(rewrite_res)
73+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
74+
squin_to_stim(constructed_method)
10875

10976
constructed_method.print()
11077

@@ -137,34 +104,8 @@ def test_control():
137104
constructed_method = gen_func_from_stmts(stmts)
138105
constructed_method.print()
139106

140-
address_frame, _ = address.AddressAnalysis(
141-
constructed_method.dialects
142-
).run_analysis(constructed_method, no_raise=False)
143-
144-
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
145-
constructed_method, no_raise=False
146-
)
147-
148-
constructed_method.print(analysis=address_frame.entries)
149-
constructed_method.print(analysis=nsites_frame.entries)
150-
151-
wrap_squin_analysis = WrapSquinAnalysis(
152-
address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries
153-
)
154-
fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis))
155-
rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code)
156-
157-
# attempt rewrite to Stim
158-
# Be careful with Fixpoint, can go to infinity until reaches defined threshold
159-
squin_to_stim = Walk(SquinToStim())
160-
rewrite_res = squin_to_stim.rewrite(constructed_method.code)
161-
162-
constructed_method.print()
163-
164-
# Get rid of the unused statements
165-
dce = Fixpoint(Walk(DeadCodeElimination()))
166-
rewrite_res = dce.rewrite(constructed_method.code)
167-
print(rewrite_res)
107+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
108+
squin_to_stim(constructed_method)
168109

169110
constructed_method.print()
170111

0 commit comments

Comments
 (0)