Skip to content

Commit 4e469e3

Browse files
Allow measurements and resets into Gemini noise model. (#583)
To simulate some operations, for example multiple rounds of syndrome extraction, it is useful to be able to measure and reuse the measurement qubits in the cirq simulation to save on RAM (even if on Gemini we will actually use different measurement atoms for each round of syndrome extraction). Here is a minimum fix to allow for measurement and reset gates in the noise models. The next step would be to add error in the measurement that is proportional to the number of gate operations left in the circuit, since in Gemini we will have to wait until the end of the circuit to measure any atom. @david-pl what do you think is the right course of action? --------- Co-authored-by: David Plankensteiner <[email protected]> Co-authored-by: David Plankensteiner <[email protected]>
1 parent 65cbc8d commit 4e469e3

File tree

4 files changed

+201
-31
lines changed

4 files changed

+201
-31
lines changed

src/bloqade/cirq_utils/noise/model.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def __post_init__(self):
9898

9999
@staticmethod
100100
def validate_moments(moments: Iterable[cirq.Moment]):
101-
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates
101+
reset_family = cirq.GateFamily(gate=cirq.ResetChannel, ignore_global_phase=True)
102+
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(
103+
additional_gates=[reset_family]
104+
).gates
102105

103106
for moment in moments:
104107
for operation in moment:
@@ -117,7 +120,7 @@ def validate_moments(moments: Iterable[cirq.Moment]):
117120
)
118121

119122
def parallel_cz_errors(
120-
self, ctrls: list[int], qargs: list[int], rest: list[int]
123+
self, ctrls: Sequence[int], qargs: Sequence[int], rest: Sequence[int]
121124
) -> dict[tuple[float, float, float, float], list[int]]:
122125
raise NotImplementedError(
123126
"This noise model doesn't support rewrites on bloqade kernels, but should be used with cirq."
@@ -245,8 +248,15 @@ def noisy_moment(self, moment, system_qubits):
245248
# Moment with original ops
246249
original_moment = moment
247250

251+
no_noise_condition = (
252+
len(moment.operations) == 0
253+
or cirq.is_measurement(moment.operations[0])
254+
or isinstance(moment.operations[0].gate, cirq.ResetChannel)
255+
or isinstance(moment.operations[0].gate, cirq.BitFlipChannel)
256+
)
257+
248258
# Check if the moment is empty
249-
if len(moment.operations) == 0:
259+
if no_noise_condition:
250260
move_noise_ops = []
251261
gate_noise_ops = []
252262
# Check if the moment contains 1-qubit gates or 2-qubit gates
@@ -319,20 +329,84 @@ def noisy_moments(
319329

320330
# Split into moments with only 1Q and 2Q gates
321331
moments_1q = [
322-
cirq.Moment([op for op in moment.operations if len(op.qubits) == 1])
332+
cirq.Moment(
333+
[
334+
op
335+
for op in moment.operations
336+
if (len(op.qubits) == 1)
337+
and (not cirq.is_measurement(op))
338+
and (not isinstance(op.gate, cirq.ResetChannel))
339+
]
340+
)
323341
for moment in moments
324342
]
325343
moments_2q = [
326-
cirq.Moment([op for op in moment.operations if len(op.qubits) == 2])
344+
cirq.Moment(
345+
[
346+
op
347+
for op in moment.operations
348+
if (len(op.qubits) == 2) and (not cirq.is_measurement(op))
349+
]
350+
)
327351
for moment in moments
328352
]
329353

330-
assert len(moments_1q) == len(moments_2q)
354+
moments_measurement = [
355+
cirq.Moment(
356+
[
357+
op
358+
for op in moment.operations
359+
if (cirq.is_measurement(op))
360+
or (isinstance(op.gate, cirq.ResetChannel))
361+
]
362+
)
363+
for moment in moments
364+
]
365+
366+
assert len(moments_1q) == len(moments_2q) == len(moments_measurement)
331367

332368
interleaved_moments = []
369+
370+
def count_remaining_cz_moments(moments_2q):
371+
remaining_cz_counts = []
372+
count = 0
373+
for m in moments_2q[::-1]:
374+
if any(isinstance(op.gate, cirq.CZPowGate) for op in m.operations):
375+
count += 1
376+
remaining_cz_counts = [count] + remaining_cz_counts
377+
return remaining_cz_counts
378+
379+
remaining_cz_moments = count_remaining_cz_moments(moments_2q)
380+
381+
pm = 2 * self.sitter_pauli_rates[0]
382+
ps = 2 * self.cz_unpaired_pauli_rates[0]
383+
384+
# probability of a bitflip error for a sitting, unpaired qubit during a move/cz/move cycle.
385+
heuristic_1step_bitflip_error: float = (
386+
2 * pm * (1 - ps) * (1 - pm) + (1 - pm) ** 2 * ps + pm**2 * ps
387+
)
388+
333389
for idx, moment in enumerate(moments_1q):
334390
interleaved_moments.append(moment)
335391
interleaved_moments.append(moments_2q[idx])
392+
# Measurements on Gemini will be at the end, so for circuits with mid-circuit measurements we will insert a
393+
# bitflip error proportional to the number of moments left in the circuit to account for the decoherence
394+
# that will happen before the final terminal measurement.
395+
measured_qubits = []
396+
for op in moments_measurement[idx].operations:
397+
if cirq.is_measurement(op):
398+
measured_qubits += list(op.qubits)
399+
# probability of a bitflip error should be Binomial(moments_left,heuristic_1step_bitflip_error)
400+
delayed_measurement_error = (
401+
1
402+
- (1 - 2 * heuristic_1step_bitflip_error) ** (remaining_cz_moments[idx])
403+
) / 2
404+
interleaved_moments.append(
405+
cirq.Moment(
406+
cirq.bit_flip(delayed_measurement_error).on_each(measured_qubits)
407+
)
408+
)
409+
interleaved_moments.append(moments_measurement[idx])
336410

337411
interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments)
338412

@@ -368,14 +442,21 @@ def noisy_moment(self, moment, system_qubits):
368442
"all qubits in the circuit must be defined as cirq.GridQubit objects."
369443
)
370444
# Check if the moment is empty
371-
if len(moment.operations) == 0:
445+
if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]):
372446
move_moments = []
373447
gate_noise_ops = []
374448
# Check if the moment contains 1-qubit gates or 2-qubit gates
375449
elif len(moment.operations[0].qubits) == 1:
376-
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
377-
moment, system_qubits
378-
)
450+
if (
451+
(isinstance(moment.operations[0].gate, cirq.ResetChannel))
452+
or (cirq.is_measurement(moment.operations[0]))
453+
or (isinstance(moment.operations[0].gate, cirq.BitFlipChannel))
454+
):
455+
gate_noise_ops = []
456+
else:
457+
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
458+
moment, system_qubits
459+
)
379460
move_moments = []
380461
elif len(moment.operations[0].qubits) == 2:
381462
cg = OneZoneConflictGraph(moment)

src/bloqade/cirq_utils/parallelize.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def auto_similarity(
119119
flattened_circuit: list[GateOperation] = list(cirq.flatten_op_tree(circuit))
120120
weights = {}
121121
for i in range(len(flattened_circuit)):
122+
if not cirq.has_unitary(flattened_circuit[i]):
123+
continue
122124
for j in range(i + 1, len(flattened_circuit)):
125+
if not cirq.has_unitary(flattened_circuit[j]):
126+
continue
123127
op1 = flattened_circuit[i]
124128
op2 = flattened_circuit[j]
125129
if can_be_parallel(op1, op2):
@@ -297,14 +301,20 @@ def colorize(
297301
for epoch in epochs:
298302
oneq_gates = []
299303
twoq_gates = []
304+
nonunitary_gates = []
300305
for gate in epoch:
301-
if len(gate.val.qubits) == 1:
306+
if not cirq.has_unitary(gate.val):
307+
nonunitary_gates.append(gate.val)
308+
elif len(gate.val.qubits) == 1:
302309
oneq_gates.append(gate.val)
303310
elif len(gate.val.qubits) == 2:
304311
twoq_gates.append(gate.val)
305312
else:
306313
raise RuntimeError("Unsupported gate type")
307314

315+
if len(nonunitary_gates) > 0:
316+
yield nonunitary_gates
317+
308318
if len(oneq_gates) > 0:
309319
yield oneq_gates
310320

test/cirq_utils/noise/test_noise_models.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616

17-
def create_ghz_circuit(qubits):
17+
def create_ghz_circuit(qubits, measurements: bool = False):
1818
n = len(qubits)
1919
circuit = cirq.Circuit()
2020

@@ -24,26 +24,41 @@ def create_ghz_circuit(qubits):
2424
# Step 2: CNOT chain from qubit i to i+1
2525
for i in range(n - 1):
2626
circuit.append(cirq.CNOT(qubits[i], qubits[i + 1]))
27+
if measurements:
28+
circuit.append(cirq.measure(qubits[i]))
29+
circuit.append(cirq.reset(qubits[i]))
30+
31+
if measurements:
32+
circuit.append(cirq.measure(qubits[-1]))
33+
circuit.append(cirq.reset(qubits[-1]))
2734

2835
return circuit
2936

3037

3138
@pytest.mark.parametrize(
32-
"model,qubits",
39+
"model,qubits,measurements",
3340
[
34-
(GeminiOneZoneNoiseModel(), None),
41+
(GeminiOneZoneNoiseModel(), None, False),
42+
(
43+
GeminiOneZoneNoiseModelConflictGraphMoves(),
44+
cirq.GridQubit.rect(rows=1, cols=2),
45+
False,
46+
),
47+
(GeminiTwoZoneNoiseModel(), None, False),
48+
(GeminiOneZoneNoiseModel(), None, True),
3549
(
3650
GeminiOneZoneNoiseModelConflictGraphMoves(),
3751
cirq.GridQubit.rect(rows=1, cols=2),
52+
True,
3853
),
39-
(GeminiTwoZoneNoiseModel(), None),
54+
(GeminiTwoZoneNoiseModel(), None, True),
4055
],
4156
)
42-
def test_simple_model(model: cirq.NoiseModel, qubits):
57+
def test_simple_model(model: cirq.NoiseModel, qubits, measurements: bool):
4358
if qubits is None:
4459
qubits = cirq.LineQubit.range(2)
4560

46-
circuit = create_ghz_circuit(qubits)
61+
circuit = create_ghz_circuit(qubits, measurements=measurements)
4762

4863
with pytest.raises(ValueError):
4964
# make sure only native gate set is supported
@@ -74,13 +89,25 @@ def test_simple_model(model: cirq.NoiseModel, qubits):
7489
for i in range(4):
7590
pops_bloqade[i] += abs(ket[i]) ** 2 / nshots
7691

77-
for pops in (pops_bloqade, pops_cirq):
78-
assert math.isclose(pops[0], 0.5, abs_tol=1e-1)
79-
assert math.isclose(pops[3], 0.5, abs_tol=1e-1)
80-
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
81-
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
82-
83-
assert pops[0] < 0.5001
84-
assert pops[3] < 0.5001
85-
assert pops[1] >= 0.0
86-
assert pops[2] >= 0.0
92+
if measurements is True:
93+
for pops in (pops_bloqade, pops_cirq):
94+
assert math.isclose(pops[0], 1.0, abs_tol=1e-1)
95+
assert math.isclose(pops[3], 0.0, abs_tol=1e-1)
96+
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
97+
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
98+
99+
assert pops[0] > 0.99
100+
assert pops[3] >= 0.0
101+
assert pops[1] >= 0.0
102+
assert pops[2] >= 0.0
103+
else:
104+
for pops in (pops_bloqade, pops_cirq):
105+
assert math.isclose(pops[0], 0.5, abs_tol=1e-1)
106+
assert math.isclose(pops[3], 0.5, abs_tol=1e-1)
107+
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
108+
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
109+
110+
assert pops[0] < 0.5001
111+
assert pops[3] < 0.5001
112+
assert pops[1] >= 0.0
113+
assert pops[2] >= 0.0

test/cirq_utils/test_parallelize.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,67 @@ def test1():
2626
)
2727

2828
circuit_m, _ = moment_similarity(circuit, weight=1.0)
29-
# print(circuit_m)
3029
circuit_b, _ = block_similarity(circuit, weight=1.0, block_id=1)
31-
circuit_m2 = remove_tags(circuit_m)
32-
print(circuit_m2)
30+
remove_tags(circuit_m)
3331
circuit2 = parallelize(circuit)
34-
# print(circuit2)
3532
assert len(circuit2.moments) == 7
3633

3734

35+
def test_measurement_and_reset():
36+
qubits = cirq.LineQubit.range(4)
37+
circuit = cirq.Circuit(
38+
cirq.H(qubits[0]),
39+
cirq.CX(qubits[0], qubits[1]),
40+
cirq.measure(qubits[1]),
41+
cirq.reset(qubits[1]),
42+
cirq.CX(qubits[1], qubits[2]),
43+
cirq.measure(qubits[2]),
44+
cirq.reset(qubits[2]),
45+
cirq.CX(qubits[2], qubits[3]),
46+
cirq.measure(qubits[0]),
47+
cirq.reset(qubits[0]),
48+
)
49+
50+
circuit_m, _ = moment_similarity(circuit, weight=1.0)
51+
circuit_b, _ = block_similarity(circuit, weight=1.0, block_id=1)
52+
remove_tags(circuit_m)
53+
54+
parallelized_circuit = parallelize(circuit)
55+
56+
print(parallelized_circuit)
57+
58+
# NOTE: depending on hardware, cirq produces differing, but unitary equivalent
59+
# native circuits; in some cases, there is a PhZX gate with a negative phase
60+
# which cannot be combined with others in the parallelization leading to a longer circuit
61+
assert len(parallelized_circuit.moments) in (11, 13)
62+
63+
# this circuit should deterministically return all qubits to |0>
64+
# let's check:
65+
simulator = cirq.Simulator()
66+
for _ in range(20): # one in a million chance we miss an error
67+
state_vector = simulator.simulate(parallelized_circuit).state_vector()
68+
assert np.all(
69+
np.isclose(
70+
np.abs(state_vector),
71+
np.concatenate((np.array([1]), np.zeros(2**4 - 1))),
72+
)
73+
)
74+
75+
76+
def test_nonunitary_error_gate():
77+
qubits = cirq.LineQubit.range(2)
78+
circuit = cirq.Circuit(
79+
cirq.H(qubits[0]),
80+
cirq.CX(qubits[0], qubits[1]),
81+
cirq.amplitude_damp(0.5).on(qubits[1]),
82+
cirq.CX(qubits[1], qubits[0]),
83+
)
84+
85+
parallelized_circuit = parallelize(circuit)
86+
87+
assert len(parallelized_circuit.moments) == 7
88+
89+
3890
RNG_STATE = np.random.RandomState(1902833)
3991

4092

0 commit comments

Comments
 (0)