Skip to content

Commit 053e798

Browse files
authored
Merge branch 'main' into david/592-3-fidelity-analysis
2 parents 6755015 + 4e469e3 commit 053e798

File tree

6 files changed

+235
-32
lines changed

6 files changed

+235
-32
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,22 @@ def for_loop(
366366
if iter_type is None:
367367
return interp_.eval_fallback(frame, stmt)
368368

369+
body_values = {}
369370
for value in iterable:
370371
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
371372
loop_vars = interp_.frame_call_region(
372373
body_frame, stmt, stmt.body, value, *loop_vars
373374
)
374375

376+
for ssa, val in body_frame.entries.items():
377+
body_values[ssa] = body_values.setdefault(ssa, val).join(val)
378+
375379
if loop_vars is None:
376380
loop_vars = ()
377381

378382
elif isinstance(loop_vars, interp.ReturnValue):
383+
frame.set_values(body_frame.entries.keys(), body_frame.entries.values())
379384
return loop_vars
380385

386+
frame.set_values(body_values.keys(), body_values.values())
381387
return loop_vars

src/bloqade/cirq_utils/noise/model.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def __post_init__(self):
9898

9999
@staticmethod
100100
def validate_moments(moments: Iterable[cirq.Moment]):
101-
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset().gates
101+
reset_family = cirq.GateFamily(gate=cirq.ResetChannel, ignore_global_phase=True)
102+
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(
103+
additional_gates=[reset_family]
104+
).gates
102105

103106
for moment in moments:
104107
for operation in moment:
@@ -117,7 +120,7 @@ def validate_moments(moments: Iterable[cirq.Moment]):
117120
)
118121

119122
def parallel_cz_errors(
120-
self, ctrls: list[int], qargs: list[int], rest: list[int]
123+
self, ctrls: Sequence[int], qargs: Sequence[int], rest: Sequence[int]
121124
) -> dict[tuple[float, float, float, float], list[int]]:
122125
raise NotImplementedError(
123126
"This noise model doesn't support rewrites on bloqade kernels, but should be used with cirq."
@@ -245,8 +248,15 @@ def noisy_moment(self, moment, system_qubits):
245248
# Moment with original ops
246249
original_moment = moment
247250

251+
no_noise_condition = (
252+
len(moment.operations) == 0
253+
or cirq.is_measurement(moment.operations[0])
254+
or isinstance(moment.operations[0].gate, cirq.ResetChannel)
255+
or isinstance(moment.operations[0].gate, cirq.BitFlipChannel)
256+
)
257+
248258
# Check if the moment is empty
249-
if len(moment.operations) == 0:
259+
if no_noise_condition:
250260
move_noise_ops = []
251261
gate_noise_ops = []
252262
# Check if the moment contains 1-qubit gates or 2-qubit gates
@@ -319,20 +329,84 @@ def noisy_moments(
319329

320330
# Split into moments with only 1Q and 2Q gates
321331
moments_1q = [
322-
cirq.Moment([op for op in moment.operations if len(op.qubits) == 1])
332+
cirq.Moment(
333+
[
334+
op
335+
for op in moment.operations
336+
if (len(op.qubits) == 1)
337+
and (not cirq.is_measurement(op))
338+
and (not isinstance(op.gate, cirq.ResetChannel))
339+
]
340+
)
323341
for moment in moments
324342
]
325343
moments_2q = [
326-
cirq.Moment([op for op in moment.operations if len(op.qubits) == 2])
344+
cirq.Moment(
345+
[
346+
op
347+
for op in moment.operations
348+
if (len(op.qubits) == 2) and (not cirq.is_measurement(op))
349+
]
350+
)
327351
for moment in moments
328352
]
329353

330-
assert len(moments_1q) == len(moments_2q)
354+
moments_measurement = [
355+
cirq.Moment(
356+
[
357+
op
358+
for op in moment.operations
359+
if (cirq.is_measurement(op))
360+
or (isinstance(op.gate, cirq.ResetChannel))
361+
]
362+
)
363+
for moment in moments
364+
]
365+
366+
assert len(moments_1q) == len(moments_2q) == len(moments_measurement)
331367

332368
interleaved_moments = []
369+
370+
def count_remaining_cz_moments(moments_2q):
371+
remaining_cz_counts = []
372+
count = 0
373+
for m in moments_2q[::-1]:
374+
if any(isinstance(op.gate, cirq.CZPowGate) for op in m.operations):
375+
count += 1
376+
remaining_cz_counts = [count] + remaining_cz_counts
377+
return remaining_cz_counts
378+
379+
remaining_cz_moments = count_remaining_cz_moments(moments_2q)
380+
381+
pm = 2 * self.sitter_pauli_rates[0]
382+
ps = 2 * self.cz_unpaired_pauli_rates[0]
383+
384+
# probability of a bitflip error for a sitting, unpaired qubit during a move/cz/move cycle.
385+
heuristic_1step_bitflip_error: float = (
386+
2 * pm * (1 - ps) * (1 - pm) + (1 - pm) ** 2 * ps + pm**2 * ps
387+
)
388+
333389
for idx, moment in enumerate(moments_1q):
334390
interleaved_moments.append(moment)
335391
interleaved_moments.append(moments_2q[idx])
392+
# Measurements on Gemini will be at the end, so for circuits with mid-circuit measurements we will insert a
393+
# bitflip error proportional to the number of moments left in the circuit to account for the decoherence
394+
# that will happen before the final terminal measurement.
395+
measured_qubits = []
396+
for op in moments_measurement[idx].operations:
397+
if cirq.is_measurement(op):
398+
measured_qubits += list(op.qubits)
399+
# probability of a bitflip error should be Binomial(moments_left,heuristic_1step_bitflip_error)
400+
delayed_measurement_error = (
401+
1
402+
- (1 - 2 * heuristic_1step_bitflip_error) ** (remaining_cz_moments[idx])
403+
) / 2
404+
interleaved_moments.append(
405+
cirq.Moment(
406+
cirq.bit_flip(delayed_measurement_error).on_each(measured_qubits)
407+
)
408+
)
409+
interleaved_moments.append(moments_measurement[idx])
336410

337411
interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments)
338412

@@ -368,14 +442,21 @@ def noisy_moment(self, moment, system_qubits):
368442
"all qubits in the circuit must be defined as cirq.GridQubit objects."
369443
)
370444
# Check if the moment is empty
371-
if len(moment.operations) == 0:
445+
if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]):
372446
move_moments = []
373447
gate_noise_ops = []
374448
# Check if the moment contains 1-qubit gates or 2-qubit gates
375449
elif len(moment.operations[0].qubits) == 1:
376-
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
377-
moment, system_qubits
378-
)
450+
if (
451+
(isinstance(moment.operations[0].gate, cirq.ResetChannel))
452+
or (cirq.is_measurement(moment.operations[0]))
453+
or (isinstance(moment.operations[0].gate, cirq.BitFlipChannel))
454+
):
455+
gate_noise_ops = []
456+
else:
457+
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
458+
moment, system_qubits
459+
)
379460
move_moments = []
380461
elif len(moment.operations[0].qubits) == 2:
381462
cg = OneZoneConflictGraph(moment)

src/bloqade/cirq_utils/parallelize.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def auto_similarity(
119119
flattened_circuit: list[GateOperation] = list(cirq.flatten_op_tree(circuit))
120120
weights = {}
121121
for i in range(len(flattened_circuit)):
122+
if not cirq.has_unitary(flattened_circuit[i]):
123+
continue
122124
for j in range(i + 1, len(flattened_circuit)):
125+
if not cirq.has_unitary(flattened_circuit[j]):
126+
continue
123127
op1 = flattened_circuit[i]
124128
op2 = flattened_circuit[j]
125129
if can_be_parallel(op1, op2):
@@ -297,14 +301,20 @@ def colorize(
297301
for epoch in epochs:
298302
oneq_gates = []
299303
twoq_gates = []
304+
nonunitary_gates = []
300305
for gate in epoch:
301-
if len(gate.val.qubits) == 1:
306+
if not cirq.has_unitary(gate.val):
307+
nonunitary_gates.append(gate.val)
308+
elif len(gate.val.qubits) == 1:
302309
oneq_gates.append(gate.val)
303310
elif len(gate.val.qubits) == 2:
304311
twoq_gates.append(gate.val)
305312
else:
306313
raise RuntimeError("Unsupported gate type")
307314

315+
if len(nonunitary_gates) > 0:
316+
yield nonunitary_gates
317+
308318
if len(oneq_gates) > 0:
309319
yield oneq_gates
310320

test/analysis/address/test_qubit_analysis.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from util import collect_address_types
33
from kirin.analysis import const
4-
from kirin.dialects import ilist
4+
from kirin.dialects import scf, ilist
55

66
from bloqade import qubit, squin
77
from bloqade.analysis import address
@@ -265,3 +265,30 @@ def main():
265265

266266
assert ret == address.AddressReg(data=tuple(range(20)))
267267
assert analysis.qubit_count == 20
268+
269+
270+
def test_for_loop_body_values():
271+
@squin.kernel
272+
def main():
273+
q = squin.qalloc(4)
274+
for i in range(1, len(q)):
275+
squin.cx(q[0], q[i])
276+
277+
address_analysis = address.AddressAnalysis(main.dialects)
278+
frame, result = address_analysis.run(main)
279+
main.print(analysis=frame.entries)
280+
281+
(for_stmt,) = tuple(
282+
stmt for stmt in main.callable_region.walk() if isinstance(stmt, scf.For)
283+
)
284+
285+
for_analysis = [
286+
value
287+
for stmt in for_stmt.body.walk()
288+
for value in frame.get_values(stmt.results)
289+
]
290+
291+
assert address.AddressQubit(data=0) in for_analysis
292+
assert address.ConstResult(const.Value(0)) in for_analysis
293+
assert address.ConstResult(const.Value(None)) in for_analysis
294+
assert address.Unknown() in for_analysis

test/cirq_utils/noise/test_noise_models.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616

17-
def create_ghz_circuit(qubits):
17+
def create_ghz_circuit(qubits, measurements: bool = False):
1818
n = len(qubits)
1919
circuit = cirq.Circuit()
2020

@@ -24,26 +24,41 @@ def create_ghz_circuit(qubits):
2424
# Step 2: CNOT chain from qubit i to i+1
2525
for i in range(n - 1):
2626
circuit.append(cirq.CNOT(qubits[i], qubits[i + 1]))
27+
if measurements:
28+
circuit.append(cirq.measure(qubits[i]))
29+
circuit.append(cirq.reset(qubits[i]))
30+
31+
if measurements:
32+
circuit.append(cirq.measure(qubits[-1]))
33+
circuit.append(cirq.reset(qubits[-1]))
2734

2835
return circuit
2936

3037

3138
@pytest.mark.parametrize(
32-
"model,qubits",
39+
"model,qubits,measurements",
3340
[
34-
(GeminiOneZoneNoiseModel(), None),
41+
(GeminiOneZoneNoiseModel(), None, False),
42+
(
43+
GeminiOneZoneNoiseModelConflictGraphMoves(),
44+
cirq.GridQubit.rect(rows=1, cols=2),
45+
False,
46+
),
47+
(GeminiTwoZoneNoiseModel(), None, False),
48+
(GeminiOneZoneNoiseModel(), None, True),
3549
(
3650
GeminiOneZoneNoiseModelConflictGraphMoves(),
3751
cirq.GridQubit.rect(rows=1, cols=2),
52+
True,
3853
),
39-
(GeminiTwoZoneNoiseModel(), None),
54+
(GeminiTwoZoneNoiseModel(), None, True),
4055
],
4156
)
42-
def test_simple_model(model: cirq.NoiseModel, qubits):
57+
def test_simple_model(model: cirq.NoiseModel, qubits, measurements: bool):
4358
if qubits is None:
4459
qubits = cirq.LineQubit.range(2)
4560

46-
circuit = create_ghz_circuit(qubits)
61+
circuit = create_ghz_circuit(qubits, measurements=measurements)
4762

4863
with pytest.raises(ValueError):
4964
# make sure only native gate set is supported
@@ -74,13 +89,25 @@ def test_simple_model(model: cirq.NoiseModel, qubits):
7489
for i in range(4):
7590
pops_bloqade[i] += abs(ket[i]) ** 2 / nshots
7691

77-
for pops in (pops_bloqade, pops_cirq):
78-
assert math.isclose(pops[0], 0.5, abs_tol=1e-1)
79-
assert math.isclose(pops[3], 0.5, abs_tol=1e-1)
80-
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
81-
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
82-
83-
assert pops[0] < 0.5001
84-
assert pops[3] < 0.5001
85-
assert pops[1] >= 0.0
86-
assert pops[2] >= 0.0
92+
if measurements is True:
93+
for pops in (pops_bloqade, pops_cirq):
94+
assert math.isclose(pops[0], 1.0, abs_tol=1e-1)
95+
assert math.isclose(pops[3], 0.0, abs_tol=1e-1)
96+
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
97+
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
98+
99+
assert pops[0] > 0.99
100+
assert pops[3] >= 0.0
101+
assert pops[1] >= 0.0
102+
assert pops[2] >= 0.0
103+
else:
104+
for pops in (pops_bloqade, pops_cirq):
105+
assert math.isclose(pops[0], 0.5, abs_tol=1e-1)
106+
assert math.isclose(pops[3], 0.5, abs_tol=1e-1)
107+
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
108+
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
109+
110+
assert pops[0] < 0.5001
111+
assert pops[3] < 0.5001
112+
assert pops[1] >= 0.0
113+
assert pops[2] >= 0.0

0 commit comments

Comments
 (0)