1+ from typing import Any
12from dataclasses import field
23
34from kirin import ir
67from kirin .interp .value import Successor
78from kirin .analysis .forward import ForwardFrame
89
10+ from ..address import AddressAnalysis
11+
912
1013class FidelityAnalysis (Forward ):
1114 """
@@ -25,23 +28,42 @@ class FidelityAnalysis(Forward):
2528 """
2629 The probability that all atoms survive for the entirety of the analysed program. Decreases whenever an atomic loss channel is encountered.
2730 """
28- atom_survival_probability : float = 1.0
31+ atom_survival_probability : list [float ] = field (init = False )
32+
33+ _current_atom_survival_probability : list [float ] = field (init = False )
2934
30- _current_atom_survival_probability : float = field (init = False )
35+ addr_frame : ForwardFrame = field (init = False )
3136
3237 def initialize (self ):
3338 super ().initialize ()
3439 self ._current_gate_fidelity = 1.0
35- self ._current_atom_survival_probability = 1.0
40+ self ._current_atom_survival_probability = [
41+ 1.0 for _ in range (len (self .atom_survival_probability ))
42+ ]
3643 return self
3744
3845 def posthook_succ (self , frame : ForwardFrame , succ : Successor ):
3946 self .gate_fidelity *= self ._current_gate_fidelity
40- self .atom_survival_probability *= self ._current_atom_survival_probability
47+ for i , _current_survival in enumerate (self ._current_atom_survival_probability ):
48+ self .atom_survival_probability [i ] *= _current_survival
4149
4250 def eval_stmt_fallback (self , frame : ForwardFrame , stmt : ir .Statement ):
4351 # NOTE: default is to conserve fidelity, so do nothing here
4452 return
4553
4654 def run_method (self , method : ir .Method , args : tuple [EmptyLattice , ...]):
4755 return self .run_callable (method .code , (self .lattice .bottom (),) + args )
56+
57+ def run_analysis (
58+ self , method : ir .Method , args : tuple | None = None , * , no_raise : bool = True
59+ ) -> tuple [ForwardFrame , Any ]:
60+ self ._run_address_analysis (method , no_raise = no_raise )
61+ return super ().run_analysis (method , args , no_raise = no_raise )
62+
63+ def _run_address_analysis (self , method : ir .Method , no_raise : bool ):
64+ addr_analysis = AddressAnalysis (self .dialects )
65+ addr_frame , _ = addr_analysis .run_analysis (method = method , no_raise = no_raise )
66+ self .addr_frame = addr_frame
67+
68+ # NOTE: make sure we have as many probabilities as we have addresses
69+ self .atom_survival_probability = [1.0 ] * addr_analysis .qubit_count
0 commit comments