1- from typing import Any , TypeVar , ParamSpec
1+ from typing import Any , Generic , TypeVar , ParamSpec , cast
22from dataclasses import field , dataclass
33
44import numpy as np
55from kirin import ir
6+ from kirin .dialects import py , func
67
8+ from bloqade .noise import native
79from pyqrack .pauli import Pauli
810from bloqade .device import AbstractSimulatorDevice
911from bloqade .pyqrack .reg import Measurement , PyQrackQubit
1012from bloqade .pyqrack .base import (
11- MemoryABC ,
1213 StackMemory ,
1314 DynamicMemory ,
1415 PyQrackOptions ,
1516 PyQrackInterpreter ,
1617 _default_pyqrack_args ,
1718)
18- from bloqade .pyqrack .task import PyQrackSimulatorTask
19+ from bloqade .pyqrack .task import PyQrackSimulatorTask , PyQrackNoiseSimulatorTask
20+ from bloqade .qasm2 .passes import NoisePass , QASM2Fold , UOpToParallel
21+ from bloqade .analysis .fidelity import FidelityAnalysis
1922from bloqade .analysis .address .lattice import AnyAddress
2023from bloqade .analysis .address .analysis import AddressAnalysis
2124
2225RetType = TypeVar ("RetType" )
2326Params = ParamSpec ("Params" )
2427
28+ PyQrackSimulatorTaskType = TypeVar (
29+ "PyQrackSimulatorTaskType" ,
30+ bound = PyQrackSimulatorTask ,
31+ )
32+
2533
2634@dataclass
27- class PyQrackSimulatorBase (AbstractSimulatorDevice [PyQrackSimulatorTask ]):
35+ class PyQrackSimulatorBase (AbstractSimulatorDevice [PyQrackSimulatorTaskType ]):
2836 options : PyQrackOptions = field (default_factory = _default_pyqrack_args )
2937 loss_m_result : Measurement = field (default = Measurement .One , kw_only = True )
3038 rng_state : np .random .Generator = field (
3139 default_factory = np .random .default_rng , kw_only = True
3240 )
3341
34- MemoryType = TypeVar ("MemoryType" , bound = MemoryABC )
35-
3642 def __post_init__ (self ):
3743 self .options = PyQrackOptions ({** _default_pyqrack_args (), ** self .options })
3844
39- def new_task (
40- self ,
41- mt : ir .Method [Params , RetType ],
42- args : tuple [Any , ...],
43- kwargs : dict [str , Any ],
44- memory : MemoryType ,
45- ) -> PyQrackSimulatorTask [Params , RetType , MemoryType ]:
46- interp = PyQrackInterpreter (
47- mt .dialects ,
48- memory = memory ,
49- rng_state = self .rng_state ,
50- loss_m_result = self .loss_m_result ,
51- )
52- return PyQrackSimulatorTask (
53- kernel = mt , args = args , kwargs = kwargs , pyqrack_interp = interp
54- )
55-
5645 def state_vector (
5746 self ,
5847 kernel : ir .Method [Params , RetType ],
@@ -98,7 +87,7 @@ def pauli_expectation(pauli: list[Pauli], qubits: list[PyQrackQubit]) -> float:
9887
9988
10089@dataclass
101- class StackMemorySimulator (PyQrackSimulatorBase ):
90+ class StackMemorySimulator (PyQrackSimulatorBase [ PyQrackSimulatorTask ] ):
10291 """PyQrack simulator device with precalculated stack of qubits."""
10392
10493 min_qubits : int = field (default = 0 , kw_only = True )
@@ -129,11 +118,20 @@ def task(
129118 total = num_qubits ,
130119 )
131120
132- return self .new_task (kernel , args , kwargs , memory )
121+ pyqrack_interp = PyQrackInterpreter (
122+ kernel .dialects ,
123+ memory = memory ,
124+ rng_state = self .rng_state ,
125+ loss_m_result = self .loss_m_result ,
126+ )
127+
128+ return PyQrackSimulatorTask (
129+ kernel = kernel , args = args , kwargs = kwargs , pyqrack_interp = pyqrack_interp
130+ )
133131
134132
135133@dataclass
136- class DynamicMemorySimulator (PyQrackSimulatorBase ):
134+ class DynamicMemorySimulator (PyQrackSimulatorBase [ PyQrackSimulatorTask ] ):
137135 """PyQrack simulator device with dynamic qubit allocation."""
138136
139137 def task (
@@ -145,20 +143,114 @@ def task(
145143 if kwargs is None :
146144 kwargs = {}
147145
148- memory = DynamicMemory (self .options .copy ())
149- return self .new_task (kernel , args , kwargs , memory )
146+ pyqrack_interp = PyQrackInterpreter (
147+ kernel .dialects ,
148+ memory = DynamicMemory (self .options .copy ()),
149+ rng_state = self .rng_state ,
150+ loss_m_result = self .loss_m_result ,
151+ )
152+
153+ return PyQrackSimulatorTask (
154+ kernel = kernel ,
155+ args = args ,
156+ kwargs = kwargs ,
157+ pyqrack_interp = pyqrack_interp ,
158+ )
159+
160+
161+ def arg_closure (
162+ kernel : ir .Method [Params , RetType ], args : tuple [Any , ...], kwargs : dict [str , Any ]
163+ ) -> ir .Method [..., RetType ]:
164+ """Create a closure for the arguments of the kernel."""
165+
166+ func_body = ir .Region (block := ir .Block ())
167+ inputs : list [ir .ResultValue ] = []
168+ for arg in args :
169+ block .stmts .append (const_stmt := py .Constant (arg ))
170+ inputs .append (const_stmt .result )
171+
172+ kw_names : list [str ] = []
173+ for key , value in kwargs .items ():
174+ block .stmts .append (const_stmt := py .Constant (value ))
175+ kw_names .append (key )
176+ inputs .append (const_stmt .result )
177+
178+ block .stmts .append (
179+ invoke_stmt := func .Invoke (
180+ inputs = tuple (inputs ),
181+ callee = kernel ,
182+ kwargs = tuple (kw_names ),
183+ purity = False ,
184+ )
185+ )
186+ block .stmts .append (func .Return (invoke_stmt .result ))
187+
188+ code = func .Function (
189+ sym_name = "closure" ,
190+ signature = func .Signature ((), kernel .return_type ),
191+ body = func_body ,
192+ )
193+ return ir .Method (None , None , "closure" , [], kernel .dialects , code )
194+
195+
196+ NoiseModelType = TypeVar ("NoiseModelType" , bound = native .MoveNoiseModelABC )
197+
198+
199+ @dataclass
200+ class NoiseSimulatorBase (
201+ PyQrackSimulatorBase [PyQrackNoiseSimulatorTask ], Generic [NoiseModelType ]
202+ ):
203+ noise_model : NoiseModelType = field (default_factory = native .TwoRowZoneModel )
204+ gate_noise_params : native .GateNoiseParams = field (
205+ default_factory = native .GateNoiseParams
206+ )
207+ optimize_parallel_gates : bool = field (default = True , kw_only = True )
208+ decompose_native_gates : bool = field (default = True , kw_only = True )
209+
210+ def task (
211+ self ,
212+ kernel : ir .Method [Params , RetType ],
213+ args : tuple [Any , ...] = (),
214+ kwargs : dict [str , Any ] | None = None ,
215+ ):
216+ if kwargs is None :
217+ kwargs = {}
218+
219+ if len (args ) > 0 or len (kwargs ) > 0 :
220+ folded_kernel = arg_closure (kernel , args , kwargs )
221+ args = ()
222+ kwargs = {}
223+ else :
224+ folded_kernel = cast (ir .Method [..., RetType ], kernel )
150225
226+ QASM2Fold (folded_kernel .dialects ).fixpoint (folded_kernel )
151227
152- def test () :
153- from bloqade . qasm2 import extended
228+ if self . optimize_parallel_gates :
229+ UOpToParallel ( folded_kernel . dialects )( folded_kernel )
154230
155- @extended
156- def main ():
157- return 1
231+ if native .dialect not in folded_kernel .dialects :
232+ noise_pass = NoisePass (
233+ kernel .dialects ,
234+ self .noise_model ,
235+ self .gate_noise_params ,
236+ )
158237
159- @extended
160- def obs (result : int ) -> int :
161- return result
238+ noise_pass (folded_kernel )
239+ folded_kernel = folded_kernel .similar (
240+ folded_kernel .dialects .add (native .dialect )
241+ )
242+
243+ pyqrack_interp = PyQrackInterpreter (
244+ folded_kernel .dialects ,
245+ memory = DynamicMemory (self .options .copy ()),
246+ rng_state = self .rng_state ,
247+ loss_m_result = self .loss_m_result ,
248+ )
162249
163- res = DynamicMemorySimulator ().task (main )
164- return res .run ()
250+ return PyQrackNoiseSimulatorTask (
251+ kernel = folded_kernel ,
252+ args = args ,
253+ kwargs = kwargs ,
254+ pyqrack_interp = pyqrack_interp ,
255+ fidelity_scorer = FidelityAnalysis (kernel .dialects ),
256+ )
0 commit comments