|
1 | | -from typing import Any |
2 | | -from dataclasses import field |
| 1 | +from dataclasses import field, dataclass |
3 | 2 |
|
4 | | -from kirin import ir |
5 | | -from kirin.lattice import EmptyLattice |
6 | | -from kirin.analysis import Forward |
7 | | -from kirin.analysis.forward import ForwardFrame |
| 3 | +from ..address import AddressReg, AddressAnalysis |
8 | 4 |
|
9 | | -from ..address import Address, AddressAnalysis |
10 | 5 |
|
| 6 | +@dataclass |
| 7 | +class FidelityRange: |
| 8 | + """Range of fidelity for a qubit as pair of (min, max) values""" |
11 | 9 |
|
12 | | -class FidelityAnalysis(Forward): |
| 10 | + min: float |
| 11 | + max: float |
| 12 | + |
| 13 | + |
| 14 | +@dataclass |
| 15 | +class FidelityAnalysis(AddressAnalysis): |
13 | 16 | """ |
14 | 17 | This analysis pass can be used to track the global addresses of qubits and wires. |
15 | 18 |
|
16 | 19 | ## Usage examples |
17 | 20 |
|
18 | 21 | ``` |
19 | | - from bloqade import qasm2 |
20 | | - from bloqade.noise import native |
| 22 | + from bloqade import squin |
21 | 23 | from bloqade.analysis.fidelity import FidelityAnalysis |
22 | | - from bloqade.qasm2.passes.noise import NoisePass |
23 | | -
|
24 | | - noise_main = qasm2.extended.add(native.dialect) |
25 | 24 |
|
26 | | - @noise_main |
| 25 | + @squin.kernel |
27 | 26 | def main(): |
28 | | - q = qasm2.qreg(2) |
29 | | - qasm2.x(q[0]) |
| 27 | + q = squin.qalloc(1) |
| 28 | + squin.x(q[0]) |
| 29 | + squin.depolarize(0.1, q[0]) |
30 | 30 | return q |
31 | 31 |
|
32 | | - NoisePass(main.dialects)(main) |
33 | | -
|
34 | 32 | fid_analysis = FidelityAnalysis(main.dialects) |
35 | | - fid_analysis.run_analysis(main, no_raise=False) |
| 33 | + fid_analysis.run(main) |
36 | 34 |
|
37 | | - gate_fidelity = fid_analysis.gate_fidelity |
38 | | - atom_survival_probs = fid_analysis.atom_survival_probability |
| 35 | + gate_fidelities = fid_analysis.gate_fidelities |
| 36 | + qubit_survival_probs = fid_analysis.qubit_survival_fidelities |
39 | 37 | ``` |
40 | 38 | """ |
41 | 39 |
|
42 | | - keys = ["circuit.fidelity"] |
43 | | - lattice = EmptyLattice |
| 40 | + keys = ("circuit.fidelity", "qubit.address") |
44 | 41 |
|
45 | | - gate_fidelity: float = 1.0 |
46 | | - """ |
47 | | - The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered. |
48 | | - """ |
| 42 | + gate_fidelities: list[FidelityRange] = field(init=False, default_factory=list) |
| 43 | + """Gate fidelities of each qubit as (min, max) pairs to provide a range""" |
49 | 44 |
|
50 | | - atom_survival_probability: list[float] = field(init=False) |
51 | | - """ |
52 | | - The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register. |
53 | | - """ |
| 45 | + qubit_survival_fidelities: list[FidelityRange] = field( |
| 46 | + init=False, default_factory=list |
| 47 | + ) |
| 48 | + """Qubit survival fidelity given as (min, max) pairs""" |
54 | 49 |
|
55 | | - addr_frame: ForwardFrame[Address] = field(init=False) |
| 50 | + @property |
| 51 | + def next_address(self) -> int: |
| 52 | + return self._next_address |
56 | 53 |
|
57 | | - def initialize(self): |
58 | | - super().initialize() |
59 | | - self._current_gate_fidelity = 1.0 |
60 | | - self._current_atom_survival_probability = [ |
61 | | - 1.0 for _ in range(len(self.atom_survival_probability)) |
62 | | - ] |
63 | | - return self |
| 54 | + @next_address.setter |
| 55 | + def next_address(self, value: int): |
| 56 | + # NOTE: hook into setter to make sure we always have fidelities of the correct length |
| 57 | + self._next_address = value |
| 58 | + self.extend_fidelities() |
| 59 | + |
| 60 | + def extend_fidelities(self): |
| 61 | + """Extend both fidelity lists so their length matches the number of qubits""" |
| 62 | + |
| 63 | + self.extend_fidelity(self.gate_fidelities) |
| 64 | + self.extend_fidelity(self.qubit_survival_fidelities) |
| 65 | + |
| 66 | + def extend_fidelity(self, fidelities: list[FidelityRange]): |
| 67 | + """Extend a list of fidelities so its length matches the number of qubits""" |
64 | 68 |
|
65 | | - def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): |
66 | | - # NOTE: default is to conserve fidelity, so do nothing here |
67 | | - return |
| 69 | + n = self.qubit_count |
| 70 | + fidelities.extend([FidelityRange(1.0, 1.0) for _ in range(n - len(fidelities))]) |
68 | 71 |
|
69 | | - def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]: |
70 | | - self._run_address_analysis(method) |
71 | | - return super().run(method, *args, **kwargs) |
| 72 | + def reset_fidelities(self): |
| 73 | + """Reset fidelities to unity for all qubits""" |
72 | 74 |
|
73 | | - def _run_address_analysis(self, method: ir.Method): |
74 | | - addr_analysis = AddressAnalysis(self.dialects) |
75 | | - addr_frame, _ = addr_analysis.run(method=method) |
76 | | - self.addr_frame = addr_frame |
| 75 | + self.gate_fidelities = [ |
| 76 | + FidelityRange(1.0, 1.0) for _ in range(self.qubit_count) |
| 77 | + ] |
| 78 | + self.qubit_survival_fidelities = [ |
| 79 | + FidelityRange(1.0, 1.0) for _ in range(self.qubit_count) |
| 80 | + ] |
77 | 81 |
|
78 | | - # NOTE: make sure we have as many probabilities as we have addresses |
79 | | - self.atom_survival_probability = [1.0] * addr_analysis.qubit_count |
| 82 | + @staticmethod |
| 83 | + def update_fidelities( |
| 84 | + fidelities: list[FidelityRange], fidelity: float, addresses: AddressReg |
| 85 | + ): |
| 86 | + """short-hand to update both (min, max) values""" |
| 87 | + |
| 88 | + for idx in addresses.data: |
| 89 | + fidelities[idx].min *= fidelity |
| 90 | + fidelities[idx].max *= fidelity |
| 91 | + |
| 92 | + def update_branched_fidelities( |
| 93 | + self, |
| 94 | + fidelities: list[FidelityRange], |
| 95 | + current_fidelities: list[FidelityRange], |
| 96 | + then_fidelities: list[FidelityRange], |
| 97 | + else_fidelities: list[FidelityRange], |
| 98 | + ): |
| 99 | + """Update fidelity (min, max) values after evaluating differing branches such as IfElse""" |
| 100 | + # NOTE: make sure they are all of the same length |
| 101 | + map( |
| 102 | + self.extend_fidelity, |
| 103 | + (fidelities, current_fidelities, then_fidelities, else_fidelities), |
| 104 | + ) |
| 105 | + |
| 106 | + # NOTE: now we update min / max accordingly |
| 107 | + for fid, current_fid, then_fid, else_fid in zip( |
| 108 | + fidelities, current_fidelities, then_fidelities, else_fidelities |
| 109 | + ): |
| 110 | + fid.min = current_fid.min * min(then_fid.min, else_fid.min) |
| 111 | + fid.max = current_fid.max * max(then_fid.max, else_fid.max) |
80 | 112 |
|
81 | | - def method_self(self, method: ir.Method) -> EmptyLattice: |
82 | | - return self.lattice.bottom() |
| 113 | + def initialize(self): |
| 114 | + super().initialize() |
| 115 | + self.reset_fidelities() |
| 116 | + return self |
0 commit comments