@@ -98,7 +98,10 @@ def __post_init__(self):
9898
9999 @staticmethod
100100 def validate_moments (moments : Iterable [cirq .Moment ]):
101- allowed_target_gates : frozenset [cirq .GateFamily ] = cirq .CZTargetGateset ().gates
101+ reset_family = cirq .GateFamily (gate = cirq .ResetChannel , ignore_global_phase = True )
102+ allowed_target_gates : frozenset [cirq .GateFamily ] = cirq .CZTargetGateset (
103+ additional_gates = [reset_family ]
104+ ).gates
102105
103106 for moment in moments :
104107 for operation in moment :
@@ -117,7 +120,7 @@ def validate_moments(moments: Iterable[cirq.Moment]):
117120 )
118121
119122 def parallel_cz_errors (
120- self , ctrls : list [int ], qargs : list [int ], rest : list [int ]
123+ self , ctrls : Sequence [int ], qargs : Sequence [int ], rest : Sequence [int ]
121124 ) -> dict [tuple [float , float , float , float ], list [int ]]:
122125 raise NotImplementedError (
123126 "This noise model doesn't support rewrites on bloqade kernels, but should be used with cirq."
@@ -245,8 +248,15 @@ def noisy_moment(self, moment, system_qubits):
245248 # Moment with original ops
246249 original_moment = moment
247250
251+ no_noise_condition = (
252+ len (moment .operations ) == 0
253+ or cirq .is_measurement (moment .operations [0 ])
254+ or isinstance (moment .operations [0 ].gate , cirq .ResetChannel )
255+ or isinstance (moment .operations [0 ].gate , cirq .BitFlipChannel )
256+ )
257+
248258 # Check if the moment is empty
249- if len ( moment . operations ) == 0 :
259+ if no_noise_condition :
250260 move_noise_ops = []
251261 gate_noise_ops = []
252262 # Check if the moment contains 1-qubit gates or 2-qubit gates
@@ -319,20 +329,84 @@ def noisy_moments(
319329
320330 # Split into moments with only 1Q and 2Q gates
321331 moments_1q = [
322- cirq .Moment ([op for op in moment .operations if len (op .qubits ) == 1 ])
332+ cirq .Moment (
333+ [
334+ op
335+ for op in moment .operations
336+ if (len (op .qubits ) == 1 )
337+ and (not cirq .is_measurement (op ))
338+ and (not isinstance (op .gate , cirq .ResetChannel ))
339+ ]
340+ )
323341 for moment in moments
324342 ]
325343 moments_2q = [
326- cirq .Moment ([op for op in moment .operations if len (op .qubits ) == 2 ])
344+ cirq .Moment (
345+ [
346+ op
347+ for op in moment .operations
348+ if (len (op .qubits ) == 2 ) and (not cirq .is_measurement (op ))
349+ ]
350+ )
327351 for moment in moments
328352 ]
329353
330- assert len (moments_1q ) == len (moments_2q )
354+ moments_measurement = [
355+ cirq .Moment (
356+ [
357+ op
358+ for op in moment .operations
359+ if (cirq .is_measurement (op ))
360+ or (isinstance (op .gate , cirq .ResetChannel ))
361+ ]
362+ )
363+ for moment in moments
364+ ]
365+
366+ assert len (moments_1q ) == len (moments_2q ) == len (moments_measurement )
331367
332368 interleaved_moments = []
369+
370+ def count_remaining_cz_moments (moments_2q ):
371+ remaining_cz_counts = []
372+ count = 0
373+ for m in moments_2q [::- 1 ]:
374+ if any (isinstance (op .gate , cirq .CZPowGate ) for op in m .operations ):
375+ count += 1
376+ remaining_cz_counts = [count ] + remaining_cz_counts
377+ return remaining_cz_counts
378+
379+ remaining_cz_moments = count_remaining_cz_moments (moments_2q )
380+
381+ pm = 2 * self .sitter_pauli_rates [0 ]
382+ ps = 2 * self .cz_unpaired_pauli_rates [0 ]
383+
384+ # probability of a bitflip error for a sitting, unpaired qubit during a move/cz/move cycle.
385+ heuristic_1step_bitflip_error : float = (
386+ 2 * pm * (1 - ps ) * (1 - pm ) + (1 - pm ) ** 2 * ps + pm ** 2 * ps
387+ )
388+
333389 for idx , moment in enumerate (moments_1q ):
334390 interleaved_moments .append (moment )
335391 interleaved_moments .append (moments_2q [idx ])
392+ # Measurements on Gemini will be at the end, so for circuits with mid-circuit measurements we will insert a
393+ # bitflip error proportional to the number of moments left in the circuit to account for the decoherence
394+ # that will happen before the final terminal measurement.
395+ measured_qubits = []
396+ for op in moments_measurement [idx ].operations :
397+ if cirq .is_measurement (op ):
398+ measured_qubits += list (op .qubits )
399+ # probability of a bitflip error should be Binomial(moments_left,heuristic_1step_bitflip_error)
400+ delayed_measurement_error = (
401+ 1
402+ - (1 - 2 * heuristic_1step_bitflip_error ) ** (remaining_cz_moments [idx ])
403+ ) / 2
404+ interleaved_moments .append (
405+ cirq .Moment (
406+ cirq .bit_flip (delayed_measurement_error ).on_each (measured_qubits )
407+ )
408+ )
409+ interleaved_moments .append (moments_measurement [idx ])
336410
337411 interleaved_circuit = cirq .Circuit .from_moments (* interleaved_moments )
338412
@@ -368,14 +442,21 @@ def noisy_moment(self, moment, system_qubits):
368442 "all qubits in the circuit must be defined as cirq.GridQubit objects."
369443 )
370444 # Check if the moment is empty
371- if len (moment .operations ) == 0 :
445+ if len (moment .operations ) == 0 or cirq . is_measurement ( moment . operations [ 0 ]) :
372446 move_moments = []
373447 gate_noise_ops = []
374448 # Check if the moment contains 1-qubit gates or 2-qubit gates
375449 elif len (moment .operations [0 ].qubits ) == 1 :
376- gate_noise_ops , _ = self ._single_qubit_moment_noise_ops (
377- moment , system_qubits
378- )
450+ if (
451+ (isinstance (moment .operations [0 ].gate , cirq .ResetChannel ))
452+ or (cirq .is_measurement (moment .operations [0 ]))
453+ or (isinstance (moment .operations [0 ].gate , cirq .BitFlipChannel ))
454+ ):
455+ gate_noise_ops = []
456+ else :
457+ gate_noise_ops , _ = self ._single_qubit_moment_noise_ops (
458+ moment , system_qubits
459+ )
379460 move_moments = []
380461 elif len (moment .operations [0 ].qubits ) == 2 :
381462 cg = OneZoneConflictGraph (moment )
0 commit comments