Skip to content

Commit 0922fdd

Browse files
committed
Track atom loss per qubit rather than globally
1 parent f98cd08 commit 0922fdd

File tree

3 files changed

+45
-14
lines changed

3 files changed

+45
-14
lines changed

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any
12
from dataclasses import field
23

34
from kirin import ir
@@ -6,6 +7,8 @@
67
from kirin.interp.value import Successor
78
from kirin.analysis.forward import ForwardFrame
89

10+
from ..address import AddressAnalysis
11+
912

1013
class FidelityAnalysis(Forward):
1114
"""
@@ -25,23 +28,42 @@ class FidelityAnalysis(Forward):
2528
"""
2629
The probability that all atoms survive for the entirety of the analysed program. Decreases whenever an atomic loss channel is encountered.
2730
"""
28-
atom_survival_probability: float = 1.0
31+
atom_survival_probability: list[float] = field(init=False)
32+
33+
_current_atom_survival_probability: list[float] = field(init=False)
2934

30-
_current_atom_survival_probability: float = field(init=False)
35+
addr_frame: ForwardFrame = field(init=False)
3136

3237
def initialize(self):
3338
super().initialize()
3439
self._current_gate_fidelity = 1.0
35-
self._current_atom_survival_probability = 1.0
40+
self._current_atom_survival_probability = [
41+
1.0 for _ in range(len(self.atom_survival_probability))
42+
]
3643
return self
3744

3845
def posthook_succ(self, frame: ForwardFrame, succ: Successor):
3946
self.gate_fidelity *= self._current_gate_fidelity
40-
self.atom_survival_probability *= self._current_atom_survival_probability
47+
for i, _current_survival in enumerate(self._current_atom_survival_probability):
48+
self.atom_survival_probability[i] *= _current_survival
4149

4250
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
4351
# NOTE: default is to conserve fidelity, so do nothing here
4452
return
4553

4654
def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
4755
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
56+
57+
def run_analysis(
58+
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
59+
) -> tuple[ForwardFrame, Any]:
60+
self._run_address_analysis(method, no_raise=no_raise)
61+
return super().run_analysis(method, args, no_raise=no_raise)
62+
63+
def _run_address_analysis(self, method: ir.Method, no_raise: bool):
64+
addr_analysis = AddressAnalysis(self.dialects)
65+
addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
66+
self.addr_frame = addr_frame
67+
68+
# NOTE: make sure we have as many probabilities as we have addresses
69+
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count

src/bloqade/noise/fidelity.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .native import dialect as native
77
from .native.stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8+
from ..analysis.address import AddressQubit, AddressTuple
89

910

1011
@native.register(key="circuit.fidelity")
@@ -40,4 +41,11 @@ def atom_loss(
4041
frame: interp.Frame[EmptyLattice],
4142
stmt: AtomLossChannel,
4243
):
43-
interp._current_atom_survival_probability *= 1 - stmt.prob
44+
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
45+
addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)
46+
47+
# NOTE: get the corresponding index and reduce survival probability accordingly
48+
for qbit_address in addresses.data:
49+
assert isinstance(qbit_address, AddressQubit)
50+
index = qbit_address.data
51+
interp._current_atom_survival_probability[index] *= 1 - stmt.prob

test/analysis/fidelity/test_fidelity.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def main():
5959
== (1 - p_noise)
6060
)
6161

62-
assert 0.9 < fid_analysis.atom_survival_probability < 1
63-
assert fid_analysis.atom_survival_probability == 1 - noise_params.local_loss_prob
62+
assert 0.9 < fid_analysis.atom_survival_probability[0] < 1
63+
assert fid_analysis.atom_survival_probability[0] == 1 - noise_params.local_loss_prob
64+
assert fid_analysis.atom_survival_probability[1] == 1
6465

6566

6667
def test_c_noise():
@@ -108,10 +109,10 @@ def main():
108109
assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity
109110
assert math.isclose(fid_cz, fid_analysis.gate_fidelity, abs_tol=1e-14)
110111

111-
assert 0.9 < fid_analysis.atom_survival_probability < 1
112-
assert fid_analysis.atom_survival_probability == (
112+
assert 0.9 < fid_analysis.atom_survival_probability[0] < 1
113+
assert fid_analysis.atom_survival_probability[0] == (
113114
1 - noise_params.cz_gate_loss_prob
114-
) ** 2 * (1 - p_loss)
115+
) * (1 - p_loss)
115116

116117

117118
@pytest.mark.xfail
@@ -169,8 +170,8 @@ def main_if():
169170
assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1
170171
assert (
171172
0
172-
< fid_if_analysis.atom_survival_probability
173-
== fid_analysis.atom_survival_probability
173+
< fid_if_analysis.atom_survival_probability[0]
174+
== fid_analysis.atom_survival_probability[0]
174175
< 1
175176
)
176177

@@ -239,7 +240,7 @@ def main_for():
239240
assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1
240241
assert (
241242
0
242-
< fid_for_analysis.atom_survival_probability
243-
== fid_analysis.atom_survival_probability
243+
< fid_for_analysis.atom_survival_probability[0]
244+
== fid_analysis.atom_survival_probability[0]
244245
< 1
245246
)

0 commit comments

Comments
 (0)