Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions src/bloqade/cirq_utils/noise/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,21 @@ class GeminiNoiseModelABC(cirq.NoiseModel, MoveNoiseModelABC):
"""The correlated CZ error rates as a dictionary"""

def __post_init__(self):
is_ambiguous = (
self.cz_paired_correlated_rates is not None
and self.cz_paired_error_probabilities is not None
)
if is_ambiguous:
raise ValueError(
"Received both `cz_paired_correlated_rates` and `cz_paired_error_probabilities` as input. This is ambiguous, please only set one."
)

use_default = (
if (
self.cz_paired_correlated_rates is None
and self.cz_paired_error_probabilities is None
)
if use_default:
):
# NOTE: no input, set to default value; weird setattr for frozen dataclass
object.__setattr__(
self,
"cz_paired_error_probabilities",
_default_cz_paired_correlated_rates(),
)
return
elif (
self.cz_paired_correlated_rates is not None
and self.cz_paired_error_probabilities is None
):

if self.cz_paired_correlated_rates is not None:
if self.cz_paired_correlated_rates.shape != (4, 4):
raise ValueError(
"Expected a 4x4 array of probabilities for cz_paired_correlated_rates"
Expand All @@ -90,15 +82,19 @@ def __post_init__(self):
"cz_paired_error_probabilities",
correlated_noise_array_to_dict(self.cz_paired_correlated_rates),
)
return

assert (
self.cz_paired_error_probabilities is not None
), "This error should not happen! Please report this issue."
elif (
self.cz_paired_correlated_rates is not None
and self.cz_paired_error_probabilities is not None
):
raise ValueError(
"Received both `cz_paired_correlated_rates` and `cz_paired_correlated_rates` as input. This is ambiguous, please only set one."
)

@staticmethod
def validate_moments(moments: Iterable[cirq.Moment]):
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates
reset_family = cirq.GateFamily(gate=cirq.ResetChannel, ignore_global_phase=True)
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(additional_gates=[reset_family]).gates
# allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates

for moment in moments:
for operation in moment:
Expand Down Expand Up @@ -246,14 +242,18 @@ def noisy_moment(self, moment, system_qubits):
original_moment = moment

# Check if the moment is empty
if len(moment.operations) == 0:
if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]):
move_noise_ops = []
gate_noise_ops = []
# Check if the moment contains 1-qubit gates or 2-qubit gates
elif len(moment.operations[0].qubits) == 1:
gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])):
move_noise_ops = []
gate_noise_ops = []
else:
gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
elif len(moment.operations[0].qubits) == 2:
control_qubits = [op.qubits[0] for op in moment.operations]
target_qubits = [op.qubits[1] for op in moment.operations]
Expand Down Expand Up @@ -319,20 +319,26 @@ def noisy_moments(

# Split into moments with only 1Q and 2Q gates
moments_1q = [
cirq.Moment([op for op in moment.operations if len(op.qubits) == 1])
cirq.Moment([op for op in moment.operations if (len(op.qubits) == 1) and (not cirq.is_measurement(op)) and (not isinstance(op.gate, cirq.ResetChannel))])
for moment in moments
]
moments_2q = [
cirq.Moment([op for op in moment.operations if len(op.qubits) == 2])
cirq.Moment([op for op in moment.operations if (len(op.qubits) == 2) and (not cirq.is_measurement(op))])
for moment in moments
]

assert len(moments_1q) == len(moments_2q)
moments_measurement = [
cirq.Moment([op for op in moment.operations if (cirq.is_measurement(op)) or (isinstance(op.gate, cirq.ResetChannel))])
for moment in moments
]

assert len(moments_1q) == len(moments_2q) == len(moments_measurement)

interleaved_moments = []
for idx, moment in enumerate(moments_1q):
interleaved_moments.append(moment)
interleaved_moments.append(moments_2q[idx])
interleaved_moments.append(moments_measurement[idx])

interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments)

Expand Down Expand Up @@ -368,14 +374,17 @@ def noisy_moment(self, moment, system_qubits):
"all qubits in the circuit must be defined as cirq.GridQubit objects."
)
# Check if the moment is empty
if len(moment.operations) == 0:
if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]):
move_moments = []
gate_noise_ops = []
# Check if the moment contains 1-qubit gates or 2-qubit gates
elif len(moment.operations[0].qubits) == 1:
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])):
gate_noise_ops = []
else:
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
move_moments = []
elif len(moment.operations[0].qubits) == 2:
cg = OneZoneConflictGraph(moment)
Expand Down
Loading