1- from typing import List , TypeVar , ParamSpec
1+ from typing import Any , TypeVar , Iterator
22from dataclasses import field , dataclass
33
44from kirin import ir
1313)
1414from bloqade .analysis .address import AnyAddress , AddressAnalysis
1515
16- Params = ParamSpec ("Params" )
17- RetType = TypeVar ("RetType" )
18-
1916
2017@dataclass
2118class PyQrack :
@@ -36,7 +33,9 @@ def __post_init__(self):
3633 {** _default_pyqrack_args (), ** self .pyqrack_options }
3734 )
3835
39- def _get_interp (self , mt : ir .Method [Params , RetType ]):
36+ RetType = TypeVar ("RetType" )
37+
38+ def _get_interp (self , mt : ir .Method [..., RetType ]):
4039 if self .dynamic_qubits :
4140
4241 options = self .pyqrack_options .copy ()
@@ -64,49 +63,51 @@ def _get_interp(self, mt: ir.Method[Params, RetType]):
6463
6564 def run (
6665 self ,
67- mt : ir .Method [Params , RetType ],
68- * args : Params .args ,
69- ** kwargs : Params .kwargs ,
70- ) -> RetType :
66+ mt : ir .Method [..., RetType ],
67+ shots : int = 1 ,
68+ args : tuple [Any , ...] = (),
69+ kwargs : dict [str , Any ] = {},
70+ return_iterator : bool = False ,
71+ ) -> RetType | list [RetType ] | Iterator [RetType ]:
7172 """Run the given kernel method on the PyQrack simulator.
7273
7374 Args
7475 mt (Method):
7576 The kernel method to run.
77+ shots (int):
78+ The number of shots to run the simulation for.
79+ Defaults to 1.
80+ args (tuple[Any, ...]):
81+ Positional arguments to pass to the kernel method.
82+ Defaults to ().
83+ kwargs (dict[str, Any]):
84+ Keyword arguments to pass to the kernel method.
85+ Defaults to {}.
86+ return_iterator (bool):
87+ Whether to return an iterator that yields results for each shot.
88+ Defaults to False. if False, a list of results is returned.
7689
7790 Returns
78- The result of the kernel method, if any.
79-
80- """
81- fold = Fold (mt .dialects )
82- fold (mt )
83- return self ._get_interp (mt ).run (mt , args , kwargs )
84-
85- def multi_run (
86- self ,
87- mt : ir .Method [Params , RetType ],
88- _shots : int ,
89- * args : Params .args ,
90- ** kwargs : Params .kwargs ,
91- ) -> List [RetType ]:
92- """Run the given kernel method on the PyQrack `_shots` times, caching analysis results.
93-
94- Args
95- mt (Method):
96- The kernel method to run.
97- _shots (int):
98- The number of times to run the kernel method.
99-
100- Returns
101- List of results of the kernel method, one for each shot.
91+ RetType | list[RetType] | Iterator[RetType]:
92+ The result of the simulation. If `return_iterator` is True,
93+ an iterator that yields results for each shot is returned.
94+ Otherwise, a list of results is returned if `shots > 1`, or
95+ a single result is returned if `shots == 1`.
10296
10397 """
10498 fold = Fold (mt .dialects )
10599 fold (mt )
106100
107101 interpreter = self ._get_interp (mt )
108- batched_results = []
109- for _ in range (_shots ):
110- batched_results .append (interpreter .run (mt , args , kwargs ))
111102
112- return batched_results
103+ def run_shots ():
104+ for _ in range (shots ):
105+ yield interpreter .run (mt , args , kwargs )
106+
107+ if shots == 1 :
108+ return interpreter .run (mt , args , kwargs )
109+ else :
110+ if return_iterator :
111+ return run_shots ()
112+ else :
113+ return list (run_shots ())
0 commit comments