Skip to content

Commit 300f9d7

Browse files
committed
account for MeasureAndReset, fix up address analysis
1 parent 20d4214 commit 300f9d7

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,20 @@ def apply(
210210
)
211211
return new_address_wires
212212

213+
@interp.impl(squin.wire.MeasureAndReset)
214+
def measure_and_reset(
215+
self,
216+
interp_: AddressAnalysis,
217+
frame: ForwardFrame[Address],
218+
stmt: squin.wire.MeasureAndReset,
219+
):
220+
221+
# take the address data from the incoming wire
222+
# and propagate that forward to the new wire generated.
223+
# The first entry can safely be NotQubit because
224+
# it's an integer
225+
return (NotQubit(), frame.get(stmt.wire))
226+
213227

214228
@squin.qubit.dialect.register(key="qubit.address")
215229
class SquinQubitMethodTable(interp.MethodTable):

src/bloqade/squin/rewrite/stim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def rewrite_Measure(
322322

323323
if isinstance(measure_stmt, qubit.Measure):
324324
qubit_ilist_ssa = measure_stmt.qubits
325-
# qubits are in an ilist which makes up an AddressTuple
326325
address_attr = self.get_address_attr(qubit_ilist_ssa)
327326

328327
elif isinstance(measure_stmt, wire.Measure):
@@ -386,6 +385,7 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult:
386385

387386
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
388387
stim_rz_stmt.insert_before(reset_stmt)
388+
reset_stmt.delete()
389389

390390
return RewriteResult(has_done_something=True)
391391

test/squin/stim/stim.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from kirin import ir, types
2+
from kirin.passes import Fold
23
from kirin.dialects import py, func, ilist
34

45
import bloqade.squin.passes as squin_passes
56
from bloqade import qasm2, squin
7+
from bloqade.analysis import address
68

79

810
def as_int(value: int):
@@ -150,7 +152,33 @@ def test_qubit_reset():
150152
# qubit.reset only accepts ilist of qubits
151153
(qlist := ilist.New(values=[q0.result])),
152154
(squin.qubit.Reset(qubits=qlist.result)),
153-
(squin.qubit.Measure(qubits=qlist.result)),
155+
# (squin.qubit.Measure(qubits=qlist.result)),
156+
(ret_none := func.ConstantNone()),
157+
(func.Return(ret_none)),
158+
]
159+
160+
constructed_method = gen_func_from_stmts(stmts)
161+
constructed_method.print()
162+
163+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
164+
rewrite_result = squin_to_stim(constructed_method)
165+
print(rewrite_result)
166+
constructed_method.print()
167+
168+
169+
def test_wire_reset():
170+
171+
stmts: list[ir.Statement] = [
172+
# Create qubit register
173+
(n_qubits := as_int(1)),
174+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
175+
# Get qubits out
176+
(idx0 := as_int(0)),
177+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
178+
# get wire
179+
(w0 := squin.wire.Unwrap(q0.result)),
180+
# reset the wire
181+
(squin.wire.Reset(w0.result)),
154182
(ret_none := func.ConstantNone()),
155183
(func.Return(ret_none)),
156184
]
@@ -212,11 +240,22 @@ def test_wire_measure_and_reset():
212240
constructed_method = gen_func_from_stmts(stmts)
213241
constructed_method.print()
214242

243+
fold_pass = Fold(constructed_method.dialects)
244+
fold_pass(constructed_method)
245+
# need to make sure the origin qubit data is properly
246+
# propagated to the new wire that wire.MeasureAndReset spits out
247+
address_res, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis(
248+
constructed_method
249+
)
250+
constructed_method.print(analysis=address_res.entries)
251+
215252
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
216253
rewrite_result = squin_to_stim(constructed_method)
217254
print(rewrite_result)
218255
constructed_method.print()
219256

220257

221-
# test_wire_measure_and_reset()
258+
test_wire_measure_and_reset()
222259
# test_qubit_measure_and_reset()
260+
# test_wire_reset()
261+
# test_qubit_reset()

0 commit comments

Comments
 (0)