Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@ class AddressAnalysis(Forward[Address]):
keys = ("qubit.address",)
_const_prop: const.Propagate
lattice = Address
next_address: int = field(init=False)
_next_address: int = field(init=False)

# NOTE: the following are properties so we can hook into the setter in FidelityAnalysis
@property
def next_address(self) -> int:
return self._next_address

@next_address.setter
def next_address(self, value: int):
self._next_address = value

def initialize(self):
super().initialize()
self.next_address: int = 0
self.next_address = 0
self._const_prop = const.Propagate(self.dialects)
self._const_prop.initialize()
return self
Expand Down Expand Up @@ -127,7 +136,9 @@ def run_lattice(
case _:
return Address.top()

def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
def get_const_value(
self, addr: Address, typ: Type[T] | tuple[Type[T], ...]
) -> T | None:
if not isinstance(addr, ConstResult):
return None

Expand Down
6 changes: 5 additions & 1 deletion src/bloqade/analysis/fidelity/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .analysis import FidelityAnalysis as FidelityAnalysis
from . import impls as impls
from .analysis import (
FidelityRange as FidelityRange,
FidelityAnalysis as FidelityAnalysis,
)
140 changes: 87 additions & 53 deletions src/bloqade/analysis/fidelity/analysis.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,116 @@
from typing import Any
from dataclasses import field
from dataclasses import field, dataclass

from kirin import ir
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward
from kirin.analysis.forward import ForwardFrame
from ..address import AddressReg, AddressAnalysis

from ..address import Address, AddressAnalysis

@dataclass
class FidelityRange:
"""Range of fidelity for a qubit as pair of (min, max) values"""

class FidelityAnalysis(Forward):
min: float
max: float


@dataclass
class FidelityAnalysis(AddressAnalysis):
"""
This analysis pass can be used to track the global addresses of qubits and wires.

## Usage examples

```
from bloqade import qasm2
from bloqade.noise import native
from bloqade import squin
from bloqade.analysis.fidelity import FidelityAnalysis
from bloqade.qasm2.passes.noise import NoisePass

noise_main = qasm2.extended.add(native.dialect)

@noise_main
@squin.kernel
def main():
q = qasm2.qreg(2)
qasm2.x(q[0])
q = squin.qalloc(1)
squin.x(q[0])
squin.depolarize(0.1, q[0])
return q

NoisePass(main.dialects)(main)

fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)
fid_analysis.run(main)

gate_fidelity = fid_analysis.gate_fidelity
atom_survival_probs = fid_analysis.atom_survival_probability
gate_fidelities = fid_analysis.gate_fidelities
qubit_survival_probs = fid_analysis.qubit_survival_fidelities
```
"""

keys = ["circuit.fidelity"]
lattice = EmptyLattice
keys = ("circuit.fidelity", "qubit.address")

gate_fidelity: float = 1.0
"""
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
"""
gate_fidelities: list[FidelityRange] = field(init=False, default_factory=list)
"""Gate fidelities of each qubit as (min, max) pairs to provide a range"""

atom_survival_probability: list[float] = field(init=False)
"""
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
"""
qubit_survival_fidelities: list[FidelityRange] = field(
init=False, default_factory=list
)
"""Qubit survival fidelity given as (min, max) pairs"""

addr_frame: ForwardFrame[Address] = field(init=False)
@property
def next_address(self) -> int:
return self._next_address

def initialize(self):
super().initialize()
self._current_gate_fidelity = 1.0
self._current_atom_survival_probability = [
1.0 for _ in range(len(self.atom_survival_probability))
]
return self
@next_address.setter
def next_address(self, value: int):
# NOTE: hook into setter to make sure we always have fidelities of the correct length
self._next_address = value
self.extend_fidelities()

def extend_fidelities(self):
"""Extend both fidelity lists so their length matches the number of qubits"""

self.extend_fidelity(self.gate_fidelities)
self.extend_fidelity(self.qubit_survival_fidelities)

def extend_fidelity(self, fidelities: list[FidelityRange]):
"""Extend a list of fidelities so its length matches the number of qubits"""

def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
# NOTE: default is to conserve fidelity, so do nothing here
return
n = self.qubit_count
fidelities.extend([FidelityRange(1.0, 1.0) for _ in range(n - len(fidelities))])

def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
self._run_address_analysis(method)
return super().run(method, *args, **kwargs)
def reset_fidelities(self):
"""Reset fidelities to unity for all qubits"""

def _run_address_analysis(self, method: ir.Method):
addr_analysis = AddressAnalysis(self.dialects)
addr_frame, _ = addr_analysis.run(method=method)
self.addr_frame = addr_frame
self.gate_fidelities = [
FidelityRange(1.0, 1.0) for _ in range(self.qubit_count)
]
self.qubit_survival_fidelities = [
FidelityRange(1.0, 1.0) for _ in range(self.qubit_count)
]

# NOTE: make sure we have as many probabilities as we have addresses
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
@staticmethod
def update_fidelities(
fidelities: list[FidelityRange], fidelity: float, addresses: AddressReg
):
"""short-hand to update both (min, max) values"""

for idx in addresses.data:
fidelities[idx].min *= fidelity
fidelities[idx].max *= fidelity

def update_branched_fidelities(
self,
fidelities: list[FidelityRange],
current_fidelities: list[FidelityRange],
then_fidelities: list[FidelityRange],
else_fidelities: list[FidelityRange],
):
"""Update fidelity (min, max) values after evaluating differing branches such as IfElse"""
# NOTE: make sure they are all of the same length
map(
self.extend_fidelity,
(fidelities, current_fidelities, then_fidelities, else_fidelities),
)

# NOTE: now we update min / max accordingly
for fid, current_fid, then_fid, else_fid in zip(
fidelities, current_fidelities, then_fidelities, else_fidelities
):
fid.min = current_fid.min * min(then_fid.min, else_fid.min)
fid.max = current_fid.max * max(then_fid.max, else_fid.max)

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()
def initialize(self):
super().initialize()
self.reset_fidelities()
return self
86 changes: 86 additions & 0 deletions src/bloqade/analysis/fidelity/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from kirin import interp
from kirin.analysis import ForwardFrame, const
from kirin.dialects import scf

from bloqade.analysis.address import Address, ConstResult

from .analysis import FidelityAnalysis


@scf.dialect.register(key="circuit.fidelity")
class __ScfMethods(interp.MethodTable):
@interp.impl(scf.IfElse)
def if_else(
self, interp_: FidelityAnalysis, frame: ForwardFrame[Address], stmt: scf.IfElse
):

# NOTE: store a copy of the fidelities
current_gate_fidelities = interp_.gate_fidelities
current_survival_fidelities = interp_.qubit_survival_fidelities

address_cond = frame.get(stmt.cond)

# NOTE: if the condition is known at compile time, run specific branch
if isinstance(address_cond, ConstResult) and isinstance(
const_cond := address_cond.result, const.Value
):
body = stmt.then_body if const_cond.data else stmt.else_body
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
return ret

# NOTE: runtime condition, evaluate both
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
# NOTE: reset fidelities before stepping into the then-body
interp_.reset_fidelities()

then_results = interp_.frame_call_region(
then_frame,
stmt,
stmt.then_body,
address_cond,
)
then_fids = interp_.gate_fidelities
then_survival = interp_.qubit_survival_fidelities

with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
# NOTE: reset again before stepping into else-body
interp_.reset_fidelities()

else_results = interp_.frame_call_region(
else_frame,
stmt,
stmt.else_body,
address_cond,
)

else_fids = interp_.gate_fidelities
else_survival = interp_.qubit_survival_fidelities

# NOTE: reset one last time
interp_.reset_fidelities()

# NOTE: now update min / max pairs accordingly
interp_.update_branched_fidelities(
interp_.gate_fidelities, current_gate_fidelities, then_fids, else_fids
)
interp_.update_branched_fidelities(
interp_.qubit_survival_fidelities,
current_survival_fidelities,
then_survival,
else_survival,
)

# TODO: pick the non-return value
if isinstance(then_results, interp.ReturnValue) and isinstance(
else_results, interp.ReturnValue
):
return interp.ReturnValue(then_results.value.join(else_results.value))
elif isinstance(then_results, interp.ReturnValue):
ret = else_results
elif isinstance(else_results, interp.ReturnValue):
ret = then_results
else:
ret = interp_.join_results(then_results, else_results)

return ret
78 changes: 55 additions & 23 deletions src/bloqade/qasm2/dialects/noise/fidelity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kirin import interp
from kirin.lattice import EmptyLattice
from kirin.analysis import ForwardFrame

from bloqade.analysis.address import Address, AddressReg
from bloqade.analysis.fidelity import FidelityAnalysis

from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
Expand All @@ -11,37 +12,68 @@
class FidelityMethodTable(interp.MethodTable):

@interp.impl(PauliChannel)
@interp.impl(CZPauliChannel)
def pauli_channel(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
stmt: PauliChannel | CZPauliChannel,
interp_: FidelityAnalysis,
frame: ForwardFrame[Address],
stmt: PauliChannel,
):
(ps,) = stmt.probabilities
fidelity = 1 - sum(ps)

addresses = frame.get(stmt.qargs)

if not isinstance(addresses, AddressReg):
return ()

interp_.update_fidelities(interp_.gate_fidelities, fidelity, addresses)

return ()

@interp.impl(CZPauliChannel)
def cz_pauli_channel(
self,
interp_: FidelityAnalysis,
frame: ForwardFrame[Address],
stmt: CZPauliChannel,
):
probs = stmt.probabilities
try:
ps, ps_ctrl = probs
except ValueError:
(ps,) = probs
ps_ctrl = ()
ps_ctrl, ps_target = stmt.probabilities

fidelity_ctrl = 1 - sum(ps_ctrl)
fidelity_target = 1 - sum(ps_target)

addresses_ctrl = frame.get(stmt.ctrls)
addresses_target = frame.get(stmt.qargs)

p = sum(ps)
p_ctrl = sum(ps_ctrl)
if not isinstance(addresses_ctrl, AddressReg) or not isinstance(
addresses_target, AddressReg
):
return ()

# NOTE: fidelity is just the inverse probability of any noise to occur
fid = (1 - p) * (1 - p_ctrl)
interp_.update_fidelities(
interp_.gate_fidelities, fidelity_ctrl, addresses_ctrl
)
interp_.update_fidelities(
interp_.gate_fidelities, fidelity_target, addresses_target
)

interp.gate_fidelity *= fid
return ()

@interp.impl(AtomLossChannel)
def atom_loss(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
interp_: FidelityAnalysis,
frame: ForwardFrame[Address],
stmt: AtomLossChannel,
):
# NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
addresses = interp.addr_frame.get(stmt.qargs)
# NOTE: get the corresponding index and reduce survival probability accordingly
for index in addresses.data:
interp.atom_survival_probability[index] *= 1 - stmt.prob
addresses = frame.get(stmt.qargs)

if not isinstance(addresses, AddressReg):
return ()

fidelity = 1 - stmt.prob
interp_.update_fidelities(
interp_.qubit_survival_fidelities, fidelity, addresses
)

return ()
Loading
Loading