1- import random
2- import typing
3- from functools import cached_property
4- from dataclasses import dataclass
5-
61from kirin import interp
7- from kirin .dialects import ilist
82
9- from bloqade .pyqrack import QubitState , PyQrackQubit , PyQrackInterpreter
3+ from bloqade .pyqrack import PyQrackQubit , PyQrackInterpreter
104from bloqade .squin .noise .stmts import (
115 QubitLoss ,
126 Depolarize ,
13- PauliError ,
147 Depolarize2 ,
158 TwoQubitPauliChannel ,
169 SingleQubitPauliChannel ,
17- StochasticUnitaryChannel ,
1810)
1911from bloqade .squin .noise ._dialect import dialect as squin_noise_dialect
2012
21- from ..runtime import KronRuntime , IdentityRuntime , OperatorRuntime , OperatorRuntimeABC
22-
23-
24- @dataclass (frozen = True )
25- class StochasticUnitaryChannelRuntime (OperatorRuntimeABC ):
26- operators : (
27- ilist .IList [OperatorRuntimeABC , typing .Any ] | tuple [OperatorRuntimeABC , ...]
28- )
29- probabilities : ilist .IList [float , typing .Any ] | tuple [float , ...]
30-
31- @property
32- def n_sites (self ) -> int :
33- n = self .operators [0 ].n_sites
34- for op in self .operators [1 :]:
35- assert (
36- op .n_sites == n
37- ), "Encountered a stochastic unitary channel with operators of different size!"
38- return n
39-
40- def apply (self , * qubits : PyQrackQubit , adjoint : bool = False ) -> None :
41- # NOTE: probabilities don't necessarily sum to 1; could be no noise event should occur
42- p_no_op = 1 - sum (self .probabilities )
43- if random .uniform (0.0 , 1.0 ) < p_no_op :
44- return
45-
46- selected_ops = random .choices (self .operators , weights = self .probabilities )
47- for op in selected_ops :
48- op .apply (* qubits , adjoint = adjoint )
49-
50-
51- @dataclass (frozen = True )
52- class QubitLossRuntime (OperatorRuntimeABC ):
53- p : float
54-
55- @property
56- def n_sites (self ) -> int :
57- return 1
58-
59- def apply (self , qubit : PyQrackQubit , adjoint : bool = False ) -> None :
60- if random .uniform (0.0 , 1.0 ) <= self .p :
61- qubit .state = QubitState .Lost
62-
6313
6414@squin_noise_dialect .register (key = "pyqrack" )
6515class PyQrackMethods (interp .MethodTable ):
66- @interp .impl (PauliError )
67- def pauli_error (
68- self , interp : PyQrackInterpreter , frame : interp .Frame , stmt : PauliError
69- ):
70- op = frame .get (stmt .basis )
71- p = frame .get (stmt .p )
72- return (StochasticUnitaryChannelRuntime ((op ,), (p ,)),)
16+
17+ single_pauli_choices = ("i" , "x" , "y" , "z" )
18+ two_pauli_choices = (
19+ "ii" ,
20+ "ix" ,
21+ "iy" ,
22+ "iz" ,
23+ "xi" ,
24+ "xx" ,
25+ "xy" ,
26+ "xz" ,
27+ "yi" ,
28+ "yx" ,
29+ "yy" ,
30+ "yz" ,
31+ "zi" ,
32+ "zx" ,
33+ "zy" ,
34+ "zz" ,
35+ )
7336
7437 @interp .impl (Depolarize )
7538 def depolarize (
7639 self , interp : PyQrackInterpreter , frame : interp .Frame , stmt : Depolarize
7740 ):
7841 p = frame .get (stmt .p )
79- ps = ( p / 3.0 ,) * 3
80- ops = self . single_qubit_paulis
81- return ( StochasticUnitaryChannelRuntime ( ops , ps ), )
42+ ps = [ p / 3.0 ] * 3
43+ qubits = frame . get ( stmt . qubits )
44+ self . apply_single_qubit_pauli_error ( interp , ps , qubits )
8245
8346 @interp .impl (Depolarize2 )
8447 def depolarize2 (
8548 self , interp : PyQrackInterpreter , frame : interp .Frame , stmt : Depolarize2
8649 ):
8750 p = frame .get (stmt .p )
88- ps = (p / 15.0 ,) * 15
89- ops = self .two_qubit_paulis
90- return (StochasticUnitaryChannelRuntime (ops , ps ),)
51+ ps = [p / 15.0 ] * 15
52+ controls = frame .get (stmt .controls )
53+ targets = frame .get (stmt .targets )
54+ self .apply_two_qubit_pauli_error (interp , ps , controls , targets )
9155
9256 @interp .impl (SingleQubitPauliChannel )
9357 def single_qubit_pauli_channel (
@@ -96,9 +60,11 @@ def single_qubit_pauli_channel(
9660 frame : interp .Frame ,
9761 stmt : SingleQubitPauliChannel ,
9862 ):
99- ps = frame .get (stmt .params )
100- ops = self .single_qubit_paulis
101- return (StochasticUnitaryChannelRuntime (ops , ps ),)
63+ px = frame .get (stmt .px )
64+ py = frame .get (stmt .py )
65+ pz = frame .get (stmt .pz )
66+ qubits = frame .get (stmt .qubits )
67+ self .apply_single_qubit_pauli_error (interp , [px , py , pz ], qubits )
10268
10369 @interp .impl (TwoQubitPauliChannel )
10470 def two_qubit_pauli_channel (
@@ -107,43 +73,54 @@ def two_qubit_pauli_channel(
10773 frame : interp .Frame ,
10874 stmt : TwoQubitPauliChannel ,
10975 ):
110- ps = frame .get (stmt .params )
111- ops = self .two_qubit_paulis
112- return (StochasticUnitaryChannelRuntime (ops , ps ),)
113-
114- @interp .impl (StochasticUnitaryChannel )
115- def stochastic_unitary_channel (
116- self ,
117- interp : PyQrackInterpreter ,
118- frame : interp .Frame ,
119- stmt : StochasticUnitaryChannel ,
120- ):
121- operators = frame .get (stmt .operators )
122- probabilities = frame .get (stmt .probabilities )
123-
124- return (StochasticUnitaryChannelRuntime (operators , probabilities ),)
76+ ps = frame .get (stmt .probabilities )
77+ controls = frame .get (stmt .controls )
78+ targets = frame .get (stmt .targets )
79+ self .apply_two_qubit_pauli_error (interp , ps , controls , targets )
12580
12681 @interp .impl (QubitLoss )
12782 def qubit_loss (
12883 self , interp : PyQrackInterpreter , frame : interp .Frame , stmt : QubitLoss
12984 ):
13085 p = frame .get (stmt .p )
131- return (QubitLossRuntime (p ),)
86+ qubits : list [PyQrackQubit ] = frame .get (stmt .qubits )
87+ for qbit in qubits :
88+ if interp .rng_state .uniform (0.0 , 1.0 ) <= p :
89+ qbit .drop ()
13290
133- @cached_property
134- def single_qubit_paulis (self ):
135- return (OperatorRuntime ("x" ), OperatorRuntime ("y" ), OperatorRuntime ("z" ))
91+ def apply_single_qubit_pauli_error (
92+ self ,
93+ interp : PyQrackInterpreter ,
94+ ps : list [float ],
95+ qubits : list [PyQrackQubit ],
96+ ):
97+ pi = 1 - sum (ps )
98+ probs = [pi ] + ps
13699
137- @cached_property
138- def two_qubit_paulis (self ):
139- paulis = (IdentityRuntime (sites = 1 ), * self .single_qubit_paulis )
140- ops : list [KronRuntime ] = []
141- for idx1 , pauli1 in enumerate (paulis ):
142- for idx2 , pauli2 in enumerate (paulis ):
143- if idx1 == idx2 == 0 :
144- # NOTE: 'II'
145- continue
100+ assert all (0 <= x <= 1 for x in probs ), "Invalid Pauli error probabilities"
146101
147- ops .append (KronRuntime (pauli1 , pauli2 ))
102+ for qbit in qubits :
103+ which = interp .rng_state .choice (self .single_pauli_choices , p = probs )
104+ self .apply_pauli_error (which , qbit )
105+
106+ def apply_two_qubit_pauli_error (
107+ self ,
108+ interp : PyQrackInterpreter ,
109+ ps : list [float ],
110+ controls : list [PyQrackQubit ],
111+ targets : list [PyQrackQubit ],
112+ ):
113+ pii = 1 - sum (ps )
114+ probs = [pii ] + ps
115+ assert all (0 <= x <= 1 for x in probs ), "Invalid Pauli error probabilities"
116+
117+ for control , target in zip (controls , targets ):
118+ which = interp .rng_state .choice (self .two_pauli_choices , p = probs )
119+ self .apply_pauli_error (which [0 ], control )
120+ self .apply_pauli_error (which [1 ], target )
121+
122+ def apply_pauli_error (self , which : str , qbit : PyQrackQubit ):
123+ if not qbit .is_active () or which == "i" :
124+ return
148125
149- return tuple ( ops )
126+ getattr ( qbit . sim_reg , which )( qbit . addr )
0 commit comments