Skip to content

Commit 0958f24

Browse files
david-plweinbe58
andauthored
Properly implement fidelity analysis with SCF (#640)
Closes #592 . We are now mostly re-using the `AddressAnalysis` for walking the IR. Except for `IfElse` statements, where we need to take care when merging results from branches. To do this, I had to make a small change to `AddressAnalysis`, but that shouldn't affect anything else. Probably needs more tests, but basics seem to work quite well. Also, we might need some wrapper around running this analysis to have a decent UX when trying to study the fidelity of a given program. @weinbe58 I'm also not sure if I'm being to strict about expecting an `AddressReg` for `qubtis: IList[Qubit]` arguments. --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 4e469e3 commit 0958f24

File tree

9 files changed

+667
-106
lines changed

9 files changed

+667
-106
lines changed

src/bloqade/analysis/address/analysis.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,20 @@ class AddressAnalysis(Forward[Address]):
1818
keys = ("qubit.address",)
1919
_const_prop: const.Propagate
2020
lattice = Address
21-
next_address: int = field(init=False)
21+
_next_address: int = field(init=False)
22+
23+
# NOTE: the following are properties so we can hook into the setter in FidelityAnalysis
24+
@property
25+
def next_address(self) -> int:
26+
return self._next_address
27+
28+
@next_address.setter
29+
def next_address(self, value: int):
30+
self._next_address = value
2231

2332
def initialize(self):
2433
super().initialize()
25-
self.next_address: int = 0
34+
self.next_address = 0
2635
self._const_prop = const.Propagate(self.dialects)
2736
self._const_prop.initialize()
2837
return self
@@ -127,7 +136,9 @@ def run_lattice(
127136
case _:
128137
return Address.top()
129138

130-
def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
139+
def get_const_value(
140+
self, addr: Address, typ: Type[T] | tuple[Type[T], ...]
141+
) -> T | None:
131142
if not isinstance(addr, ConstResult):
132143
return None
133144

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
from .analysis import FidelityAnalysis as FidelityAnalysis
1+
from . import impls as impls
2+
from .analysis import (
3+
FidelityRange as FidelityRange,
4+
FidelityAnalysis as FidelityAnalysis,
5+
)
Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,116 @@
1-
from typing import Any
2-
from dataclasses import field
1+
from dataclasses import field, dataclass
32

4-
from kirin import ir
5-
from kirin.lattice import EmptyLattice
6-
from kirin.analysis import Forward
7-
from kirin.analysis.forward import ForwardFrame
3+
from ..address import AddressReg, AddressAnalysis
84

9-
from ..address import Address, AddressAnalysis
105

6+
@dataclass
7+
class FidelityRange:
8+
"""Range of fidelity for a qubit as pair of (min, max) values"""
119

12-
class FidelityAnalysis(Forward):
10+
min: float
11+
max: float
12+
13+
14+
@dataclass
15+
class FidelityAnalysis(AddressAnalysis):
1316
"""
1417
This analysis pass can be used to track the global addresses of qubits and wires.
1518
1619
## Usage examples
1720
1821
```
19-
from bloqade import qasm2
20-
from bloqade.noise import native
22+
from bloqade import squin
2123
from bloqade.analysis.fidelity import FidelityAnalysis
22-
from bloqade.qasm2.passes.noise import NoisePass
23-
24-
noise_main = qasm2.extended.add(native.dialect)
2524
26-
@noise_main
25+
@squin.kernel
2726
def main():
28-
q = qasm2.qreg(2)
29-
qasm2.x(q[0])
27+
q = squin.qalloc(1)
28+
squin.x(q[0])
29+
squin.depolarize(0.1, q[0])
3030
return q
3131
32-
NoisePass(main.dialects)(main)
33-
3432
fid_analysis = FidelityAnalysis(main.dialects)
35-
fid_analysis.run_analysis(main, no_raise=False)
33+
fid_analysis.run(main)
3634
37-
gate_fidelity = fid_analysis.gate_fidelity
38-
atom_survival_probs = fid_analysis.atom_survival_probability
35+
gate_fidelities = fid_analysis.gate_fidelities
36+
qubit_survival_probs = fid_analysis.qubit_survival_fidelities
3937
```
4038
"""
4139

42-
keys = ["circuit.fidelity"]
43-
lattice = EmptyLattice
40+
keys = ("circuit.fidelity", "qubit.address")
4441

45-
gate_fidelity: float = 1.0
46-
"""
47-
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
48-
"""
42+
gate_fidelities: list[FidelityRange] = field(init=False, default_factory=list)
43+
"""Gate fidelities of each qubit as (min, max) pairs to provide a range"""
4944

50-
atom_survival_probability: list[float] = field(init=False)
51-
"""
52-
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
53-
"""
45+
qubit_survival_fidelities: list[FidelityRange] = field(
46+
init=False, default_factory=list
47+
)
48+
"""Qubit survival fidelity given as (min, max) pairs"""
5449

55-
addr_frame: ForwardFrame[Address] = field(init=False)
50+
@property
51+
def next_address(self) -> int:
52+
return self._next_address
5653

57-
def initialize(self):
58-
super().initialize()
59-
self._current_gate_fidelity = 1.0
60-
self._current_atom_survival_probability = [
61-
1.0 for _ in range(len(self.atom_survival_probability))
62-
]
63-
return self
54+
@next_address.setter
55+
def next_address(self, value: int):
56+
# NOTE: hook into setter to make sure we always have fidelities of the correct length
57+
self._next_address = value
58+
self.extend_fidelities()
59+
60+
def extend_fidelities(self):
61+
"""Extend both fidelity lists so their length matches the number of qubits"""
62+
63+
self.extend_fidelity(self.gate_fidelities)
64+
self.extend_fidelity(self.qubit_survival_fidelities)
65+
66+
def extend_fidelity(self, fidelities: list[FidelityRange]):
67+
"""Extend a list of fidelities so its length matches the number of qubits"""
6468

65-
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
66-
# NOTE: default is to conserve fidelity, so do nothing here
67-
return
69+
n = self.qubit_count
70+
fidelities.extend([FidelityRange(1.0, 1.0) for _ in range(n - len(fidelities))])
6871

69-
def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
70-
self._run_address_analysis(method)
71-
return super().run(method, *args, **kwargs)
72+
def reset_fidelities(self):
73+
"""Reset fidelities to unity for all qubits"""
7274

73-
def _run_address_analysis(self, method: ir.Method):
74-
addr_analysis = AddressAnalysis(self.dialects)
75-
addr_frame, _ = addr_analysis.run(method=method)
76-
self.addr_frame = addr_frame
75+
self.gate_fidelities = [
76+
FidelityRange(1.0, 1.0) for _ in range(self.qubit_count)
77+
]
78+
self.qubit_survival_fidelities = [
79+
FidelityRange(1.0, 1.0) for _ in range(self.qubit_count)
80+
]
7781

78-
# NOTE: make sure we have as many probabilities as we have addresses
79-
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
82+
@staticmethod
83+
def update_fidelities(
84+
fidelities: list[FidelityRange], fidelity: float, addresses: AddressReg
85+
):
86+
"""short-hand to update both (min, max) values"""
87+
88+
for idx in addresses.data:
89+
fidelities[idx].min *= fidelity
90+
fidelities[idx].max *= fidelity
91+
92+
def update_branched_fidelities(
93+
self,
94+
fidelities: list[FidelityRange],
95+
current_fidelities: list[FidelityRange],
96+
then_fidelities: list[FidelityRange],
97+
else_fidelities: list[FidelityRange],
98+
):
99+
"""Update fidelity (min, max) values after evaluating differing branches such as IfElse"""
100+
# NOTE: make sure they are all of the same length
101+
map(
102+
self.extend_fidelity,
103+
(fidelities, current_fidelities, then_fidelities, else_fidelities),
104+
)
105+
106+
# NOTE: now we update min / max accordingly
107+
for fid, current_fid, then_fid, else_fid in zip(
108+
fidelities, current_fidelities, then_fidelities, else_fidelities
109+
):
110+
fid.min = current_fid.min * min(then_fid.min, else_fid.min)
111+
fid.max = current_fid.max * max(then_fid.max, else_fid.max)
80112

81-
def method_self(self, method: ir.Method) -> EmptyLattice:
82-
return self.lattice.bottom()
113+
def initialize(self):
114+
super().initialize()
115+
self.reset_fidelities()
116+
return self
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from kirin import interp
2+
from kirin.analysis import ForwardFrame, const
3+
from kirin.dialects import scf
4+
5+
from bloqade.analysis.address import Address, ConstResult
6+
7+
from .analysis import FidelityAnalysis
8+
9+
10+
@scf.dialect.register(key="circuit.fidelity")
11+
class __ScfMethods(interp.MethodTable):
12+
@interp.impl(scf.IfElse)
13+
def if_else(
14+
self, interp_: FidelityAnalysis, frame: ForwardFrame[Address], stmt: scf.IfElse
15+
):
16+
17+
# NOTE: store a copy of the fidelities
18+
current_gate_fidelities = interp_.gate_fidelities
19+
current_survival_fidelities = interp_.qubit_survival_fidelities
20+
21+
address_cond = frame.get(stmt.cond)
22+
23+
# NOTE: if the condition is known at compile time, run specific branch
24+
if isinstance(address_cond, ConstResult) and isinstance(
25+
const_cond := address_cond.result, const.Value
26+
):
27+
body = stmt.then_body if const_cond.data else stmt.else_body
28+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
29+
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
30+
return ret
31+
32+
# NOTE: runtime condition, evaluate both
33+
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
34+
# NOTE: reset fidelities before stepping into the then-body
35+
interp_.reset_fidelities()
36+
37+
then_results = interp_.frame_call_region(
38+
then_frame,
39+
stmt,
40+
stmt.then_body,
41+
address_cond,
42+
)
43+
then_fids = interp_.gate_fidelities
44+
then_survival = interp_.qubit_survival_fidelities
45+
46+
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
47+
# NOTE: reset again before stepping into else-body
48+
interp_.reset_fidelities()
49+
50+
else_results = interp_.frame_call_region(
51+
else_frame,
52+
stmt,
53+
stmt.else_body,
54+
address_cond,
55+
)
56+
57+
else_fids = interp_.gate_fidelities
58+
else_survival = interp_.qubit_survival_fidelities
59+
60+
# NOTE: reset one last time
61+
interp_.reset_fidelities()
62+
63+
# NOTE: now update min / max pairs accordingly
64+
interp_.update_branched_fidelities(
65+
interp_.gate_fidelities, current_gate_fidelities, then_fids, else_fids
66+
)
67+
interp_.update_branched_fidelities(
68+
interp_.qubit_survival_fidelities,
69+
current_survival_fidelities,
70+
then_survival,
71+
else_survival,
72+
)
73+
74+
# TODO: pick the non-return value
75+
if isinstance(then_results, interp.ReturnValue) and isinstance(
76+
else_results, interp.ReturnValue
77+
):
78+
return interp.ReturnValue(then_results.value.join(else_results.value))
79+
elif isinstance(then_results, interp.ReturnValue):
80+
ret = else_results
81+
elif isinstance(else_results, interp.ReturnValue):
82+
ret = then_results
83+
else:
84+
ret = interp_.join_results(then_results, else_results)
85+
86+
return ret
Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from kirin import interp
2-
from kirin.lattice import EmptyLattice
2+
from kirin.analysis import ForwardFrame
33

4+
from bloqade.analysis.address import Address, AddressReg
45
from bloqade.analysis.fidelity import FidelityAnalysis
56

67
from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
@@ -11,37 +12,68 @@
1112
class FidelityMethodTable(interp.MethodTable):
1213

1314
@interp.impl(PauliChannel)
14-
@interp.impl(CZPauliChannel)
1515
def pauli_channel(
1616
self,
17-
interp: FidelityAnalysis,
18-
frame: interp.Frame[EmptyLattice],
19-
stmt: PauliChannel | CZPauliChannel,
17+
interp_: FidelityAnalysis,
18+
frame: ForwardFrame[Address],
19+
stmt: PauliChannel,
20+
):
21+
(ps,) = stmt.probabilities
22+
fidelity = 1 - sum(ps)
23+
24+
addresses = frame.get(stmt.qargs)
25+
26+
if not isinstance(addresses, AddressReg):
27+
return ()
28+
29+
interp_.update_fidelities(interp_.gate_fidelities, fidelity, addresses)
30+
31+
return ()
32+
33+
@interp.impl(CZPauliChannel)
34+
def cz_pauli_channel(
35+
self,
36+
interp_: FidelityAnalysis,
37+
frame: ForwardFrame[Address],
38+
stmt: CZPauliChannel,
2039
):
21-
probs = stmt.probabilities
22-
try:
23-
ps, ps_ctrl = probs
24-
except ValueError:
25-
(ps,) = probs
26-
ps_ctrl = ()
40+
ps_ctrl, ps_target = stmt.probabilities
41+
42+
fidelity_ctrl = 1 - sum(ps_ctrl)
43+
fidelity_target = 1 - sum(ps_target)
44+
45+
addresses_ctrl = frame.get(stmt.ctrls)
46+
addresses_target = frame.get(stmt.qargs)
2747

28-
p = sum(ps)
29-
p_ctrl = sum(ps_ctrl)
48+
if not isinstance(addresses_ctrl, AddressReg) or not isinstance(
49+
addresses_target, AddressReg
50+
):
51+
return ()
3052

31-
# NOTE: fidelity is just the inverse probability of any noise to occur
32-
fid = (1 - p) * (1 - p_ctrl)
53+
interp_.update_fidelities(
54+
interp_.gate_fidelities, fidelity_ctrl, addresses_ctrl
55+
)
56+
interp_.update_fidelities(
57+
interp_.gate_fidelities, fidelity_target, addresses_target
58+
)
3359

34-
interp.gate_fidelity *= fid
60+
return ()
3561

3662
@interp.impl(AtomLossChannel)
3763
def atom_loss(
3864
self,
39-
interp: FidelityAnalysis,
40-
frame: interp.Frame[EmptyLattice],
65+
interp_: FidelityAnalysis,
66+
frame: ForwardFrame[Address],
4167
stmt: AtomLossChannel,
4268
):
43-
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
44-
addresses = interp.addr_frame.get(stmt.qargs)
45-
# NOTE: get the corresponding index and reduce survival probability accordingly
46-
for index in addresses.data:
47-
interp.atom_survival_probability[index] *= 1 - stmt.prob
69+
addresses = frame.get(stmt.qargs)
70+
71+
if not isinstance(addresses, AddressReg):
72+
return ()
73+
74+
fidelity = 1 - stmt.prob
75+
interp_.update_fidelities(
76+
interp_.qubit_survival_fidelities, fidelity, addresses
77+
)
78+
79+
return ()

0 commit comments

Comments
 (0)