Skip to content

Commit 13ae8a5

Browse files
committed
more testing, verification implemented
1 parent c137221 commit 13ae8a5

File tree

2 files changed

+208
-14
lines changed

2 files changed

+208
-14
lines changed

src/bloqade/squin/rewrite/stim.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict
1+
from typing import Dict, cast
22
from dataclasses import dataclass
33

44
from kirin import ir
@@ -9,7 +9,7 @@
99
from bloqade import stim
1010
from bloqade.squin import op, wire, qubit
1111
from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple
12-
from bloqade.squin.analysis.nsites import Sites
12+
from bloqade.squin.analysis.nsites import Sites, NumberSites
1313

1414
# Probably best to move these attributes to a
1515
# separate file? Keep here for now
@@ -93,7 +93,9 @@ def get_address_attr(self, value: ir.SSAValue) -> AddressAttribute:
9393

9494
def get_sites_attr(self, value: ir.SSAValue):
9595
try:
96-
return value.hints["sites"]
96+
sites_attr = value.hints["sites"]
97+
assert isinstance(sites_attr, SitesAttribute)
98+
return sites_attr
9799
except KeyError:
98100
raise KeyError(f"The sites analysis hint for {value} does not exist")
99101

@@ -111,7 +113,7 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
111113
return stim.gate.H
112114
case op.stmts.S():
113115
return stim.gate.S
114-
case op.stmts.Identity(): # enforce sites defined = num wires in
116+
case op.stmts.Identity():
115117
return stim.gate.Identity
116118
case _:
117119
raise NotImplementedError(
@@ -193,6 +195,64 @@ def insert_qubit_idx_after_apply(
193195
"unsupported statement detected, only wire.Apply and qubit.Apply statements are supported by this method"
194196
)
195197

198+
def verify_num_site_Apply(self, apply_stmt: wire.Apply | qubit.Apply):
199+
200+
# get the number of wires/qubits that went into the statement
201+
if isinstance(apply_stmt, wire.Apply):
202+
num_sites_targeted = len(apply_stmt.inputs)
203+
elif isinstance(apply_stmt, qubit.Apply):
204+
address_attr = self.get_address_attr(apply_stmt.qubits)
205+
# ilist has AddressTuple type,
206+
# should be the case that the types INSIDE the AddressTuple
207+
# are all AddressQubit
208+
address_tuple = address_attr.address
209+
assert isinstance(address_tuple, AddressTuple)
210+
num_sites_targeted = len(address_tuple.data)
211+
else:
212+
raise TypeError(
213+
"Number of sites verification cannot occur on statements other than wire.Apply and qubit.Apply"
214+
)
215+
216+
# The only single qubit operator that can have its size customized is the Identity gate.
217+
# There are two possible valid uses for size.
218+
# Either:
219+
## Apply(Identity(size=n), wire0, ..., wire_n)
220+
# Or:
221+
## Apply(Identity(size=1), wire0, ..., wire_n)
222+
# both should have the same effect, and can naturally be represented in Stim as:
223+
# 1QGate q1 q2 q3 q4
224+
225+
op_ssa = apply_stmt.operator
226+
op_stmt = op_ssa.owner
227+
cast(ir.Statement, op_stmt)
228+
229+
sites_attr = self.get_sites_attr(op_ssa)
230+
sites_type = sites_attr.sites
231+
assert isinstance(sites_type, NumberSites)
232+
num_sites_supported = sites_type.sites
233+
234+
if isinstance(op_stmt, op.stmts.Identity):
235+
if num_sites_supported != 1 or num_sites_supported != num_sites_targeted:
236+
raise ValueError(
237+
"squin.op.Identity must either have sites = 1 or sites = the number of qubits/wires it is being applied on"
238+
)
239+
elif isinstance(op_stmt, op.stmts.Control):
240+
# in Stim control gates have the following supported syntax
241+
## CX 1 2
242+
## CX 1 2 3 4 (equivalent to CX 1 2, then CX 3 4)
243+
244+
if (
245+
num_sites_targeted < num_sites_supported
246+
or num_sites_targeted % num_sites_supported != 0
247+
):
248+
raise ValueError(
249+
"Mismatch found between Control gate supported number of qubits/wires and number of qubits/wires being supplied."
250+
)
251+
else:
252+
return None
253+
254+
return None
255+
196256
# might be worth attempting multiple dispatch like qasm2 rewrites
197257
# for Glob and Parallel to UOp
198258
# The problem is I'd have to introduce names for all the statements
@@ -206,6 +266,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
206266

207267
match node:
208268
case wire.Apply() | qubit.Apply():
269+
self.verify_num_site_Apply(node)
209270
return self.rewrite_Apply(node)
210271
case wire.Wrap():
211272
return self.rewrite_Wrap(node)
@@ -244,11 +305,6 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
244305
# but we can handle X, Y, Z, H, and S here just fine
245306
stim_1q_op = self.get_stim_1q_gate(applied_op)
246307

247-
# wire.Apply -> tuple of SSA -> AddressTuple
248-
# qubit.Apply -> list of qubits -> AddressTuple
249-
## Both cases the statements follow the Stim semantics of
250-
## 1QGate a b c d ....
251-
252308
if isinstance(apply_stmt, qubit.Apply):
253309
address_attr = self.get_address_attr(apply_stmt.qubits)
254310
qubit_idx_ssas = self.insert_qubit_idx_from_address(
@@ -266,6 +322,12 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
266322
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
267323
stim_1q_stmt.insert_before(apply_stmt)
268324

325+
# Could I safely delete the apply statements?
326+
# If it's a a qubit.Apply yes, because it doesn't return anything
327+
# If it's a wire.Apply no, because the `results` of that Apply get used later on
328+
if isinstance(apply_stmt, qubit.Apply):
329+
apply_stmt.delete()
330+
269331
return RewriteResult(has_done_something=True)
270332

271333
def rewrite_Control(
@@ -278,12 +340,9 @@ def rewrite_Control(
278340
ctrl_op = apply_stmt_ctrl.operator.owner
279341
assert isinstance(ctrl_op, op.stmts.Control)
280342

281-
# enforce that n_controls is 1
282-
283343
ctrl_op_target_gate = ctrl_op.op.owner
284344
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
285345

286-
# should enforce that this is some multiple of 2
287346
qubit_idx_ssas = self.insert_qubit_idx_after_apply(apply_stmt=apply_stmt_ctrl)
288347
# according to stim, final result can be:
289348
# CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4)
@@ -314,6 +373,9 @@ def rewrite_Control(
314373

315374
stim_stmt.insert_before(apply_stmt_ctrl)
316375

376+
if isinstance(apply_stmt_ctrl, qubit.Apply):
377+
apply_stmt_ctrl.delete()
378+
317379
return RewriteResult(has_done_something=True)
318380

319381
def rewrite_Measure(

test/squin/stim/stim.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,135 @@ def test_wire_1q():
7878
constructed_method.print()
7979

8080

81+
def test_parallel_wire_1q_application():
82+
83+
stmts: list[ir.Statement] = [
84+
# Create qubit register
85+
(n_qubits := as_int(4)),
86+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
87+
# Get qubits out
88+
(idx0 := as_int(0)),
89+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
90+
(idx1 := as_int(1)),
91+
(q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)),
92+
(idx2 := as_int(2)),
93+
(q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)),
94+
(idx3 := as_int(3)),
95+
(q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)),
96+
# Unwrap to get wires
97+
(w0 := squin.wire.Unwrap(qubit=q0.result)),
98+
(w1 := squin.wire.Unwrap(qubit=q1.result)),
99+
(w2 := squin.wire.Unwrap(qubit=q2.result)),
100+
(w3 := squin.wire.Unwrap(qubit=q3.result)),
101+
# Apply with stim semantics
102+
(h_op := squin.op.stmts.H()),
103+
(
104+
app_res := squin.wire.Apply(
105+
h_op.result, w0.result, w1.result, w2.result, w3.result
106+
)
107+
),
108+
# Wrap everything back
109+
(squin.wire.Wrap(app_res.results[0], q0.result)),
110+
(squin.wire.Wrap(app_res.results[1], q1.result)),
111+
(squin.wire.Wrap(app_res.results[2], q2.result)),
112+
(squin.wire.Wrap(app_res.results[3], q3.result)),
113+
(ret_none := func.ConstantNone()),
114+
(func.Return(ret_none)),
115+
]
116+
117+
constructed_method = gen_func_from_stmts(stmts)
118+
119+
constructed_method.print()
120+
121+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
122+
squin_to_stim(constructed_method)
123+
124+
constructed_method.print()
125+
126+
127+
def test_parallel_qubit_1q_application():
128+
129+
stmts: list[ir.Statement] = [
130+
# Create qubit register
131+
(n_qubits := as_int(4)),
132+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
133+
# Get qubits out
134+
(idx0 := as_int(0)),
135+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
136+
(idx1 := as_int(1)),
137+
(q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)),
138+
(idx2 := as_int(2)),
139+
(q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)),
140+
(idx3 := as_int(3)),
141+
(q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)),
142+
# create ilist of qubits
143+
(q_list := ilist.New(values=(q0.result, q1.result, q2.result, q3.result))),
144+
# Apply with stim semantics
145+
(h_op := squin.op.stmts.H()),
146+
(app_res := squin.qubit.Apply(h_op.result, q_list.result)), # noqa: F841
147+
# Measure everything out
148+
(meas_res := squin.qubit.Measure(q_list.result)), # noqa: F841
149+
(ret_none := func.ConstantNone()),
150+
(func.Return(ret_none)),
151+
]
152+
153+
constructed_method = gen_func_from_stmts(stmts)
154+
155+
constructed_method.print()
156+
157+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
158+
squin_to_stim(constructed_method)
159+
160+
constructed_method.print()
161+
162+
163+
def test_parallel_control_gate_wire_application():
164+
165+
stmts: list[ir.Statement] = [
166+
# Create qubit register
167+
(n_qubits := as_int(4)),
168+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
169+
# Get qubits out
170+
(idx0 := as_int(0)),
171+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
172+
(idx1 := as_int(1)),
173+
(q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)),
174+
(idx2 := as_int(2)),
175+
(q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)),
176+
(idx3 := as_int(3)),
177+
(q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)),
178+
# Unwrap to get wires
179+
(w0 := squin.wire.Unwrap(qubit=q0.result)),
180+
(w1 := squin.wire.Unwrap(qubit=q1.result)),
181+
(w2 := squin.wire.Unwrap(qubit=q2.result)),
182+
(w3 := squin.wire.Unwrap(qubit=q3.result)),
183+
# Create and apply CX gate
184+
(x_op := squin.op.stmts.X()),
185+
(ctrl_x_op := squin.op.stmts.Control(x_op.result, n_controls=1)),
186+
(
187+
app_res := squin.wire.Apply(
188+
ctrl_x_op.result, w0.result, w1.result, w2.result, w3.result
189+
)
190+
),
191+
# measure it all out
192+
(meas_res_0 := squin.wire.Measure(app_res.results[0])), # noqa: F841
193+
(meas_res_1 := squin.wire.Measure(app_res.results[1])), # noqa: F841
194+
(meas_res_2 := squin.wire.Measure(app_res.results[2])), # noqa: F841
195+
(meas_res_3 := squin.wire.Measure(app_res.results[3])), # noqa: F841
196+
(ret_none := func.ConstantNone()),
197+
(func.Return(ret_none)),
198+
]
199+
200+
constructed_method = gen_func_from_stmts(stmts)
201+
202+
constructed_method.print()
203+
204+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
205+
squin_to_stim(constructed_method)
206+
207+
constructed_method.print()
208+
209+
81210
def test_wire_control():
82211

83212
stmts: list[ir.Statement] = [
@@ -255,7 +384,10 @@ def test_wire_measure_and_reset():
255384
constructed_method.print()
256385

257386

258-
test_wire_measure_and_reset()
387+
# test_wire_measure_and_reset()
259388
# test_qubit_measure_and_reset()
260389
# test_wire_reset()
261-
# test_qubit_reset()
390+
391+
# test_parallel_qubit_1q_application()
392+
# test_parallel_wire_1q_application()
393+
test_parallel_control_gate_wire_application()

0 commit comments

Comments
 (0)