Skip to content

Commit b057ea8

Browse files
committed
Revert "remove tests for measurement/reset and pyqrack."
This reverts commit 3c21aae.
1 parent aab06de commit b057ea8

File tree

3 files changed

+100
-60
lines changed

3 files changed

+100
-60
lines changed

src/bloqade/cirq_utils/noise/model.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class GeminiNoiseModelABC(cirq.NoiseModel, MoveNoiseModelABC):
5757

5858
def __post_init__(self):
5959
if (
60-
self.cz_paired_correlated_rates is None
61-
and self.cz_paired_error_probabilities is None
60+
self.cz_paired_correlated_rates is None
61+
and self.cz_paired_error_probabilities is None
6262
):
6363
# NOTE: no input, set to default value; weird setattr for frozen dataclass
6464
object.__setattr__(
@@ -67,8 +67,8 @@ def __post_init__(self):
6767
_default_cz_paired_correlated_rates(),
6868
)
6969
elif (
70-
self.cz_paired_correlated_rates is not None
71-
and self.cz_paired_error_probabilities is None
70+
self.cz_paired_correlated_rates is not None
71+
and self.cz_paired_error_probabilities is None
7272
):
7373

7474
if self.cz_paired_correlated_rates.shape != (4, 4):
@@ -83,8 +83,8 @@ def __post_init__(self):
8383
correlated_noise_array_to_dict(self.cz_paired_correlated_rates),
8484
)
8585
elif (
86-
self.cz_paired_correlated_rates is not None
87-
and self.cz_paired_error_probabilities is not None
86+
self.cz_paired_correlated_rates is not None
87+
and self.cz_paired_error_probabilities is not None
8888
):
8989
raise ValueError(
9090
"Received both `cz_paired_correlated_rates` and `cz_paired_correlated_rates` as input. This is ambiguous, please only set one."
@@ -93,7 +93,9 @@ def __post_init__(self):
9393
@staticmethod
9494
def validate_moments(moments: Iterable[cirq.Moment]):
9595
reset_family = cirq.GateFamily(gate=cirq.ResetChannel, ignore_global_phase=True)
96-
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(additional_gates=[reset_family]).gates
96+
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(
97+
additional_gates=[reset_family]
98+
).gates
9799
# allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates
98100

99101
for moment in moments:
@@ -113,7 +115,7 @@ def validate_moments(moments: Iterable[cirq.Moment]):
113115
)
114116

115117
def parallel_cz_errors(
116-
self, ctrls: list[int], qargs: list[int], rest: list[int]
118+
self, ctrls: list[int], qargs: list[int], rest: list[int]
117119
) -> dict[tuple[float, float, float, float], list[int]]:
118120
raise NotImplementedError(
119121
"This noise model doesn't support rewrites on bloqade kernels, but should be used with cirq."
@@ -179,7 +181,7 @@ class GeminiOneZoneNoiseModel(GeminiNoiseModelABC):
179181
parallelize_circuit: bool = False
180182

181183
def _single_qubit_moment_noise_ops(
182-
self, moment: cirq.Moment, system_qubits: Sequence[cirq.Qid]
184+
self, moment: cirq.Moment, system_qubits: Sequence[cirq.Qid]
183185
) -> tuple[list, list]:
184186
"""
185187
Helper function to determine the noise operations for a single qubit moment.
@@ -211,7 +213,7 @@ def _single_qubit_moment_noise_ops(
211213
op.qubits[0]
212214
for op in moment.operations
213215
if not (
214-
np.isclose(op.gate.x_exponent, 0) and np.isclose(op.gate.z_exponent, 0)
216+
np.isclose(op.gate.x_exponent, 0) and np.isclose(op.gate.z_exponent, 0)
215217
)
216218
]
217219

@@ -247,9 +249,11 @@ def noisy_moment(self, moment, system_qubits):
247249
gate_noise_ops = []
248250
# Check if the moment contains 1-qubit gates or 2-qubit gates
249251
elif len(moment.operations[0].qubits) == 1:
250-
if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (
251-
cirq.is_measurement(moment.operations[0])) or (
252-
isinstance(moment.operations[0].gate, cirq.BitFlipChannel)):
252+
if (
253+
(isinstance(moment.operations[0].gate, cirq.ResetChannel))
254+
or (cirq.is_measurement(moment.operations[0]))
255+
or (isinstance(moment.operations[0].gate, cirq.BitFlipChannel))
256+
):
253257
move_noise_ops = []
254258
gate_noise_ops = []
255259
else:
@@ -303,7 +307,7 @@ def noisy_moment(self, moment, system_qubits):
303307
]
304308

305309
def noisy_moments(
306-
self, moments: Iterable[cirq.Moment], system_qubits: Sequence[cirq.Qid]
310+
self, moments: Iterable[cirq.Moment], system_qubits: Sequence[cirq.Qid]
307311
) -> Sequence[cirq.OP_TREE]:
308312
"""Adds possibly stateful noise to a series of moments.
309313
@@ -321,18 +325,37 @@ def noisy_moments(
321325

322326
# Split into moments with only 1Q and 2Q gates
323327
moments_1q = [
324-
cirq.Moment([op for op in moment.operations if (len(op.qubits) == 1) and (not cirq.is_measurement(op)) and (
325-
not isinstance(op.gate, cirq.ResetChannel))])
328+
cirq.Moment(
329+
[
330+
op
331+
for op in moment.operations
332+
if (len(op.qubits) == 1)
333+
and (not cirq.is_measurement(op))
334+
and (not isinstance(op.gate, cirq.ResetChannel))
335+
]
336+
)
326337
for moment in moments
327338
]
328339
moments_2q = [
329-
cirq.Moment([op for op in moment.operations if (len(op.qubits) == 2) and (not cirq.is_measurement(op))])
340+
cirq.Moment(
341+
[
342+
op
343+
for op in moment.operations
344+
if (len(op.qubits) == 2) and (not cirq.is_measurement(op))
345+
]
346+
)
330347
for moment in moments
331348
]
332349

333350
moments_measurement = [
334-
cirq.Moment([op for op in moment.operations if
335-
(cirq.is_measurement(op)) or (isinstance(op.gate, cirq.ResetChannel))])
351+
cirq.Moment(
352+
[
353+
op
354+
for op in moment.operations
355+
if (cirq.is_measurement(op))
356+
or (isinstance(op.gate, cirq.ResetChannel))
357+
]
358+
)
336359
for moment in moments
337360
]
338361

@@ -356,7 +379,9 @@ def count_remaining_cz_moments(moments_2q):
356379
ps = 2 * self.cz_unpaired_pauli_rates[0]
357380

358381
# probability of a bitflip error for a sitting, unpaired qubit during a move/cz/move cycle.
359-
heuristic_1step_bitflip_error: float = 2 * pm * (1 - ps) * (1 - pm) + (1 - pm) ** 2 * ps + pm ** 2 * ps
382+
heuristic_1step_bitflip_error: float = (
383+
2 * pm * (1 - ps) * (1 - pm) + (1 - pm) ** 2 * ps + pm**2 * ps
384+
)
360385

361386
for idx, moment in enumerate(moments_1q):
362387
interleaved_moments.append(moment)
@@ -369,8 +394,15 @@ def count_remaining_cz_moments(moments_2q):
369394
if cirq.is_measurement(op):
370395
measured_qubits += list(op.qubits)
371396
# probability of a bitflip error should be Binomial(moments_left,heuristic_1step_bitflip_error)
372-
delayed_measurement_error = (1 - (1 - 2 * heuristic_1step_bitflip_error) ** (remaining_cz_moments[idx])) / 2
373-
interleaved_moments.append(cirq.Moment(cirq.bit_flip(delayed_measurement_error).on_each(measured_qubits)))
397+
delayed_measurement_error = (
398+
1
399+
- (1 - 2 * heuristic_1step_bitflip_error) ** (remaining_cz_moments[idx])
400+
) / 2
401+
interleaved_moments.append(
402+
cirq.Moment(
403+
cirq.bit_flip(delayed_measurement_error).on_each(measured_qubits)
404+
)
405+
)
374406
interleaved_moments.append(moments_measurement[idx])
375407

376408
interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments)
@@ -412,9 +444,11 @@ def noisy_moment(self, moment, system_qubits):
412444
gate_noise_ops = []
413445
# Check if the moment contains 1-qubit gates or 2-qubit gates
414446
elif len(moment.operations[0].qubits) == 1:
415-
if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (
416-
cirq.is_measurement(moment.operations[0])) or (
417-
isinstance(moment.operations[0].gate, cirq.BitFlipChannel)):
447+
if (
448+
(isinstance(moment.operations[0].gate, cirq.ResetChannel))
449+
or (cirq.is_measurement(moment.operations[0]))
450+
or (isinstance(moment.operations[0].gate, cirq.BitFlipChannel))
451+
):
418452
gate_noise_ops = []
419453
else:
420454
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
@@ -484,7 +518,7 @@ def noisy_moment(self, moment, system_qubits):
484518
@dataclass(frozen=True)
485519
class GeminiTwoZoneNoiseModel(GeminiNoiseModelABC):
486520
def noisy_moments(
487-
self, moments: Iterable[cirq.Moment], system_qubits: Sequence[cirq.Qid]
521+
self, moments: Iterable[cirq.Moment], system_qubits: Sequence[cirq.Qid]
488522
) -> Sequence[cirq.OP_TREE]:
489523
"""Adds possibly stateful noise to a series of moments.
490524
@@ -516,12 +550,12 @@ def noisy_moments(
516550
[
517551
moment
518552
for moment in _two_zone_utils.get_move_error_channel_two_zoned(
519-
moments[i],
520-
prev_moment,
521-
np.array(self.mover_pauli_rates),
522-
np.array(self.sitter_pauli_rates),
523-
nqubs,
524-
).moments
553+
moments[i],
554+
prev_moment,
555+
np.array(self.mover_pauli_rates),
556+
np.array(self.sitter_pauli_rates),
557+
nqubs,
558+
).moments
525559
if len(moment) > 0
526560
]
527561
)
@@ -532,13 +566,13 @@ def noisy_moments(
532566
[
533567
moment
534568
for moment in _two_zone_utils.get_gate_error_channel(
535-
moments[i],
536-
np.array(self.local_pauli_rates),
537-
np.array(self.global_pauli_rates),
538-
self.two_qubit_pauli,
539-
np.array(self.cz_unpaired_pauli_rates),
540-
nqubs,
541-
).moments
569+
moments[i],
570+
np.array(self.local_pauli_rates),
571+
np.array(self.global_pauli_rates),
572+
self.two_qubit_pauli,
573+
np.array(self.cz_unpaired_pauli_rates),
574+
nqubs,
575+
).moments
542576
if len(moment) > 0
543577
]
544578
)

test/cirq_utils/noise/test_noise_models.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ def create_ghz_circuit(qubits, measurements: bool = False):
4040
[
4141
(GeminiOneZoneNoiseModel(), None, False),
4242
(
43-
GeminiOneZoneNoiseModelConflictGraphMoves(),
44-
cirq.GridQubit.rect(rows=1, cols=2),
45-
False
43+
GeminiOneZoneNoiseModelConflictGraphMoves(),
44+
cirq.GridQubit.rect(rows=1, cols=2),
45+
False,
4646
),
4747
(GeminiTwoZoneNoiseModel(), None, False),
4848
(GeminiOneZoneNoiseModel(), None, True),
4949
(
50-
GeminiOneZoneNoiseModelConflictGraphMoves(),
51-
cirq.GridQubit.rect(rows=1, cols=2),
52-
True
50+
GeminiOneZoneNoiseModelConflictGraphMoves(),
51+
cirq.GridQubit.rect(rows=1, cols=2),
52+
True,
5353
),
5454
(GeminiTwoZoneNoiseModel(), None, True),
5555
],
@@ -76,22 +76,21 @@ def test_simple_model(model: cirq.NoiseModel, qubits, measurements: bool):
7676
dm = cirq_sim.simulate(noisy_circuit).final_density_matrix
7777
pops_cirq = np.real(np.diag(dm))
7878

79-
if not measurements:
80-
kernel = load_circuit(noisy_circuit)
81-
pyqrack_sim = StackMemorySimulator(
82-
min_qubits=2, rng_state=np.random.default_rng(1234)
83-
)
79+
kernel = load_circuit(noisy_circuit)
80+
pyqrack_sim = StackMemorySimulator(
81+
min_qubits=2, rng_state=np.random.default_rng(1234)
82+
)
8483

85-
pops_bloqade = [0.0] * 4
84+
pops_bloqade = [0.0] * 4
8685

87-
nshots = 500
88-
for _ in range(nshots):
89-
ket = pyqrack_sim.state_vector(kernel)
90-
for i in range(4):
91-
pops_bloqade[i] += abs(ket[i]) ** 2 / nshots
86+
nshots = 500
87+
for _ in range(nshots):
88+
ket = pyqrack_sim.state_vector(kernel)
89+
for i in range(4):
90+
pops_bloqade[i] += abs(ket[i]) ** 2 / nshots
9291

9392
if measurements is True:
94-
for pops in [pops_cirq]:
93+
for pops in (pops_bloqade, pops_cirq):
9594
assert math.isclose(pops[0], 1.0, abs_tol=1e-1)
9695
assert math.isclose(pops[3], 0.0, abs_tol=1e-1)
9796
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)

test/cirq_utils/test_parallelize.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test1():
2727

2828
circuit_m, _ = moment_similarity(circuit, weight=1.0)
2929
circuit_b, _ = block_similarity(circuit, weight=1.0, block_id=1)
30-
circuit_m2 = remove_tags(circuit_m)
30+
remove_tags(circuit_m)
3131
circuit2 = parallelize(circuit)
3232
assert len(circuit2.moments) == 7
3333

@@ -49,16 +49,23 @@ def test_measurement_and_reset():
4949

5050
circuit_m, _ = moment_similarity(circuit, weight=1.0)
5151
circuit_b, _ = block_similarity(circuit, weight=1.0, block_id=1)
52-
circuit_m2 = remove_tags(circuit_m)
52+
remove_tags(circuit_m)
5353

5454
parallelized_circuit = parallelize(circuit)
5555

56+
assert len(parallelized_circuit.moments) == 11
57+
5658
# this circuit should deterministically return all qubits to |0>
5759
# let's check:
5860
simulator = cirq.Simulator()
5961
for _ in range(20): # one in a million chance we miss an error
6062
state_vector = simulator.simulate(parallelized_circuit).state_vector()
61-
assert np.all(np.isclose(np.abs(state_vector), np.concatenate((np.array([1]), np.zeros(2 ** 4 - 1)))))
63+
assert np.all(
64+
np.isclose(
65+
np.abs(state_vector),
66+
np.concatenate((np.array([1]), np.zeros(2**4 - 1))),
67+
)
68+
)
6269

6370

6471
def test_nonunitary_error_gate():

0 commit comments

Comments
 (0)