diff --git a/src/bloqade/cirq_utils/noise/model.py b/src/bloqade/cirq_utils/noise/model.py index 4e3c064d..5ede68d6 100644 --- a/src/bloqade/cirq_utils/noise/model.py +++ b/src/bloqade/cirq_utils/noise/model.py @@ -56,29 +56,21 @@ class GeminiNoiseModelABC(cirq.NoiseModel, MoveNoiseModelABC): """The correlated CZ error rates as a dictionary""" def __post_init__(self): - is_ambiguous = ( - self.cz_paired_correlated_rates is not None - and self.cz_paired_error_probabilities is not None - ) - if is_ambiguous: - raise ValueError( - "Received both `cz_paired_correlated_rates` and `cz_paired_error_probabilities` as input. This is ambiguous, please only set one." - ) - - use_default = ( + if ( self.cz_paired_correlated_rates is None and self.cz_paired_error_probabilities is None - ) - if use_default: + ): # NOTE: no input, set to default value; weird setattr for frozen dataclass object.__setattr__( self, "cz_paired_error_probabilities", _default_cz_paired_correlated_rates(), ) - return + elif ( + self.cz_paired_correlated_rates is not None + and self.cz_paired_error_probabilities is None + ): - if self.cz_paired_correlated_rates is not None: if self.cz_paired_correlated_rates.shape != (4, 4): raise ValueError( "Expected a 4x4 array of probabilities for cz_paired_correlated_rates" @@ -90,15 +82,19 @@ def __post_init__(self): "cz_paired_error_probabilities", correlated_noise_array_to_dict(self.cz_paired_correlated_rates), ) - return - - assert ( - self.cz_paired_error_probabilities is not None - ), "This error should not happen! Please report this issue." + elif ( + self.cz_paired_correlated_rates is not None + and self.cz_paired_error_probabilities is not None + ): + raise ValueError( + "Received both `cz_paired_correlated_rates` and `cz_paired_correlated_rates` as input. This is ambiguous, please only set one." + ) @staticmethod def validate_moments(moments: Iterable[cirq.Moment]): - allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates + reset_family = cirq.GateFamily(gate=cirq.ResetChannel, ignore_global_phase=True) + allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(additional_gates=[reset_family]).gates + # allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates for moment in moments: for operation in moment: @@ -246,14 +242,18 @@ def noisy_moment(self, moment, system_qubits): original_moment = moment # Check if the moment is empty - if len(moment.operations) == 0: + if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]): move_noise_ops = [] gate_noise_ops = [] # Check if the moment contains 1-qubit gates or 2-qubit gates elif len(moment.operations[0].qubits) == 1: - gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops( - moment, system_qubits - ) + if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])): + move_noise_ops = [] + gate_noise_ops = [] + else: + gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops( + moment, system_qubits + ) elif len(moment.operations[0].qubits) == 2: control_qubits = [op.qubits[0] for op in moment.operations] target_qubits = [op.qubits[1] for op in moment.operations] @@ -319,20 +319,26 @@ def noisy_moments( # Split into moments with only 1Q and 2Q gates moments_1q = [ - cirq.Moment([op for op in moment.operations if len(op.qubits) == 1]) + cirq.Moment([op for op in moment.operations if (len(op.qubits) == 1) and (not cirq.is_measurement(op)) and (not isinstance(op.gate, cirq.ResetChannel))]) for moment in moments ] moments_2q = [ - cirq.Moment([op for op in moment.operations if len(op.qubits) == 2]) + cirq.Moment([op for op in moment.operations if (len(op.qubits) == 2) and (not cirq.is_measurement(op))]) for moment in moments ] - assert len(moments_1q) == len(moments_2q) + moments_measurement = [ + cirq.Moment([op for op in moment.operations if (cirq.is_measurement(op)) or (isinstance(op.gate, cirq.ResetChannel))]) + for moment in moments + ] + + assert len(moments_1q) == len(moments_2q) == len(moments_measurement) interleaved_moments = [] for idx, moment in enumerate(moments_1q): interleaved_moments.append(moment) interleaved_moments.append(moments_2q[idx]) + interleaved_moments.append(moments_measurement[idx]) interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments) @@ -368,14 +374,17 @@ def noisy_moment(self, moment, system_qubits): "all qubits in the circuit must be defined as cirq.GridQubit objects." ) # Check if the moment is empty - if len(moment.operations) == 0: + if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]): move_moments = [] gate_noise_ops = [] # Check if the moment contains 1-qubit gates or 2-qubit gates elif len(moment.operations[0].qubits) == 1: - gate_noise_ops, _ = self._single_qubit_moment_noise_ops( - moment, system_qubits - ) + if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])): + gate_noise_ops = [] + else: + gate_noise_ops, _ = self._single_qubit_moment_noise_ops( + moment, system_qubits + ) move_moments = [] elif len(moment.operations[0].qubits) == 2: cg = OneZoneConflictGraph(moment)