diff --git a/src/bloqade/pyqrack/target.py b/src/bloqade/pyqrack/target.py index 13662315..c55d2117 100644 --- a/src/bloqade/pyqrack/target.py +++ b/src/bloqade/pyqrack/target.py @@ -1,4 +1,4 @@ -from typing import List, TypeVar, ParamSpec +from typing import Any, TypeVar, Iterable from dataclasses import field, dataclass from kirin import ir @@ -13,9 +13,6 @@ ) from bloqade.analysis.address import AnyAddress, AddressAnalysis -Params = ParamSpec("Params") -RetType = TypeVar("RetType") - @dataclass class PyQrack: @@ -36,7 +33,9 @@ def __post_init__(self): {**_default_pyqrack_args(), **self.pyqrack_options} ) - def _get_interp(self, mt: ir.Method[Params, RetType]): + RetType = TypeVar("RetType") + + def _get_interp(self, mt: ir.Method[..., RetType]): if self.dynamic_qubits: options = self.pyqrack_options.copy() @@ -64,49 +63,51 @@ def _get_interp(self, mt: ir.Method[Params, RetType]): def run( self, - mt: ir.Method[Params, RetType], - *args: Params.args, - **kwargs: Params.kwargs, - ) -> RetType: + mt: ir.Method[..., RetType], + *, + shots: int = 1, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] = {}, + return_iterator: bool = False, + ) -> RetType | list[RetType] | Iterable[RetType]: """Run the given kernel method on the PyQrack simulator. Args mt (Method): The kernel method to run. + shots (int): + The number of shots to run the simulation for. + Defaults to 1. + args (tuple[Any, ...]): + Positional arguments to pass to the kernel method. + Defaults to (). + kwargs (dict[str, Any]): + Keyword arguments to pass to the kernel method. + Defaults to {}. + return_iterator (bool): + Whether to return an iterator that yields results for each shot. + Defaults to False. if False, a list of results is returned. Returns - The result of the kernel method, if any. - - """ - fold = Fold(mt.dialects) - fold(mt) - return self._get_interp(mt).run(mt, args, kwargs) - - def multi_run( - self, - mt: ir.Method[Params, RetType], - _shots: int, - *args: Params.args, - **kwargs: Params.kwargs, - ) -> List[RetType]: - """Run the given kernel method on the PyQrack `_shots` times, caching analysis results. - - Args - mt (Method): - The kernel method to run. - _shots (int): - The number of times to run the kernel method. - - Returns - List of results of the kernel method, one for each shot. + RetType | list[RetType] | Iterable[RetType]: + The result of the simulation. If `return_iterator` is True, + an iterator that yields results for each shot is returned. + Otherwise, a list of results is returned if `shots > 1`, or + a single result is returned if `shots == 1`. """ fold = Fold(mt.dialects) fold(mt) interpreter = self._get_interp(mt) - batched_results = [] - for _ in range(_shots): - batched_results.append(interpreter.run(mt, args, kwargs)) - return batched_results + def run_shots(): + for _ in range(shots): + yield interpreter.run(mt, args, kwargs) + + if shots == 1: + return interpreter.run(mt, args, kwargs) + elif return_iterator: + return run_shots() + else: + return list(run_shots()) diff --git a/test/pyqrack/runtime/test_dyn_memory.py b/test/pyqrack/runtime/test_dyn_memory.py index 03e08a7a..c8ab9919 100644 --- a/test/pyqrack/runtime/test_dyn_memory.py +++ b/test/pyqrack/runtime/test_dyn_memory.py @@ -15,8 +15,7 @@ def ghz(n: int): for i in range(1, n): qasm2.cx(q[0], q[i]) - for i in range(n): - qasm2.measure(q[i], c[i]) + qasm2.measure(q, c) return c @@ -27,6 +26,6 @@ def ghz(n: int): N = 20 - result = target.multi_run(ghz, 100, N) + result = target.run(ghz, shots=100, args=(N,)) result = Counter("".join(str(int(bit)) for bit in bits) for bits in result) assert result.keys() == {"0" * N, "1" * N}