diff --git a/src/bloqade/analysis/fidelity/__init__.py b/src/bloqade/analysis/fidelity/__init__.py new file mode 100644 index 00000000..496b1b5d --- /dev/null +++ b/src/bloqade/analysis/fidelity/__init__.py @@ -0,0 +1 @@ +from .analysis import FidelityAnalysis as FidelityAnalysis diff --git a/src/bloqade/analysis/fidelity/analysis.py b/src/bloqade/analysis/fidelity/analysis.py new file mode 100644 index 00000000..8d0571c3 --- /dev/null +++ b/src/bloqade/analysis/fidelity/analysis.py @@ -0,0 +1,69 @@ +from typing import Any +from dataclasses import field + +from kirin import ir +from kirin.lattice import EmptyLattice +from kirin.analysis import Forward +from kirin.interp.value import Successor +from kirin.analysis.forward import ForwardFrame + +from ..address import AddressAnalysis + + +class FidelityAnalysis(Forward): + """ + This analysis pass can be used to track the global addresses of qubits and wires. + """ + + keys = ["circuit.fidelity"] + lattice = EmptyLattice + + """ + The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered. + """ + gate_fidelity: float = 1.0 + + _current_gate_fidelity: float = field(init=False) + + """ + The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register. + """ + atom_survival_probability: list[float] = field(init=False) + + _current_atom_survival_probability: list[float] = field(init=False) + + addr_frame: ForwardFrame = field(init=False) + + def initialize(self): + super().initialize() + self._current_gate_fidelity = 1.0 + self._current_atom_survival_probability = [ + 1.0 for _ in range(len(self.atom_survival_probability)) + ] + return self + + def posthook_succ(self, frame: ForwardFrame, succ: Successor): + self.gate_fidelity *= self._current_gate_fidelity + for i, _current_survival in enumerate(self._current_atom_survival_probability): + self.atom_survival_probability[i] *= _current_survival + + def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): + # NOTE: default is to conserve fidelity, so do nothing here + return + + def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]): + return self.run_callable(method.code, (self.lattice.bottom(),) + args) + + def run_analysis( + self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True + ) -> tuple[ForwardFrame, Any]: + self._run_address_analysis(method, no_raise=no_raise) + return super().run_analysis(method, args, no_raise=no_raise) + + def _run_address_analysis(self, method: ir.Method, no_raise: bool): + addr_analysis = AddressAnalysis(self.dialects) + addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise) + self.addr_frame = addr_frame + + # NOTE: make sure we have as many probabilities as we have addresses + self.atom_survival_probability = [1.0] * addr_analysis.qubit_count diff --git a/src/bloqade/noise/__init__.py b/src/bloqade/noise/__init__.py index 0994acfd..caeacc0e 100644 --- a/src/bloqade/noise/__init__.py +++ b/src/bloqade/noise/__init__.py @@ -1 +1,2 @@ -from . import native as native +# NOTE: just to register methods +from . import native as native, fidelity as fidelity diff --git a/src/bloqade/noise/fidelity.py b/src/bloqade/noise/fidelity.py new file mode 100644 index 00000000..c489dffc --- /dev/null +++ b/src/bloqade/noise/fidelity.py @@ -0,0 +1,51 @@ +from kirin import interp +from kirin.lattice import EmptyLattice + +from bloqade.analysis.fidelity import FidelityAnalysis + +from .native import dialect as native +from .native.stmts import PauliChannel, CZPauliChannel, AtomLossChannel +from ..analysis.address import AddressQubit, AddressTuple + + +@native.register(key="circuit.fidelity") +class FidelityMethodTable(interp.MethodTable): + + @interp.impl(PauliChannel) + @interp.impl(CZPauliChannel) + def pauli_channel( + self, + interp: FidelityAnalysis, + frame: interp.Frame[EmptyLattice], + stmt: PauliChannel | CZPauliChannel, + ): + probs = stmt.probabilities + try: + ps, ps_ctrl = probs + except ValueError: + (ps,) = probs + ps_ctrl = () + + p = sum(ps) + p_ctrl = sum(ps_ctrl) + + # NOTE: fidelity is just the inverse probability of any noise to occur + fid = (1 - p) * (1 - p_ctrl) + + interp._current_gate_fidelity *= fid + + @interp.impl(AtomLossChannel) + def atom_loss( + self, + interp: FidelityAnalysis, + frame: interp.Frame[EmptyLattice], + stmt: AtomLossChannel, + ): + # NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple + addresses: AddressTuple = interp.addr_frame.get(stmt.qargs) + + # NOTE: get the corresponding index and reduce survival probability accordingly + for qbit_address in addresses.data: + assert isinstance(qbit_address, AddressQubit) + index = qbit_address.data + interp._current_atom_survival_probability[index] *= 1 - stmt.prob diff --git a/test/analysis/fidelity/test_fidelity.py b/test/analysis/fidelity/test_fidelity.py new file mode 100644 index 00000000..644cda36 --- /dev/null +++ b/test/analysis/fidelity/test_fidelity.py @@ -0,0 +1,246 @@ +import math + +import pytest + +from bloqade import qasm2 +from bloqade.noise import native +from bloqade.analysis.fidelity import FidelityAnalysis +from bloqade.qasm2.passes.noise import NoisePass + +noise_main = qasm2.extended.add(native.dialect) + + +class NoiseTestModel(native.MoveNoiseModelABC): + def parallel_cz_errors(self, ctrls, qargs, rest): + return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest} + + +def test_basic_noise(): + + @noise_main + def main(): + q = qasm2.qreg(2) + qasm2.x(q[0]) + return q + + main.print() + + fid_analysis = FidelityAnalysis(main.dialects) + fid_analysis.run_analysis(main, no_raise=False) + + assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1 + + px = 0.01 + py = 0.01 + pz = 0.01 + p_loss = 0.01 + + noise_params = native.GateNoiseParams( + global_loss_prob=p_loss, + global_px=px, + global_py=py, + global_pz=pz, + local_px=0.002, + ) + + model = NoiseTestModel() + + NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main) + + main.print() + + fid_analysis = FidelityAnalysis(main.dialects) + fid_analysis.run_analysis(main, no_raise=False) + + p_noise = noise_params.local_px + noise_params.local_py + noise_params.local_pz + assert ( + fid_analysis.gate_fidelity + == fid_analysis._current_gate_fidelity + == (1 - p_noise) + ) + + assert 0.9 < fid_analysis.atom_survival_probability[0] < 1 + assert fid_analysis.atom_survival_probability[0] == 1 - noise_params.local_loss_prob + assert fid_analysis.atom_survival_probability[1] == 1 + + +def test_c_noise(): + @noise_main + def main(): + q = qasm2.qreg(2) + qasm2.cz(q[0], q[1]) + return q + + main.print() + + fid_analysis = FidelityAnalysis(main.dialects) + fid_analysis.run_analysis(main, no_raise=False) + + assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1 + + px = 0.01 + py = 0.01 + pz = 0.01 + p_loss = 0.01 + + noise_params = native.GateNoiseParams( + global_loss_prob=p_loss, + global_px=px, + global_py=py, + global_pz=pz, + local_px=0.002, + ) + + model = NoiseTestModel() + + NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main) + + main.print() + + fid_analysis = FidelityAnalysis(main.dialects) + fid_analysis.run_analysis(main, no_raise=False) + + # two cz channels (**2 for each one since we look at both control & target) + fid_cz = (1 - 3 * noise_params.cz_paired_gate_px) ** 4 + + # one pauli channel + fid_cz *= 1 - noise_params.global_px * 3 + + assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity + assert math.isclose(fid_cz, fid_analysis.gate_fidelity, abs_tol=1e-14) + + assert 0.9 < fid_analysis.atom_survival_probability[0] < 1 + assert fid_analysis.atom_survival_probability[0] == ( + 1 - noise_params.cz_gate_loss_prob + ) * (1 - p_loss) + + +@pytest.mark.xfail +def test_if(): + + @noise_main + def main(): + q = qasm2.qreg(1) + c = qasm2.creg(1) + qasm2.h(q[0]) + qasm2.measure(q, c) + qasm2.x(q[0]) + qasm2.measure(q, c) + + return c + + @noise_main + def main_if(): + q = qasm2.qreg(1) + c = qasm2.creg(1) + qasm2.h(q[0]) + qasm2.measure(q, c) + + if c[0] == 0: + qasm2.x(q[0]) + + qasm2.measure(q, c) + return c + + px = 0.01 + py = 0.01 + pz = 0.01 + p_loss = 0.01 + + noise_params = native.GateNoiseParams( + global_loss_prob=p_loss, + global_px=px, + global_py=py, + global_pz=pz, + local_px=0.002, + ) + + model = NoiseTestModel() + NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main) + fid_analysis = FidelityAnalysis(main.dialects) + fid_analysis.run_analysis(main, no_raise=False) + + model = NoiseTestModel() + NoisePass(main_if.dialects, noise_model=model, gate_noise_params=noise_params)( + main_if + ) + fid_if_analysis = FidelityAnalysis(main_if.dialects) + fid_if_analysis.run_analysis(main_if, no_raise=False) + + assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 + assert ( + 0 + < fid_if_analysis.atom_survival_probability[0] + == fid_analysis.atom_survival_probability[0] + < 1 + ) + + +@pytest.mark.xfail +def test_for(): + + @noise_main + def main(): + q = qasm2.qreg(1) + c = qasm2.creg(1) + qasm2.h(q[0]) + qasm2.measure(q, c) + + # unrolled for loop + qasm2.x(q[0]) + qasm2.x(q[0]) + qasm2.x(q[0]) + + qasm2.measure(q, c) + + return c + + @noise_main + def main_for(): + q = qasm2.qreg(1) + c = qasm2.creg(1) + qasm2.h(q[0]) + qasm2.measure(q, c) + + for _ in range(3): + qasm2.x(q[0]) + + qasm2.measure(q, c) + return c + + px = 0.01 + py = 0.01 + pz = 0.01 + p_loss = 0.01 + + noise_params = native.GateNoiseParams( + global_loss_prob=p_loss, + global_px=px, + global_py=py, + global_pz=pz, + local_px=0.002, + local_loss_prob=0.03, + ) + + model = NoiseTestModel() + NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main) + fid_analysis = FidelityAnalysis(main.dialects) + fid_analysis.run_analysis(main, no_raise=False) + + model = NoiseTestModel() + NoisePass(main_for.dialects, noise_model=model, gate_noise_params=noise_params)( + main_for + ) + + main_for.print() + + fid_for_analysis = FidelityAnalysis(main_for.dialects) + fid_for_analysis.run_analysis(main_for, no_raise=False) + + assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 + assert ( + 0 + < fid_for_analysis.atom_survival_probability[0] + == fid_analysis.atom_survival_probability[0] + < 1 + ) diff --git a/test/qasm2/passes/test_heuristic_noise.py b/test/qasm2/passes/test_heuristic_noise.py index 85d1ed6f..8b69e5c8 100644 --- a/test/qasm2/passes/test_heuristic_noise.py +++ b/test/qasm2/passes/test_heuristic_noise.py @@ -12,9 +12,7 @@ class NoiseTestModel(native.MoveNoiseModelABC): - - @classmethod - def parallel_cz_errors(cls, ctrls, qargs, rest): + def parallel_cz_errors(self, ctrls, qargs, rest): return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}