Skip to content
Merged
1 change: 1 addition & 0 deletions src/bloqade/analysis/fidelity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .analysis import FidelityAnalysis as FidelityAnalysis
69 changes: 69 additions & 0 deletions src/bloqade/analysis/fidelity/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any
from dataclasses import field

from kirin import ir
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward
from kirin.interp.value import Successor
from kirin.analysis.forward import ForwardFrame

from ..address import AddressAnalysis


class FidelityAnalysis(Forward):
"""
This analysis pass can be used to track the global addresses of qubits and wires.
"""

keys = ["circuit.fidelity"]
lattice = EmptyLattice

"""
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
"""
gate_fidelity: float = 1.0

_current_gate_fidelity: float = field(init=False)

"""
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
"""
atom_survival_probability: list[float] = field(init=False)

_current_atom_survival_probability: list[float] = field(init=False)

addr_frame: ForwardFrame = field(init=False)

def initialize(self):
super().initialize()
self._current_gate_fidelity = 1.0
self._current_atom_survival_probability = [
1.0 for _ in range(len(self.atom_survival_probability))
]
return self

def posthook_succ(self, frame: ForwardFrame, succ: Successor):
self.gate_fidelity *= self._current_gate_fidelity
for i, _current_survival in enumerate(self._current_atom_survival_probability):
self.atom_survival_probability[i] *= _current_survival

def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
# NOTE: default is to conserve fidelity, so do nothing here
return

def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
return self.run_callable(method.code, (self.lattice.bottom(),) + args)

def run_analysis(
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
) -> tuple[ForwardFrame, Any]:
self._run_address_analysis(method, no_raise=no_raise)
return super().run_analysis(method, args, no_raise=no_raise)

def _run_address_analysis(self, method: ir.Method, no_raise: bool):
addr_analysis = AddressAnalysis(self.dialects)
addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
self.addr_frame = addr_frame

# NOTE: make sure we have as many probabilities as we have addresses
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
3 changes: 2 additions & 1 deletion src/bloqade/noise/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import native as native
# NOTE: just to register methods
from . import native as native, fidelity as fidelity
51 changes: 51 additions & 0 deletions src/bloqade/noise/fidelity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from kirin import interp
from kirin.lattice import EmptyLattice

from bloqade.analysis.fidelity import FidelityAnalysis

from .native import dialect as native
from .native.stmts import PauliChannel, CZPauliChannel, AtomLossChannel
from ..analysis.address import AddressQubit, AddressTuple


@native.register(key="circuit.fidelity")
class FidelityMethodTable(interp.MethodTable):

@interp.impl(PauliChannel)
@interp.impl(CZPauliChannel)
def pauli_channel(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
stmt: PauliChannel | CZPauliChannel,
):
probs = stmt.probabilities
try:
ps, ps_ctrl = probs
except ValueError:
(ps,) = probs
ps_ctrl = ()

p = sum(ps)
p_ctrl = sum(ps_ctrl)

# NOTE: fidelity is just the inverse probability of any noise to occur
fid = (1 - p) * (1 - p_ctrl)

interp._current_gate_fidelity *= fid

@interp.impl(AtomLossChannel)
def atom_loss(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
stmt: AtomLossChannel,
):
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)

# NOTE: get the corresponding index and reduce survival probability accordingly
for qbit_address in addresses.data:
assert isinstance(qbit_address, AddressQubit)
index = qbit_address.data
interp._current_atom_survival_probability[index] *= 1 - stmt.prob
246 changes: 246 additions & 0 deletions test/analysis/fidelity/test_fidelity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import math

import pytest

from bloqade import qasm2
from bloqade.noise import native
from bloqade.analysis.fidelity import FidelityAnalysis
from bloqade.qasm2.passes.noise import NoisePass

noise_main = qasm2.extended.add(native.dialect)


class NoiseTestModel(native.MoveNoiseModelABC):
def parallel_cz_errors(self, ctrls, qargs, rest):
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}


def test_basic_noise():

@noise_main
def main():
q = qasm2.qreg(2)
qasm2.x(q[0])
return q

main.print()

fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1

px = 0.01
py = 0.01
pz = 0.01
p_loss = 0.01

noise_params = native.GateNoiseParams(
global_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
local_px=0.002,
)

model = NoiseTestModel()

NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)

main.print()

fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

p_noise = noise_params.local_px + noise_params.local_py + noise_params.local_pz
assert (
fid_analysis.gate_fidelity
== fid_analysis._current_gate_fidelity
== (1 - p_noise)
)

assert 0.9 < fid_analysis.atom_survival_probability[0] < 1
assert fid_analysis.atom_survival_probability[0] == 1 - noise_params.local_loss_prob
assert fid_analysis.atom_survival_probability[1] == 1


def test_c_noise():
@noise_main
def main():
q = qasm2.qreg(2)
qasm2.cz(q[0], q[1])
return q

main.print()

fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1

px = 0.01
py = 0.01
pz = 0.01
p_loss = 0.01

noise_params = native.GateNoiseParams(
global_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
local_px=0.002,
)

model = NoiseTestModel()

NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)

main.print()

fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

# two cz channels (**2 for each one since we look at both control & target)
fid_cz = (1 - 3 * noise_params.cz_paired_gate_px) ** 4

# one pauli channel
fid_cz *= 1 - noise_params.global_px * 3

assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity
assert math.isclose(fid_cz, fid_analysis.gate_fidelity, abs_tol=1e-14)

assert 0.9 < fid_analysis.atom_survival_probability[0] < 1
assert fid_analysis.atom_survival_probability[0] == (
1 - noise_params.cz_gate_loss_prob
) * (1 - p_loss)


@pytest.mark.xfail
def test_if():

@noise_main
def main():
q = qasm2.qreg(1)
c = qasm2.creg(1)
qasm2.h(q[0])
qasm2.measure(q, c)
qasm2.x(q[0])
qasm2.measure(q, c)

return c

@noise_main
def main_if():
q = qasm2.qreg(1)
c = qasm2.creg(1)
qasm2.h(q[0])
qasm2.measure(q, c)

if c[0] == 0:
qasm2.x(q[0])

qasm2.measure(q, c)
return c

px = 0.01
py = 0.01
pz = 0.01
p_loss = 0.01

noise_params = native.GateNoiseParams(
global_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
local_px=0.002,
)

model = NoiseTestModel()
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

model = NoiseTestModel()
NoisePass(main_if.dialects, noise_model=model, gate_noise_params=noise_params)(
main_if
)
fid_if_analysis = FidelityAnalysis(main_if.dialects)
fid_if_analysis.run_analysis(main_if, no_raise=False)

assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1
assert (
0
< fid_if_analysis.atom_survival_probability[0]
== fid_analysis.atom_survival_probability[0]
< 1
)


@pytest.mark.xfail
def test_for():

@noise_main
def main():
q = qasm2.qreg(1)
c = qasm2.creg(1)
qasm2.h(q[0])
qasm2.measure(q, c)

# unrolled for loop
qasm2.x(q[0])
qasm2.x(q[0])
qasm2.x(q[0])

qasm2.measure(q, c)

return c

@noise_main
def main_for():
q = qasm2.qreg(1)
c = qasm2.creg(1)
qasm2.h(q[0])
qasm2.measure(q, c)

for _ in range(3):
qasm2.x(q[0])

qasm2.measure(q, c)
return c

px = 0.01
py = 0.01
pz = 0.01
p_loss = 0.01

noise_params = native.GateNoiseParams(
global_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
local_px=0.002,
local_loss_prob=0.03,
)

model = NoiseTestModel()
NoisePass(main.dialects, noise_model=model, gate_noise_params=noise_params)(main)
fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

model = NoiseTestModel()
NoisePass(main_for.dialects, noise_model=model, gate_noise_params=noise_params)(
main_for
)

main_for.print()

fid_for_analysis = FidelityAnalysis(main_for.dialects)
fid_for_analysis.run_analysis(main_for, no_raise=False)

assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1
assert (
0
< fid_for_analysis.atom_survival_probability[0]
== fid_analysis.atom_survival_probability[0]
< 1
)
4 changes: 1 addition & 3 deletions test/qasm2/passes/test_heuristic_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@


class NoiseTestModel(native.MoveNoiseModelABC):

@classmethod
def parallel_cz_errors(cls, ctrls, qargs, rest):
def parallel_cz_errors(self, ctrls, qargs, rest):
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}


Expand Down