Skip to content

Commit f3203da

Browse files
committed
partially working reset rewrite
1 parent 7ba5670 commit f3203da

File tree

2 files changed

+207
-49
lines changed

2 files changed

+207
-49
lines changed

src/bloqade/squin/rewrite/stim.py

Lines changed: 150 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from bloqade import stim
1010
from bloqade.squin import op, wire, qubit
11-
from bloqade.analysis.address import Address, AddressWire, AddressTuple
11+
from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple
1212
from bloqade.squin.analysis.nsites import Sites
1313

1414
# Probably best to move these attributes to a
@@ -82,9 +82,12 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
8282
@dataclass
8383
class _SquinToStim(RewriteRule):
8484

85-
def get_address(self, value: ir.SSAValue):
85+
def get_address(self, value: ir.SSAValue) -> AddressAttribute:
86+
8687
try:
87-
return value.hints["address"]
88+
address_attr = value.hints["address"]
89+
assert isinstance(address_attr, AddressAttribute)
90+
return address_attr
8891
except KeyError:
8992
raise KeyError(f"The address analysis hint for {value} does not exist")
9093

@@ -115,40 +118,80 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
115118
f"The squin operator {squin_op} is not supported in the stim dialect"
116119
)
117120

121+
def insert_qubit_idx_from_address(
122+
self, address: AddressAttribute, stmt_to_insert_before: ir.Statement
123+
) -> tuple[ir.SSAValue, ...]:
124+
125+
address_data = address.address
126+
127+
qubit_idx_ssas = []
128+
129+
if isinstance(address_data, AddressTuple):
130+
for address_qubit in address_data.data:
131+
132+
# ensure that the stuff in the AddressTuple should be AddressQubit
133+
# could handle AddressWires as well but don't see the need for that right now
134+
if not isinstance(address_qubit, AddressQubit):
135+
raise ValueError(
136+
"Unsupported Address type detected inside AddressTuple, must be AddressQubit"
137+
)
138+
qubit_idx = address_qubit.data
139+
qubit_idx_stmt = py.Constant(qubit_idx)
140+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
141+
qubit_idx_ssas.append(qubit_idx_stmt.result)
142+
elif isinstance(address_data, AddressWire):
143+
address_qubit = address_data.origin_qubit
144+
qubit_idx = address_qubit.data
145+
qubit_idx_stmt = py.Constant(qubit_idx)
146+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
147+
qubit_idx_ssas.append(qubit_idx_stmt.result)
148+
else:
149+
NotImplementedError(
150+
"qubit idx extraction and insertion only support for AddressTuple[AddressQubit] and AddressWire instances"
151+
)
152+
153+
return tuple(qubit_idx_ssas)
154+
155+
def insert_qubit_idx_from_wire_ssa(
156+
self, wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
157+
) -> tuple[ir.SSAValue, ...]:
158+
qubit_idx_ssas = []
159+
for wire_ssa in wire_ssas:
160+
address_attribute = self.get_address(wire_ssa) # get AddressWire
161+
# get parent qubit idx
162+
wire_address = address_attribute.address
163+
assert isinstance(wire_address, AddressWire)
164+
qubit_idx = wire_address.origin_qubit.data
165+
qubit_idx_stmt = py.Constant(qubit_idx)
166+
# accumulate all qubit idx SSA to instantiate stim gate stmt
167+
qubit_idx_ssas.append(qubit_idx_stmt.result)
168+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
169+
170+
return tuple(qubit_idx_ssas)
171+
118172
# get the qubit indices from the Apply statement argument
119173
# wires/qubits
120-
def insert_qubit_idx_ssa(
174+
175+
def insert_qubit_idx_after_apply(
121176
self, apply_stmt: wire.Apply | qubit.Apply
122177
) -> tuple[ir.SSAValue, ...]:
123178

124179
if isinstance(apply_stmt, qubit.Apply):
125180
qubits = apply_stmt.qubits
126181
address_attribute: AddressAttribute = self.get_address(qubits)
127182
# Should get an AddressTuple out of the address stored in attribute
128-
address_tuple = address_attribute.address
129-
qubit_idx_ssas: list[ir.SSAValue] = []
130-
for address_qubit in address_tuple.data:
131-
qubit_idx = address_qubit.data
132-
qubit_idx_stmt = py.Constant(qubit_idx)
133-
qubit_idx_stmt.insert_before(apply_stmt)
134-
qubit_idx_ssas.append(qubit_idx_stmt.result)
135-
136-
return tuple(qubit_idx_ssas)
137-
183+
return self.insert_qubit_idx_from_address(
184+
address=address_attribute, stmt_to_insert_before=apply_stmt
185+
)
138186
elif isinstance(apply_stmt, wire.Apply):
139187
wire_ssas = apply_stmt.inputs
140-
qubit_idx_ssas: list[ir.SSAValue] = []
141-
for wire_ssa in wire_ssas:
142-
address_attribute = self.get_address(wire_ssa)
143-
# get parent qubit idx
144-
wire_address = address_attribute.address
145-
qubit_idx = wire_address.origin_qubit.data
146-
qubit_idx_stmt = py.Constant(qubit_idx)
147-
# accumulate all qubit idx SSA to instantiate stim gate stmt
148-
qubit_idx_ssas.append(qubit_idx_stmt.result)
149-
qubit_idx_stmt.insert_before(apply_stmt)
150-
151-
return tuple(qubit_idx_ssas)
188+
return self.insert_qubit_idx_from_wire_ssa(
189+
wire_ssas=wire_ssas, stmt_to_insert_before=apply_stmt
190+
)
191+
else:
192+
raise TypeError(
193+
"unsupported statement detected, only wire.Apply and qubit.Apply statements are supported by this method"
194+
)
152195

153196
# might be worth attempting multiple dispatch like qasm2 rewrites
154197
# for Glob and Parallel to UOp
@@ -168,6 +211,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
168211
return self.rewrite_Wrap(node)
169212
case wire.Measure() | qubit.Measure():
170213
return self.rewrite_Measure(node)
214+
case wire.Reset() | qubit.Reset():
215+
return self.rewrite_Reset(node)
171216
case _:
172217
return RewriteResult()
173218

@@ -188,6 +233,7 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
188233

189234
# this is an SSAValue, need it to be the actual operator
190235
applied_op = apply_stmt.operator.owner
236+
assert isinstance(applied_op, op.stmts.Operator)
191237

192238
if isinstance(applied_op, op.stmts.Control):
193239
return self.rewrite_Control(apply_stmt)
@@ -196,7 +242,25 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
196242
# but we can handle X, Y, Z, H, and S here just fine
197243
stim_1q_op = self.get_stim_1q_gate(applied_op)
198244

199-
qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt)
245+
# wire.Apply -> tuple of SSA -> AddressTuple
246+
# qubit.Apply -> list of qubits -> AddressTuple
247+
## Both cases the statements follow the Stim semantics of
248+
## 1QGate a b c d ....
249+
250+
if isinstance(apply_stmt, qubit.Apply):
251+
address_attr = self.get_address(apply_stmt.qubits)
252+
qubit_idx_ssas = self.insert_qubit_idx_from_address(
253+
address=address_attr, stmt_to_insert_before=apply_stmt
254+
)
255+
elif isinstance(apply_stmt, wire.Apply):
256+
qubit_idx_ssas = self.insert_qubit_idx_from_wire_ssa(
257+
wire_ssas=apply_stmt.inputs, stmt_to_insert_before=apply_stmt
258+
)
259+
else:
260+
raise TypeError(
261+
"Unsupported statement detected, only qubit.Apply and wire.Apply are permitted"
262+
)
263+
200264
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
201265
stim_1q_stmt.insert_before(apply_stmt)
202266

@@ -209,13 +273,16 @@ def rewrite_Control(
209273
# operator of Apply is a Control gate, enforce it's only asking for 1 control qubit,
210274
# and that the target of the control is X, Y, Z in squin
211275

212-
ctrl_op: op.stmts.Control = apply_stmt_ctrl.operator.owner
276+
ctrl_op = apply_stmt_ctrl.operator.owner
277+
assert isinstance(ctrl_op, op.stmts.Control)
278+
213279
# enforce that n_controls is 1
214280

215281
ctrl_op_target_gate = ctrl_op.op.owner
282+
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
216283

217284
# should enforce that this is some multiple of 2
218-
qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt_ctrl)
285+
qubit_idx_ssas = self.insert_qubit_idx_after_apply(apply_stmt=apply_stmt_ctrl)
219286
# according to stim, final result can be:
220287
# CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4)
221288
target_qubits = []
@@ -254,26 +321,26 @@ def rewrite_Measure(
254321
if isinstance(measure_stmt, qubit.Measure):
255322
qubit_ilist_ssa = measure_stmt.qubits
256323
# qubits are in an ilist which makes up an AddressTuple
257-
address_tuple: AddressTuple = self.get_address(qubit_ilist_ssa).address
258-
qubit_idx_ssas = []
259-
for qubit_address in address_tuple:
260-
qubit_idx = qubit_address.data
261-
qubit_idx_stmt = py.constant.Constant(qubit_idx)
262-
qubit_idx_stmt.insert_before(measure_stmt)
263-
qubit_idx_ssas.append(qubit_idx_stmt.result)
264-
qubit_idx_ssas = tuple(qubit_idx_ssas)
324+
address_attr = self.get_address(qubit_ilist_ssa)
265325

266326
elif isinstance(measure_stmt, wire.Measure):
327+
# Wire Terminator, should kill the existence of
328+
# the wire here so DCE can sweep up the rest like with rewriting wrap
267329
wire_ssa = measure_stmt.wire
268-
wire_address: AddressWire = self.get_address(wire_ssa).address
330+
address_attr = self.get_address(wire_ssa)
269331

270-
qubit_idx = wire_address.origin_qubit.data
271-
qubit_idx_stmt = py.constant.Constant(qubit_idx)
272-
qubit_idx_stmt.insert_before(measure_stmt)
273-
qubit_idx_ssas = (qubit_idx_stmt.result,)
332+
# DCE can't remove the old measure_stmt for both wire and qubit versions
333+
# because of the fact it has a result that can be depended on by other statements
334+
# whereas Stim Measure has no such notion
274335

275336
else:
276-
return RewriteResult()
337+
raise TypeError(
338+
"unsupported Statement, only qubit.Measure and wire.Measure are supported"
339+
)
340+
341+
qubit_idx_ssas = self.insert_qubit_idx_from_address(
342+
address=address_attr, stmt_to_insert_before=measure_stmt
343+
)
277344

278345
prob_noise_stmt = py.constant.Constant(0.0)
279346
stim_measure_stmt = stim.collapse.MZ(
@@ -284,3 +351,43 @@ def rewrite_Measure(
284351
stim_measure_stmt.insert_before(measure_stmt)
285352

286353
return RewriteResult(has_done_something=True)
354+
355+
def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult:
356+
"""
357+
qubit.Reset(ilist of qubits) -> nothing
358+
# safe to delete the statement afterwards, no depending results
359+
# DCE could probably do this automatically?
360+
361+
wire.Reset(single wire) -> new wire
362+
# DO NOT DELETE
363+
364+
# assume RZ, but could extend to RY and RX later
365+
Stim RZ(targets = tuple[int of SSAVals])
366+
"""
367+
368+
if isinstance(reset_stmt, qubit.Reset):
369+
qubit_ilist_ssa = reset_stmt.qubits
370+
# qubits are in an ilist which makes up an AddressTuple
371+
address_attr = self.get_address(qubit_ilist_ssa)
372+
qubit_idx_ssas = self.insert_qubit_idx_from_address(
373+
address=address_attr, stmt_to_insert_before=reset_stmt
374+
)
375+
elif isinstance(reset_stmt, wire.Reset):
376+
address_attr = self.get_address(reset_stmt.wire)
377+
qubit_idx_ssas = self.insert_qubit_idx_from_address(
378+
address=address_attr, stmt_to_insert_before=reset_stmt
379+
)
380+
else:
381+
raise TypeError(
382+
"unsupported statement, only qubit.Reset and wire.Reset are supported"
383+
)
384+
385+
stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
386+
stim_rz_stmt.insert_before(reset_stmt)
387+
388+
return RewriteResult(has_done_something=True)
389+
390+
def rewrite_MeasureAndReset(
391+
self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
392+
):
393+
pass

test/squin/stim/stim.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ def as_float(value: float):
1313
return py.constant.Constant(value=value)
1414

1515

16-
def gen_func_from_stmts(stmts):
16+
def gen_func_from_stmts(stmts, output=types.NoneType):
1717

18-
extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist)
18+
extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist).add(squin.qubit)
1919

2020
block = ir.Block(stmts)
2121
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
2222
func_wrapper = func.Function(
2323
sym_name="main",
24-
signature=func.Signature(inputs=(), output=types.NoneType),
24+
signature=func.Signature(inputs=(), output=output),
2525
body=ir.Region(blocks=block),
2626
)
2727

@@ -37,7 +37,7 @@ def gen_func_from_stmts(stmts):
3737
return constructed_method
3838

3939

40-
def test_1q():
40+
def test_wire_1q():
4141

4242
stmts: list[ir.Statement] = [
4343
# Create qubit register
@@ -76,7 +76,7 @@ def test_1q():
7676
constructed_method.print()
7777

7878

79-
def test_control():
79+
def test_wire_control():
8080

8181
stmts: list[ir.Statement] = [
8282
# Create qubit register
@@ -110,4 +110,55 @@ def test_control():
110110
constructed_method.print()
111111

112112

113-
test_control()
113+
def test_wire_measure():
114+
115+
stmts: list[ir.Statement] = [
116+
# Create qubit register
117+
(n_qubits := as_int(2)),
118+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
119+
# Get qubis out
120+
(idx0 := as_int(0)),
121+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
122+
# Unwrap to get wires
123+
(w0 := squin.wire.Unwrap(qubit=q0.result)),
124+
# measure the wires out
125+
(r0 := squin.wire.Measure(w0.result)),
126+
# return ints so DCE doesn't get
127+
# rid of everything
128+
# (ret_none := func.ConstantNone()),
129+
(func.Return(r0)),
130+
]
131+
132+
constructed_method = gen_func_from_stmts(stmts)
133+
constructed_method.print()
134+
135+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
136+
rewrite_result = squin_to_stim(constructed_method)
137+
print(rewrite_result)
138+
constructed_method.print()
139+
140+
141+
def test_qubit_reset():
142+
143+
stmts: list[ir.Statement] = [
144+
# Create qubit register
145+
(n_qubits := as_int(1)),
146+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
147+
# Get qubits out
148+
(idx0 := as_int(0)),
149+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
150+
# qubit.reset only accepts ilist of qubits
151+
(qlist := ilist.New(values=[q0.result])),
152+
(squin.qubit.Reset(qubits=qlist.result)),
153+
(squin.qubit.Measure(qubits=qlist.result)),
154+
(ret_none := func.ConstantNone()),
155+
(func.Return(ret_none)),
156+
]
157+
158+
constructed_method = gen_func_from_stmts(stmts)
159+
constructed_method.print()
160+
161+
squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects)
162+
rewrite_result = squin_to_stim(constructed_method)
163+
print(rewrite_result)
164+
constructed_method.print()

0 commit comments

Comments
 (0)