Skip to content

Commit 4cc1e65

Browse files
committed
Fidelity calculation for IfElse
1 parent d78e25a commit 4cc1e65

File tree

3 files changed

+134
-37
lines changed

3 files changed

+134
-37
lines changed

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FidelityAnalysis(Forward):
1818
# TODO: this should be a tuple[float, float] = (mean, max)
1919
current_fidelity: float = field(init=False)
2020
global_fidelity: float = 1.0
21+
# TODO: atom loss
2122

2223
def initialize(self):
2324
super().initialize()
@@ -28,12 +29,12 @@ def posthook_succ(self, frame: ForwardFrame, succ: Successor):
2829
self.global_fidelity *= self.current_fidelity
2930

3031
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
31-
print(
32-
"no implementation for stmt "
33-
+ stmt.print_str(end="")
34-
+ " from "
35-
+ str(type(self))
36-
)
32+
# print(
33+
# "no implementation for stmt "
34+
# + stmt.print_str(end="")
35+
# + " from "
36+
# + str(type(self))
37+
# )
3738
return
3839

3940
def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):

src/bloqade/noise/fidelity.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
from kirin import interp
22
from kirin.lattice import EmptyLattice
3+
from kirin.dialects.scf import dialect as scf
4+
from kirin.dialects.scf.stmts import IfElse
35

46
from bloqade.analysis.fidelity import FidelityAnalysis
57

6-
from .native import dialect
8+
from .native import dialect as native
79
from .native.stmts import PauliChannel, CZPauliChannel
810

911

10-
@dialect.register(key="circuit.fidelity")
12+
@native.register(key="circuit.fidelity")
1113
class FidelityMethodTable(interp.MethodTable):
1214

1315
@interp.impl(PauliChannel)
1416
@interp.impl(CZPauliChannel)
15-
def single_qubit_gate(
17+
def pauli_channel(
1618
self,
1719
interp: FidelityAnalysis,
1820
frame: interp.Frame[EmptyLattice],
19-
stmt: PauliChannel,
21+
stmt: PauliChannel | CZPauliChannel,
2022
):
2123
probs = stmt.probabilities
2224
try:
@@ -32,3 +34,43 @@ def single_qubit_gate(
3234
fid = (1 - p) * (1 - p_ctrl)
3335

3436
interp.current_fidelity *= fid
37+
38+
39+
@scf.register(key="circuit.fidelity")
40+
class ScfFidelityMethodTable(FidelityMethodTable):
41+
42+
@interp.impl(IfElse)
43+
def if_else(
44+
self,
45+
interp: FidelityAnalysis,
46+
frame: interp.Frame[EmptyLattice],
47+
stmt: IfElse,
48+
):
49+
# NOTE: store current fidelity for later
50+
current_fidelity = interp.current_fidelity
51+
52+
for s in stmt.then_body.stmts():
53+
stmt_impl = interp.lookup_registry(frame=frame, stmt=s)
54+
if stmt_impl is None:
55+
continue
56+
57+
stmt_impl(interp=interp, frame=frame, stmt=s)
58+
59+
then_fidelity = interp.current_fidelity
60+
61+
# NOTE: reset fidelity of interp to check if the else body results in a worse fidelity
62+
interp.current_fidelity = current_fidelity
63+
64+
for s in stmt.else_body.stmts():
65+
stmt_impl = interp.lookup_registry(frame=frame, stmt=s)
66+
if stmt_impl is None:
67+
continue
68+
69+
stmt_impl(interp=interp, frame=frame, stmt=s)
70+
71+
else_fidelity = interp.current_fidelity
72+
73+
if then_fidelity < else_fidelity:
74+
interp.current_fidelity = then_fidelity
75+
else:
76+
interp.current_fidelity = else_fidelity

test/analysis/fidelity/test_fidelity.py

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,98 @@ def parallel_cz_errors(self, ctrls, qargs, rest):
1111
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}
1212

1313

14-
@noise_main
15-
def main():
16-
q = qasm2.qreg(2)
17-
qasm2.x(q[0])
18-
return q
14+
def test_basic_noise():
1915

16+
@noise_main
17+
def main():
18+
q = qasm2.qreg(2)
19+
qasm2.x(q[0])
20+
return q
2021

21-
main.print()
22+
main.print()
2223

23-
fid_analysis = FidelityAnalysis(main.dialects)
24-
fid_analysis.run_analysis(main)
24+
fid_analysis = FidelityAnalysis(main.dialects)
25+
fid_analysis.run_analysis(main)
2526

26-
assert fid_analysis.global_fidelity == fid_analysis.current_fidelity == 1
27+
assert fid_analysis.global_fidelity == fid_analysis.current_fidelity == 1
2728

29+
px = 0.01
30+
py = 0.01
31+
pz = 0.01
32+
p_loss = 0.01
2833

29-
px = 0.01
30-
py = 0.01
31-
pz = 0.01
32-
p_loss = 0.01
34+
noise_params = native.GateNoiseParams(
35+
global_loss_prob=p_loss,
36+
global_px=px,
37+
global_py=py,
38+
global_pz=pz,
39+
local_px=0.002,
40+
)
3341

34-
noise_params = native.GateNoiseParams(
35-
global_loss_prob=p_loss,
36-
global_px=px,
37-
global_py=py,
38-
global_pz=pz,
39-
local_px=0.002,
40-
)
42+
model = NoiseTestModel()
4143

42-
model = NoiseTestModel()
44+
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
4345

46+
main.print()
4447

45-
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
48+
fid_analysis = FidelityAnalysis(main.dialects)
49+
fid_analysis.run_analysis(main, no_raise=False)
4650

51+
p_noise = noise_params.local_px + noise_params.local_py + noise_params.local_pz
52+
assert (
53+
fid_analysis.global_fidelity == fid_analysis.current_fidelity == (1 - p_noise)
54+
)
4755

48-
main.print()
4956

50-
fid_analysis = FidelityAnalysis(main.dialects)
51-
fid_analysis.run_analysis(main, no_raise=False)
57+
def test_if():
5258

53-
p_noise = noise_params.local_px + noise_params.local_py + noise_params.local_pz
54-
assert fid_analysis.global_fidelity == fid_analysis.current_fidelity == (1 - p_noise)
59+
@noise_main
60+
def main():
61+
q = qasm2.qreg(1)
62+
c = qasm2.creg(1)
63+
qasm2.h(q[0])
64+
qasm2.measure(q, c)
65+
qasm2.x(q[0])
66+
qasm2.measure(q, c)
67+
68+
return c
69+
70+
@noise_main
71+
def main_if():
72+
q = qasm2.qreg(1)
73+
c = qasm2.creg(1)
74+
qasm2.h(q[0])
75+
qasm2.measure(q, c)
76+
77+
if c[0] == 0:
78+
qasm2.x(q[0])
79+
80+
qasm2.measure(q, c)
81+
return c
82+
83+
px = 0.01
84+
py = 0.01
85+
pz = 0.01
86+
p_loss = 0.01
87+
88+
noise_params = native.GateNoiseParams(
89+
global_loss_prob=p_loss,
90+
global_px=px,
91+
global_py=py,
92+
global_pz=pz,
93+
local_px=0.002,
94+
)
95+
96+
model = NoiseTestModel()
97+
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
98+
fid_analysis = FidelityAnalysis(main.dialects)
99+
fid_analysis.run_analysis(main, no_raise=False)
100+
101+
model = NoiseTestModel()
102+
NoisePass(main_if.dialects, noise_model=model, gate_noise_params=noise_params)(
103+
main_if
104+
)
105+
fid_if_analysis = FidelityAnalysis(main_if.dialects)
106+
fid_if_analysis.run_analysis(main_if, no_raise=False)
107+
108+
assert 0 < fid_if_analysis.global_fidelity == fid_analysis.global_fidelity < 1

0 commit comments

Comments
 (0)