Skip to content

Commit 1ddea9b

Browse files
committed
updating doc string
1 parent badf341 commit 1ddea9b

File tree

1 file changed

+38
-37
lines changed

1 file changed

+38
-37
lines changed

src/bloqade/pyqrack/target.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, TypeVar, ParamSpec
1+
from typing import Any, TypeVar, Iterator
22
from dataclasses import field, dataclass
33

44
from kirin import ir
@@ -13,9 +13,6 @@
1313
)
1414
from bloqade.analysis.address import AnyAddress, AddressAnalysis
1515

16-
Params = ParamSpec("Params")
17-
RetType = TypeVar("RetType")
18-
1916

2017
@dataclass
2118
class PyQrack:
@@ -36,7 +33,9 @@ def __post_init__(self):
3633
{**_default_pyqrack_args(), **self.pyqrack_options}
3734
)
3835

39-
def _get_interp(self, mt: ir.Method[Params, RetType]):
36+
RetType = TypeVar("RetType")
37+
38+
def _get_interp(self, mt: ir.Method[..., RetType]):
4039
if self.dynamic_qubits:
4140

4241
options = self.pyqrack_options.copy()
@@ -64,49 +63,51 @@ def _get_interp(self, mt: ir.Method[Params, RetType]):
6463

6564
def run(
6665
self,
67-
mt: ir.Method[Params, RetType],
68-
*args: Params.args,
69-
**kwargs: Params.kwargs,
70-
) -> RetType:
66+
mt: ir.Method[..., RetType],
67+
shots: int = 1,
68+
args: tuple[Any, ...] = (),
69+
kwargs: dict[str, Any] = {},
70+
return_iterator: bool = False,
71+
) -> RetType | list[RetType] | Iterator[RetType]:
7172
"""Run the given kernel method on the PyQrack simulator.
7273
7374
Args
7475
mt (Method):
7576
The kernel method to run.
77+
shots (int):
78+
The number of shots to run the simulation for.
79+
Defaults to 1.
80+
args (tuple[Any, ...]):
81+
Positional arguments to pass to the kernel method.
82+
Defaults to ().
83+
kwargs (dict[str, Any]):
84+
Keyword arguments to pass to the kernel method.
85+
Defaults to {}.
86+
return_iterator (bool):
87+
Whether to return an iterator that yields results for each shot.
88+
Defaults to False. if False, a list of results is returned.
7689
7790
Returns
78-
The result of the kernel method, if any.
79-
80-
"""
81-
fold = Fold(mt.dialects)
82-
fold(mt)
83-
return self._get_interp(mt).run(mt, args, kwargs)
84-
85-
def multi_run(
86-
self,
87-
mt: ir.Method[Params, RetType],
88-
_shots: int,
89-
*args: Params.args,
90-
**kwargs: Params.kwargs,
91-
) -> List[RetType]:
92-
"""Run the given kernel method on the PyQrack `_shots` times, caching analysis results.
93-
94-
Args
95-
mt (Method):
96-
The kernel method to run.
97-
_shots (int):
98-
The number of times to run the kernel method.
99-
100-
Returns
101-
List of results of the kernel method, one for each shot.
91+
RetType | list[RetType] | Iterator[RetType]:
92+
The result of the simulation. If `return_iterator` is True,
93+
an iterator that yields results for each shot is returned.
94+
Otherwise, a list of results is returned if `shots > 1`, or
95+
a single result is returned if `shots == 1`.
10296
10397
"""
10498
fold = Fold(mt.dialects)
10599
fold(mt)
106100

107101
interpreter = self._get_interp(mt)
108-
batched_results = []
109-
for _ in range(_shots):
110-
batched_results.append(interpreter.run(mt, args, kwargs))
111102

112-
return batched_results
103+
def run_shots():
104+
for _ in range(shots):
105+
yield interpreter.run(mt, args, kwargs)
106+
107+
if shots == 1:
108+
return interpreter.run(mt, args, kwargs)
109+
else:
110+
if return_iterator:
111+
return run_shots()
112+
else:
113+
return list(run_shots())

0 commit comments

Comments
 (0)