Skip to content

Commit 6a35f85

Browse files
david-plweinbe58
andauthored
Add a simple circuit fidelity analysis pass (#218)
Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 65acb76 commit 6a35f85

File tree

6 files changed

+370
-4
lines changed

6 files changed

+370
-4
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .analysis import FidelityAnalysis as FidelityAnalysis
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import Any
2+
from dataclasses import field
3+
4+
from kirin import ir
5+
from kirin.lattice import EmptyLattice
6+
from kirin.analysis import Forward
7+
from kirin.interp.value import Successor
8+
from kirin.analysis.forward import ForwardFrame
9+
10+
from ..address import AddressAnalysis
11+
12+
13+
class FidelityAnalysis(Forward):
14+
"""
15+
This analysis pass can be used to track the global addresses of qubits and wires.
16+
"""
17+
18+
keys = ["circuit.fidelity"]
19+
lattice = EmptyLattice
20+
21+
"""
22+
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
23+
"""
24+
gate_fidelity: float = 1.0
25+
26+
_current_gate_fidelity: float = field(init=False)
27+
28+
"""
29+
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
30+
"""
31+
atom_survival_probability: list[float] = field(init=False)
32+
33+
_current_atom_survival_probability: list[float] = field(init=False)
34+
35+
addr_frame: ForwardFrame = field(init=False)
36+
37+
def initialize(self):
38+
super().initialize()
39+
self._current_gate_fidelity = 1.0
40+
self._current_atom_survival_probability = [
41+
1.0 for _ in range(len(self.atom_survival_probability))
42+
]
43+
return self
44+
45+
def posthook_succ(self, frame: ForwardFrame, succ: Successor):
46+
self.gate_fidelity *= self._current_gate_fidelity
47+
for i, _current_survival in enumerate(self._current_atom_survival_probability):
48+
self.atom_survival_probability[i] *= _current_survival
49+
50+
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
51+
# NOTE: default is to conserve fidelity, so do nothing here
52+
return
53+
54+
def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
55+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
56+
57+
def run_analysis(
58+
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
59+
) -> tuple[ForwardFrame, Any]:
60+
self._run_address_analysis(method, no_raise=no_raise)
61+
return super().run_analysis(method, args, no_raise=no_raise)
62+
63+
def _run_address_analysis(self, method: ir.Method, no_raise: bool):
64+
addr_analysis = AddressAnalysis(self.dialects)
65+
addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
66+
self.addr_frame = addr_frame
67+
68+
# NOTE: make sure we have as many probabilities as we have addresses
69+
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count

src/bloqade/noise/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from . import native as native
1+
# NOTE: just to register methods
2+
from . import native as native, fidelity as fidelity

src/bloqade/noise/fidelity.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from kirin import interp
2+
from kirin.lattice import EmptyLattice
3+
4+
from bloqade.analysis.fidelity import FidelityAnalysis
5+
6+
from .native import dialect as native
7+
from .native.stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8+
from ..analysis.address import AddressQubit, AddressTuple
9+
10+
11+
@native.register(key="circuit.fidelity")
12+
class FidelityMethodTable(interp.MethodTable):
13+
14+
@interp.impl(PauliChannel)
15+
@interp.impl(CZPauliChannel)
16+
def pauli_channel(
17+
self,
18+
interp: FidelityAnalysis,
19+
frame: interp.Frame[EmptyLattice],
20+
stmt: PauliChannel | CZPauliChannel,
21+
):
22+
probs = stmt.probabilities
23+
try:
24+
ps, ps_ctrl = probs
25+
except ValueError:
26+
(ps,) = probs
27+
ps_ctrl = ()
28+
29+
p = sum(ps)
30+
p_ctrl = sum(ps_ctrl)
31+
32+
# NOTE: fidelity is just the inverse probability of any noise to occur
33+
fid = (1 - p) * (1 - p_ctrl)
34+
35+
interp._current_gate_fidelity *= fid
36+
37+
@interp.impl(AtomLossChannel)
38+
def atom_loss(
39+
self,
40+
interp: FidelityAnalysis,
41+
frame: interp.Frame[EmptyLattice],
42+
stmt: AtomLossChannel,
43+
):
44+
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
45+
addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)
46+
47+
# NOTE: get the corresponding index and reduce survival probability accordingly
48+
for qbit_address in addresses.data:
49+
assert isinstance(qbit_address, AddressQubit)
50+
index = qbit_address.data
51+
interp._current_atom_survival_probability[index] *= 1 - stmt.prob
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import math
2+
3+
import pytest
4+
5+
from bloqade import qasm2
6+
from bloqade.noise import native
7+
from bloqade.analysis.fidelity import FidelityAnalysis
8+
from bloqade.qasm2.passes.noise import NoisePass
9+
10+
noise_main = qasm2.extended.add(native.dialect)
11+
12+
13+
class NoiseTestModel(native.MoveNoiseModelABC):
14+
def parallel_cz_errors(self, ctrls, qargs, rest):
15+
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}
16+
17+
18+
def test_basic_noise():
19+
20+
@noise_main
21+
def main():
22+
q = qasm2.qreg(2)
23+
qasm2.x(q[0])
24+
return q
25+
26+
main.print()
27+
28+
fid_analysis = FidelityAnalysis(main.dialects)
29+
fid_analysis.run_analysis(main, no_raise=False)
30+
31+
assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1
32+
33+
px = 0.01
34+
py = 0.01
35+
pz = 0.01
36+
p_loss = 0.01
37+
38+
noise_params = native.GateNoiseParams(
39+
global_loss_prob=p_loss,
40+
global_px=px,
41+
global_py=py,
42+
global_pz=pz,
43+
local_px=0.002,
44+
)
45+
46+
model = NoiseTestModel()
47+
48+
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
49+
50+
main.print()
51+
52+
fid_analysis = FidelityAnalysis(main.dialects)
53+
fid_analysis.run_analysis(main, no_raise=False)
54+
55+
p_noise = noise_params.local_px + noise_params.local_py + noise_params.local_pz
56+
assert (
57+
fid_analysis.gate_fidelity
58+
== fid_analysis._current_gate_fidelity
59+
== (1 - p_noise)
60+
)
61+
62+
assert 0.9 < fid_analysis.atom_survival_probability[0] < 1
63+
assert fid_analysis.atom_survival_probability[0] == 1 - noise_params.local_loss_prob
64+
assert fid_analysis.atom_survival_probability[1] == 1
65+
66+
67+
def test_c_noise():
68+
@noise_main
69+
def main():
70+
q = qasm2.qreg(2)
71+
qasm2.cz(q[0], q[1])
72+
return q
73+
74+
main.print()
75+
76+
fid_analysis = FidelityAnalysis(main.dialects)
77+
fid_analysis.run_analysis(main, no_raise=False)
78+
79+
assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1
80+
81+
px = 0.01
82+
py = 0.01
83+
pz = 0.01
84+
p_loss = 0.01
85+
86+
noise_params = native.GateNoiseParams(
87+
global_loss_prob=p_loss,
88+
global_px=px,
89+
global_py=py,
90+
global_pz=pz,
91+
local_px=0.002,
92+
)
93+
94+
model = NoiseTestModel()
95+
96+
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
97+
98+
main.print()
99+
100+
fid_analysis = FidelityAnalysis(main.dialects)
101+
fid_analysis.run_analysis(main, no_raise=False)
102+
103+
# two cz channels (**2 for each one since we look at both control & target)
104+
fid_cz = (1 - 3 * noise_params.cz_paired_gate_px) ** 4
105+
106+
# one pauli channel
107+
fid_cz *= 1 - noise_params.global_px * 3
108+
109+
assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity
110+
assert math.isclose(fid_cz, fid_analysis.gate_fidelity, abs_tol=1e-14)
111+
112+
assert 0.9 < fid_analysis.atom_survival_probability[0] < 1
113+
assert fid_analysis.atom_survival_probability[0] == (
114+
1 - noise_params.cz_gate_loss_prob
115+
) * (1 - p_loss)
116+
117+
118+
@pytest.mark.xfail
119+
def test_if():
120+
121+
@noise_main
122+
def main():
123+
q = qasm2.qreg(1)
124+
c = qasm2.creg(1)
125+
qasm2.h(q[0])
126+
qasm2.measure(q, c)
127+
qasm2.x(q[0])
128+
qasm2.measure(q, c)
129+
130+
return c
131+
132+
@noise_main
133+
def main_if():
134+
q = qasm2.qreg(1)
135+
c = qasm2.creg(1)
136+
qasm2.h(q[0])
137+
qasm2.measure(q, c)
138+
139+
if c[0] == 0:
140+
qasm2.x(q[0])
141+
142+
qasm2.measure(q, c)
143+
return c
144+
145+
px = 0.01
146+
py = 0.01
147+
pz = 0.01
148+
p_loss = 0.01
149+
150+
noise_params = native.GateNoiseParams(
151+
global_loss_prob=p_loss,
152+
global_px=px,
153+
global_py=py,
154+
global_pz=pz,
155+
local_px=0.002,
156+
)
157+
158+
model = NoiseTestModel()
159+
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
160+
fid_analysis = FidelityAnalysis(main.dialects)
161+
fid_analysis.run_analysis(main, no_raise=False)
162+
163+
model = NoiseTestModel()
164+
NoisePass(main_if.dialects, noise_model=model, gate_noise_params=noise_params)(
165+
main_if
166+
)
167+
fid_if_analysis = FidelityAnalysis(main_if.dialects)
168+
fid_if_analysis.run_analysis(main_if, no_raise=False)
169+
170+
assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1
171+
assert (
172+
0
173+
< fid_if_analysis.atom_survival_probability[0]
174+
== fid_analysis.atom_survival_probability[0]
175+
< 1
176+
)
177+
178+
179+
@pytest.mark.xfail
180+
def test_for():
181+
182+
@noise_main
183+
def main():
184+
q = qasm2.qreg(1)
185+
c = qasm2.creg(1)
186+
qasm2.h(q[0])
187+
qasm2.measure(q, c)
188+
189+
# unrolled for loop
190+
qasm2.x(q[0])
191+
qasm2.x(q[0])
192+
qasm2.x(q[0])
193+
194+
qasm2.measure(q, c)
195+
196+
return c
197+
198+
@noise_main
199+
def main_for():
200+
q = qasm2.qreg(1)
201+
c = qasm2.creg(1)
202+
qasm2.h(q[0])
203+
qasm2.measure(q, c)
204+
205+
for _ in range(3):
206+
qasm2.x(q[0])
207+
208+
qasm2.measure(q, c)
209+
return c
210+
211+
px = 0.01
212+
py = 0.01
213+
pz = 0.01
214+
p_loss = 0.01
215+
216+
noise_params = native.GateNoiseParams(
217+
global_loss_prob=p_loss,
218+
global_px=px,
219+
global_py=py,
220+
global_pz=pz,
221+
local_px=0.002,
222+
local_loss_prob=0.03,
223+
)
224+
225+
model = NoiseTestModel()
226+
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
227+
fid_analysis = FidelityAnalysis(main.dialects)
228+
fid_analysis.run_analysis(main, no_raise=False)
229+
230+
model = NoiseTestModel()
231+
NoisePass(main_for.dialects, noise_model=model, gate_noise_params=noise_params)(
232+
main_for
233+
)
234+
235+
main_for.print()
236+
237+
fid_for_analysis = FidelityAnalysis(main_for.dialects)
238+
fid_for_analysis.run_analysis(main_for, no_raise=False)
239+
240+
assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1
241+
assert (
242+
0
243+
< fid_for_analysis.atom_survival_probability[0]
244+
== fid_analysis.atom_survival_probability[0]
245+
< 1
246+
)

test/qasm2/passes/test_heuristic_noise.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212

1313

1414
class NoiseTestModel(native.MoveNoiseModelABC):
15-
16-
@classmethod
17-
def parallel_cz_errors(cls, ctrls, qargs, rest):
15+
def parallel_cz_errors(self, ctrls, qargs, rest):
1816
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}
1917

2018

0 commit comments

Comments
 (0)