Skip to content

Commit 20d4214

Browse files
committed
account for MeasureAndReset
1 parent f3203da commit 20d4214

File tree

3 files changed

+118
-10
lines changed

3 files changed

+118
-10
lines changed

src/bloqade/squin/analysis/nsites/impls.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def apply(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.Apply):
1919

2020
return tuple([frame.get(input) for input in stmt.inputs])
2121

22+
@interp.impl(wire.MeasureAndReset)
23+
def measure_and_reset(
24+
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset
25+
):
26+
27+
# MeasureAndReset produces both a new wire
28+
# and an integer which don't have any sites at all
29+
return (NoSites(), NoSites())
30+
2231

2332
@op.dialect.register(key="op.nsites")
2433
class SquinOp(interp.MethodTable):

src/bloqade/squin/rewrite/stim.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
8282
@dataclass
8383
class _SquinToStim(RewriteRule):
8484

85-
def get_address(self, value: ir.SSAValue) -> AddressAttribute:
85+
def get_address_attr(self, value: ir.SSAValue) -> AddressAttribute:
8686

8787
try:
8888
address_attr = value.hints["address"]
@@ -91,7 +91,7 @@ def get_address(self, value: ir.SSAValue) -> AddressAttribute:
9191
except KeyError:
9292
raise KeyError(f"The address analysis hint for {value} does not exist")
9393

94-
def get_sites(self, value: ir.SSAValue):
94+
def get_sites_attr(self, value: ir.SSAValue):
9595
try:
9696
return value.hints["sites"]
9797
except KeyError:
@@ -157,7 +157,7 @@ def insert_qubit_idx_from_wire_ssa(
157157
) -> tuple[ir.SSAValue, ...]:
158158
qubit_idx_ssas = []
159159
for wire_ssa in wire_ssas:
160-
address_attribute = self.get_address(wire_ssa) # get AddressWire
160+
address_attribute = self.get_address_attr(wire_ssa) # get AddressWire
161161
# get parent qubit idx
162162
wire_address = address_attribute.address
163163
assert isinstance(wire_address, AddressWire)
@@ -178,7 +178,7 @@ def insert_qubit_idx_after_apply(
178178

179179
if isinstance(apply_stmt, qubit.Apply):
180180
qubits = apply_stmt.qubits
181-
address_attribute: AddressAttribute = self.get_address(qubits)
181+
address_attribute: AddressAttribute = self.get_address_attr(qubits)
182182
# Should get an AddressTuple out of the address stored in attribute
183183
return self.insert_qubit_idx_from_address(
184184
address=address_attribute, stmt_to_insert_before=apply_stmt
@@ -213,6 +213,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
213213
return self.rewrite_Measure(node)
214214
case wire.Reset() | qubit.Reset():
215215
return self.rewrite_Reset(node)
216+
case wire.MeasureAndReset() | qubit.MeasureAndReset():
217+
return self.rewrite_MeasureAndReset(node)
216218
case _:
217219
return RewriteResult()
218220

@@ -248,7 +250,7 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
248250
## 1QGate a b c d ....
249251

250252
if isinstance(apply_stmt, qubit.Apply):
251-
address_attr = self.get_address(apply_stmt.qubits)
253+
address_attr = self.get_address_attr(apply_stmt.qubits)
252254
qubit_idx_ssas = self.insert_qubit_idx_from_address(
253255
address=address_attr, stmt_to_insert_before=apply_stmt
254256
)
@@ -321,13 +323,13 @@ def rewrite_Measure(
321323
if isinstance(measure_stmt, qubit.Measure):
322324
qubit_ilist_ssa = measure_stmt.qubits
323325
# qubits are in an ilist which makes up an AddressTuple
324-
address_attr = self.get_address(qubit_ilist_ssa)
326+
address_attr = self.get_address_attr(qubit_ilist_ssa)
325327

326328
elif isinstance(measure_stmt, wire.Measure):
327329
# Wire Terminator, should kill the existence of
328330
# the wire here so DCE can sweep up the rest like with rewriting wrap
329331
wire_ssa = measure_stmt.wire
330-
address_attr = self.get_address(wire_ssa)
332+
address_attr = self.get_address_attr(wire_ssa)
331333

332334
# DCE can't remove the old measure_stmt for both wire and qubit versions
333335
# because of the fact it has a result that can be depended on by other statements
@@ -368,12 +370,12 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult:
368370
if isinstance(reset_stmt, qubit.Reset):
369371
qubit_ilist_ssa = reset_stmt.qubits
370372
# qubits are in an ilist which makes up an AddressTuple
371-
address_attr = self.get_address(qubit_ilist_ssa)
373+
address_attr = self.get_address_attr(qubit_ilist_ssa)
372374
qubit_idx_ssas = self.insert_qubit_idx_from_address(
373375
address=address_attr, stmt_to_insert_before=reset_stmt
374376
)
375377
elif isinstance(reset_stmt, wire.Reset):
376-
address_attr = self.get_address(reset_stmt.wire)
378+
address_attr = self.get_address_attr(reset_stmt.wire)
377379
qubit_idx_ssas = self.insert_qubit_idx_from_address(
378380
address=address_attr, stmt_to_insert_before=reset_stmt
379381
)
@@ -390,4 +392,43 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult:
390392
def rewrite_MeasureAndReset(
391393
self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
392394
):
393-
pass
395+
"""
396+
qubit.MeasureAndReset(qubits) -> result
397+
Could be translated (roughly equivalent) to
398+
399+
stim.MZ(tuple[SSAvals for ints])
400+
stim.RZ(tuple[SSAvals for ints])
401+
402+
Stim does have MRZ, might be more reflective of what we want/
403+
lines up the semantics better
404+
405+
"""
406+
407+
if isinstance(meas_and_reset_stmt, qubit.MeasureAndReset):
408+
409+
address_attr = self.get_address_attr(meas_and_reset_stmt.qubits)
410+
qubit_idx_ssas = self.insert_qubit_idx_from_address(
411+
address=address_attr, stmt_to_insert_before=meas_and_reset_stmt
412+
)
413+
414+
elif isinstance(meas_and_reset_stmt, wire.MeasureAndReset):
415+
address_attr = self.get_address_attr(meas_and_reset_stmt.wire)
416+
qubit_idx_ssas = self.insert_qubit_idx_from_address(
417+
address_attr, stmt_to_insert_before=meas_and_reset_stmt
418+
)
419+
420+
else:
421+
raise TypeError(
422+
"Unsupported statement detected, only qubit.MeasureAndReset and wire.MeasureAndReset are supported"
423+
)
424+
425+
error_p_stmt = py.Constant(0.0)
426+
stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
427+
stim_rz_stmt = stim.collapse.RZ(
428+
targets=qubit_idx_ssas,
429+
)
430+
error_p_stmt.insert_before(meas_and_reset_stmt)
431+
stim_mz_stmt.insert_before(meas_and_reset_stmt)
432+
stim_rz_stmt.insert_before(meas_and_reset_stmt)
433+
434+
return RewriteResult(has_done_something=True)

test/squin/stim/stim.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,61 @@ def test_qubit_reset():
162162
rewrite_result = squin_to_stim(constructed_method)
163163
print(rewrite_result)
164164
constructed_method.print()
165+
166+
167+
def test_qubit_measure_and_reset():
168+
169+
stmts: list[ir.Statement] = [
170+
# Create qubit register
171+
(n_qubits := as_int(1)),
172+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
173+
# Get qubits out
174+
(idx0 := as_int(0)),
175+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
176+
# qubit.reset only accepts ilist of qubits
177+
(qlist := ilist.New(values=[q0.result])),
178+
(squin.qubit.MeasureAndReset(qlist.result)),
179+
(ret_none := func.ConstantNone()),
180+
(func.Return(ret_none)),
181+
]
182+
183+
constructed_method = gen_func_from_stmts(stmts)
184+
constructed_method.print()
185+
186+
# analysis_res, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(constructed_method)
187+
# constructed_method.print(analysis=analysis_res.entries)
188+
189+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
190+
rewrite_result = squin_to_stim(constructed_method)
191+
print(rewrite_result)
192+
constructed_method.print()
193+
194+
195+
def test_wire_measure_and_reset():
196+
197+
stmts: list[ir.Statement] = [
198+
# Create qubit register
199+
(n_qubits := as_int(1)),
200+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
201+
# Get qubits out
202+
(idx0 := as_int(0)),
203+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
204+
# get wire out
205+
(w0 := squin.wire.Unwrap(q0.result)),
206+
# qubit.reset only accepts ilist of qubits
207+
(squin.wire.MeasureAndReset(w0.result)),
208+
(ret_none := func.ConstantNone()),
209+
(func.Return(ret_none)),
210+
]
211+
212+
constructed_method = gen_func_from_stmts(stmts)
213+
constructed_method.print()
214+
215+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
216+
rewrite_result = squin_to_stim(constructed_method)
217+
print(rewrite_result)
218+
constructed_method.print()
219+
220+
221+
# test_wire_measure_and_reset()
222+
# test_qubit_measure_and_reset()

0 commit comments

Comments
 (0)