Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@ class AddressAnalysis(Forward[Address]):
keys = ("qubit.address",)
_const_prop: const.Propagate
lattice = Address
next_address: int = field(init=False)
_next_address: int = field(init=False)

# NOTE: the following are properties so we can hook into the setter in FidelityAnalysis
@property
def next_address(self) -> int:
return self._next_address

@next_address.setter
def next_address(self, value: int):
self._next_address = value

def initialize(self):
super().initialize()
self.next_address: int = 0
self.next_address = 0
self._const_prop = const.Propagate(self.dialects)
self._const_prop.initialize()
return self
Expand Down Expand Up @@ -127,7 +136,9 @@ def run_lattice(
case _:
return Address.top()

def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
def get_const_value(
self, addr: Address, typ: Type[T] | tuple[Type[T], ...]
) -> T | None:
if not isinstance(addr, ConstResult):
return None

Expand Down
6 changes: 5 additions & 1 deletion src/bloqade/analysis/fidelity/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .analysis import FidelityAnalysis as FidelityAnalysis
from . import impls as impls
from .analysis import (
FidelityRange as FidelityRange,
FidelityAnalysis as FidelityAnalysis,
)
123 changes: 81 additions & 42 deletions src/bloqade/analysis/fidelity/analysis.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Any
from dataclasses import field
from dataclasses import field, dataclass

from kirin import ir
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward
from kirin.analysis.forward import ForwardFrame
from ..address import AddressReg, AddressAnalysis

from ..address import Address, AddressAnalysis

@dataclass
class FidelityRange:
"""Range of fidelity for a qubit as pair of (min, max) values"""

class FidelityAnalysis(Forward):
min: float
max: float


@dataclass
class FidelityAnalysis(AddressAnalysis):
"""
This analysis pass can be used to track the global addresses of qubits and wires.

Expand All @@ -34,49 +37,85 @@ def main():
fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

gate_fidelity = fid_analysis.gate_fidelity
atom_survival_probs = fid_analysis.atom_survival_probability
gate_fidelities = fid_analysis.gate_fidelities
qubit_survival_probs = fid_analysis.qubit_survival_fidelities
```
"""

keys = ["circuit.fidelity"]
lattice = EmptyLattice
keys = ("circuit.fidelity", "qubit.address")

gate_fidelity: float = 1.0
"""
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
"""
gate_fidelities: list[FidelityRange] = field(init=False, default_factory=list)
"""Gate fidelities of each qubit as (min, max) pairs to provide a range"""

atom_survival_probability: list[float] = field(init=False)
"""
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
"""
qubit_survival_fidelities: list[FidelityRange] = field(
init=False, default_factory=list
)
"""Qubit survival fidelity given as (min, max) pairs"""

addr_frame: ForwardFrame[Address] = field(init=False)
@property
def next_address(self) -> int:
return self._next_address

def initialize(self):
super().initialize()
self._current_gate_fidelity = 1.0
self._current_atom_survival_probability = [
1.0 for _ in range(len(self.atom_survival_probability))
]
return self
@next_address.setter
def next_address(self, value: int):
# NOTE: hook into setter to make sure we always have fidelities of the correct length
self._next_address = value
self.extend_fidelities()

def extend_fidelities(self):
"""Extend both fidelity lists so their length matches the number of qubits"""

self.extend_fidelity(self.gate_fidelities)
self.extend_fidelity(self.qubit_survival_fidelities)

def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
# NOTE: default is to conserve fidelity, so do nothing here
return
def extend_fidelity(self, fidelities: list[FidelityRange]):
"""Extend a list of fidelities so its length matches the number of qubits"""

def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
self._run_address_analysis(method)
return super().run(method, *args, **kwargs)
n = self.qubit_count
fidelities.extend([FidelityRange(1.0, 1.0) for _ in range(n - len(fidelities))])

def _run_address_analysis(self, method: ir.Method):
addr_analysis = AddressAnalysis(self.dialects)
addr_frame, _ = addr_analysis.run(method=method)
self.addr_frame = addr_frame
def reset_fidelities(self):
"""Reset fidelities to unity for all qubits"""

# NOTE: make sure we have as many probabilities as we have addresses
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
self.gate_fidelities = [
FidelityRange(1.0, 1.0) for _ in range(self.qubit_count)
]
self.qubit_survival_fidelities = [
FidelityRange(1.0, 1.0) for _ in range(self.qubit_count)
]

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()
@staticmethod
def update_fidelities(
fidelities: list[FidelityRange], fidelity: float, addresses: AddressReg
):
"""short-hand to update both (min, max) values"""

for idx in addresses.data:
fidelities[idx].min *= fidelity
fidelities[idx].max *= fidelity

def update_branched_fidelities(
self,
fidelities: list[FidelityRange],
current_fidelities: list[FidelityRange],
then_fidelities: list[FidelityRange],
else_fidelities: list[FidelityRange],
):
"""Update fidelity (min, max) values after evaluating differing branches such as IfElse"""
# NOTE: make sure they are all of the same length
map(
self.extend_fidelity,
(fidelities, current_fidelities, then_fidelities, else_fidelities),
)

# NOTE: now we update min / max accordingly
for fid, current_fid, then_fid, else_fid in zip(
fidelities, current_fidelities, then_fidelities, else_fidelities
):
fid.min = current_fid.min * min(then_fid.min, else_fid.min)
fid.max = current_fid.max * max(then_fid.max, else_fid.max)

def initialize(self):
super().initialize()
self.reset_fidelities()
return self
62 changes: 62 additions & 0 deletions src/bloqade/analysis/fidelity/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from kirin import interp
from kirin.analysis import ForwardFrame
from kirin.dialects import scf

from bloqade.analysis.address import Address

from .analysis import FidelityAnalysis


@scf.dialect.register(key="circuit.fidelity")
class __ScfMethods(interp.MethodTable):
@interp.impl(scf.IfElse)
def if_else(
self, interp_: FidelityAnalysis, frame: ForwardFrame[Address], stmt: scf.IfElse
):

# NOTE: store a copy of the fidelities
current_gate_fidelities = interp_.gate_fidelities
current_survival_fidelities = interp_.qubit_survival_fidelities

# TODO: check if the condition is constant and fix the branch in that case
# run both branches
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
# NOTE: reset fidelities before stepping into the then-body
interp_.reset_fidelities()

interp_.frame_call_region(
then_frame,
stmt,
stmt.then_body,
*(interp_.lattice.bottom() for _ in range(len(stmt.args))),
)
then_fids = interp_.gate_fidelities
then_survival = interp_.qubit_survival_fidelities

with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
# NOTE: reset again before stepping into else-body
interp_.reset_fidelities()

interp_.frame_call_region(
else_frame,
stmt,
stmt.else_body,
*(interp_.lattice.bottom() for _ in range(len(stmt.args))),
)

else_fids = interp_.gate_fidelities
else_survival = interp_.qubit_survival_fidelities

# NOTE: reset one last time
interp_.reset_fidelities()

# NOTE: now update min / max pairs accordingly
interp_.update_branched_fidelities(
interp_.gate_fidelities, current_gate_fidelities, then_fids, else_fids
)
interp_.update_branched_fidelities(
interp_.qubit_survival_fidelities,
current_survival_fidelities,
then_survival,
else_survival,
)
78 changes: 55 additions & 23 deletions src/bloqade/qasm2/dialects/noise/fidelity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kirin import interp
from kirin.lattice import EmptyLattice
from kirin.analysis import ForwardFrame

from bloqade.analysis.address import Address, AddressReg
from bloqade.analysis.fidelity import FidelityAnalysis

from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
Expand All @@ -11,37 +12,68 @@
class FidelityMethodTable(interp.MethodTable):

@interp.impl(PauliChannel)
@interp.impl(CZPauliChannel)
def pauli_channel(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
stmt: PauliChannel | CZPauliChannel,
interp_: FidelityAnalysis,
frame: ForwardFrame[Address],
stmt: PauliChannel,
):
(ps,) = stmt.probabilities
fidelity = 1 - sum(ps)

addresses = frame.get(stmt.qargs)

if not isinstance(addresses, AddressReg):
return ()

interp_.update_fidelities(interp_.gate_fidelities, fidelity, addresses)

return ()

@interp.impl(CZPauliChannel)
def cz_pauli_channel(
self,
interp_: FidelityAnalysis,
frame: ForwardFrame[Address],
stmt: CZPauliChannel,
):
probs = stmt.probabilities
try:
ps, ps_ctrl = probs
except ValueError:
(ps,) = probs
ps_ctrl = ()
ps_ctrl, ps_target = stmt.probabilities

fidelity_ctrl = 1 - sum(ps_ctrl)
fidelity_target = 1 - sum(ps_target)

addresses_ctrl = frame.get(stmt.ctrls)
addresses_target = frame.get(stmt.qargs)

p = sum(ps)
p_ctrl = sum(ps_ctrl)
if not isinstance(addresses_ctrl, AddressReg) or not isinstance(
addresses_target, AddressReg
):
return ()

# NOTE: fidelity is just the inverse probability of any noise to occur
fid = (1 - p) * (1 - p_ctrl)
interp_.update_fidelities(
interp_.gate_fidelities, fidelity_ctrl, addresses_ctrl
)
interp_.update_fidelities(
interp_.gate_fidelities, fidelity_target, addresses_target
)

interp.gate_fidelity *= fid
return ()

@interp.impl(AtomLossChannel)
def atom_loss(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
interp_: FidelityAnalysis,
frame: ForwardFrame[Address],
stmt: AtomLossChannel,
):
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
addresses = interp.addr_frame.get(stmt.qargs)
# NOTE: get the corresponding index and reduce survival probability accordingly
for index in addresses.data:
interp.atom_survival_probability[index] *= 1 - stmt.prob
addresses = frame.get(stmt.qargs)

if not isinstance(addresses, AddressReg):
return ()

fidelity = 1 - stmt.prob
interp_.update_fidelities(
interp_.qubit_survival_fidelities, fidelity, addresses
)

return ()
1 change: 1 addition & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
two_qubit_pauli_channel as two_qubit_pauli_channel,
single_qubit_pauli_channel as single_qubit_pauli_channel,
)
from .analysis.fidelity import impls as impls

# NOTE: it's important to keep these imports here since they import squin.kernel
# we skip isort here
Expand Down
Empty file.
Loading
Loading