@@ -34,13 +34,32 @@ class NoisePass(Pass):
3434 def __post_init__ (self ):
3535 self .address_analysis = address .AddressAnalysis (self .dialects )
3636
37+ def get_qubit_values (self , mt : ir .Method ):
38+ frame , _ = self .address_analysis .run_analysis (mt , no_raise = self .no_raise )
39+ qubit_ssa_values = {}
40+ # Traverse statements in block order to fine the first SSA value for each qubit
41+ for block in mt .callable_region .blocks :
42+ for stmt in block .stmts :
43+ if len (stmt .results ) != 1 :
44+ continue
45+
46+ addr = frame .entries .get (result := stmt .results [0 ])
47+ if (
48+ isinstance (addr , address .AddressQubit )
49+ and (index := addr .data ) not in qubit_ssa_values
50+ ):
51+ qubit_ssa_values [index ] = result
52+
53+ return qubit_ssa_values , frame .entries
54+
3755 def unsafe_run (self , mt : ir .Method ):
3856 result = LiftQubits (self .dialects ).unsafe_run (mt )
39- frame , _ = self .address_analysis . run_analysis (mt , no_raise = self . no_raise )
57+ qubit_ssa_value , address_analysis = self .get_qubit_values (mt )
4058 result = (
4159 Walk (
4260 NoiseRewriteRule (
43- address_analysis = frame .entries ,
61+ qubit_ssa_value = qubit_ssa_value ,
62+ address_analysis = address_analysis ,
4463 noise_model = self .noise_model ,
4564 gate_noise_params = self .gate_noise_params ,
4665 ),
@@ -49,5 +68,6 @@ def unsafe_run(self, mt: ir.Method):
4968 .rewrite (mt .code )
5069 .join (result )
5170 )
71+
5272 result = Fixpoint (Walk (DeadCodeElimination ())).rewrite (mt .code ).join (result )
5373 return result
0 commit comments