Skip to content

Commit d56df5f

Browse files
weinbe58Roger-luo
andauthored
Task Interface (#228)
Addresses QuEraComputing/bloqade#225 Currently looking for rfc: QuEraComputing/bloqade#234 --------- Co-authored-by: Xiu-zhe (Roger) Luo <[email protected]> Co-authored-by: Roger-luo <[email protected]>
1 parent 893a9e9 commit d56df5f

File tree

9 files changed

+455
-15
lines changed

9 files changed

+455
-15
lines changed

src/bloqade/device.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import abc
2+
from typing import Any, Generic, TypeVar, ParamSpec
3+
4+
from kirin import ir
5+
6+
from bloqade.task import (
7+
BatchFuture,
8+
AbstractTask,
9+
AbstractRemoteTask,
10+
AbstractSimulatorTask,
11+
DeviceTaskExpectMixin,
12+
)
13+
14+
Params = ParamSpec("Params")
15+
RetType = TypeVar("RetType")
16+
ObsType = TypeVar("ObsType")
17+
18+
19+
TaskType = TypeVar("TaskType", bound=AbstractTask)
20+
21+
22+
class AbstractDevice(abc.ABC, Generic[TaskType]):
23+
"""Abstract base class for devices. Defines the minimum interface for devices."""
24+
25+
@abc.abstractmethod
26+
def task(
27+
self,
28+
kernel: ir.Method[Params, RetType],
29+
args: tuple[Any, ...] = (),
30+
kwargs: dict[str, Any] | None = None,
31+
) -> TaskType:
32+
"""Creates a remote task for the device."""
33+
34+
35+
ExpectTaskType = TypeVar("ExpectTaskType", bound=DeviceTaskExpectMixin)
36+
37+
38+
class ExpectationDeviceMixin(AbstractDevice[ExpectTaskType]):
39+
def expect(
40+
self,
41+
kernel: ir.Method[Params, RetType],
42+
observable: ir.Method[[RetType], ObsType],
43+
args: tuple[Any, ...] = (),
44+
kwargs: dict[str, Any] | None = None,
45+
*,
46+
shots: int = 1,
47+
) -> ObsType:
48+
"""Returns the expectation value of the given observable after running the task."""
49+
return self.task(kernel, args, kwargs).expect(observable, shots)
50+
51+
52+
RemoteTaskType = TypeVar("RemoteTaskType", bound=AbstractRemoteTask)
53+
54+
55+
class AbstractRemoteDevice(AbstractDevice[RemoteTaskType]):
56+
"""Abstract base class for remote devices."""
57+
58+
def run(
59+
self,
60+
kernel: ir.Method[Params, RetType],
61+
args: tuple[Any, ...] = (),
62+
kwargs: dict[str, Any] | None = None,
63+
*,
64+
shots: int = 1,
65+
timeout: float | None = None,
66+
) -> list[RetType]:
67+
"""Runs the kernel and returns the result.
68+
69+
Args:
70+
kernel (ir.Method):
71+
The kernel method to run.
72+
args (tuple[Any, ...]):
73+
Positional arguments to pass to the kernel method.
74+
kwargs (dict[str, Any] | None):
75+
Keyword arguments to pass to the kernel method.
76+
shots (int):
77+
The number of times to run the kernel method.
78+
timeout (float | None):
79+
Timeout in seconds for the asynchronous execution. If None, wait indefinitely.
80+
81+
Returns:
82+
list[RetType]:
83+
The result of the kernel method, if any.
84+
85+
"""
86+
return self.task(kernel, args, kwargs).run(shots=shots, timeout=timeout)
87+
88+
def run_async(
89+
self,
90+
kernel: ir.Method[Params, RetType],
91+
args: tuple[Any, ...] = (),
92+
kwargs: dict[str, Any] | None = None,
93+
*,
94+
shots: int = 1,
95+
) -> BatchFuture[RetType]:
96+
"""Runs the kernel asynchronously and returns a Future object.
97+
98+
Args:
99+
kernel (ir.Method):
100+
The kernel method to run.
101+
args (tuple[Any, ...]):
102+
Positional arguments to pass to the kernel method.
103+
kwargs (dict[str, Any] | None):
104+
Keyword arguments to pass to the kernel method.
105+
shots (int):
106+
The number of times to run the kernel method.
107+
108+
Returns:
109+
Future[list[RetType]]:
110+
The Future for all executions of the kernel method.
111+
112+
113+
"""
114+
return self.task(kernel, args, kwargs).run_async(shots=shots)
115+
116+
117+
SimulatorTaskType = TypeVar("SimulatorTaskType", bound=AbstractSimulatorTask)
118+
119+
120+
class AbstractSimulatorDevice(AbstractDevice[SimulatorTaskType]):
121+
"""Abstract base class for simulator devices."""
122+
123+
def run(
124+
self,
125+
kernel: ir.Method[Params, RetType],
126+
args: tuple[Any, ...] = (),
127+
kwargs: dict[str, Any] | None = None,
128+
) -> RetType:
129+
"""Runs the kernel and returns the result."""
130+
return self.task(kernel, args, kwargs).run()

src/bloqade/pyqrack/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
DynamicMemory as DynamicMemory,
1212
PyQrackInterpreter as PyQrackInterpreter,
1313
)
14+
from .task import PyQrackSimulatorTask as PyQrackSimulatorTask
1415

1516
# NOTE: The following import is for registering the method tables
1617
from .noise import native as native
1718
from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
1819
from .squin import op as op, qubit as qubit
20+
from .device import (
21+
StackMemorySimulator as StackMemorySimulator,
22+
DynamicMemorySimulator as DynamicMemorySimulator,
23+
)
1924
from .target import PyQrack as PyQrack

src/bloqade/pyqrack/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ def allocate(self, n_qubits: int):
133133
return tuple(range(start, start + n_qubits))
134134

135135

136+
MemoryType = typing.TypeVar("MemoryType", bound=MemoryABC)
137+
138+
136139
@dataclass
137-
class PyQrackInterpreter(Interpreter):
140+
class PyQrackInterpreter(Interpreter, typing.Generic[MemoryType]):
138141
keys = ["pyqrack", "main"]
139-
memory: MemoryABC = field(kw_only=True)
142+
memory: MemoryType = field(kw_only=True)
140143
rng_state: np.random.Generator = field(
141144
default_factory=np.random.default_rng, kw_only=True
142145
)

src/bloqade/pyqrack/device.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Any, TypeVar, ParamSpec
2+
from dataclasses import field, dataclass
3+
4+
import numpy as np
5+
from kirin import ir
6+
7+
from pyqrack.pauli import Pauli
8+
from bloqade.device import AbstractSimulatorDevice
9+
from bloqade.pyqrack.reg import Measurement, PyQrackQubit
10+
from bloqade.pyqrack.base import (
11+
MemoryABC,
12+
StackMemory,
13+
DynamicMemory,
14+
PyQrackOptions,
15+
PyQrackInterpreter,
16+
_default_pyqrack_args,
17+
)
18+
from bloqade.pyqrack.task import PyQrackSimulatorTask
19+
from bloqade.analysis.address.lattice import AnyAddress
20+
from bloqade.analysis.address.analysis import AddressAnalysis
21+
22+
RetType = TypeVar("RetType")
23+
Params = ParamSpec("Params")
24+
25+
26+
@dataclass
27+
class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
28+
options: PyQrackOptions = field(default_factory=_default_pyqrack_args)
29+
loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
30+
rng_state: np.random.Generator = field(
31+
default_factory=np.random.default_rng, kw_only=True
32+
)
33+
34+
MemoryType = TypeVar("MemoryType", bound=MemoryABC)
35+
36+
def __post_init__(self):
37+
self.options = PyQrackOptions({**_default_pyqrack_args(), **self.options})
38+
39+
def new_task(
40+
self,
41+
mt: ir.Method[Params, RetType],
42+
args: tuple[Any, ...],
43+
kwargs: dict[str, Any],
44+
memory: MemoryType,
45+
) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
46+
interp = PyQrackInterpreter(
47+
mt.dialects,
48+
memory=memory,
49+
rng_state=self.rng_state,
50+
loss_m_result=self.loss_m_result,
51+
)
52+
return PyQrackSimulatorTask(
53+
kernel=mt, args=args, kwargs=kwargs, pyqrack_interp=interp
54+
)
55+
56+
def state_vector(
57+
self,
58+
kernel: ir.Method[Params, RetType],
59+
args: tuple[Any, ...] = (),
60+
kwargs: dict[str, Any] | None = None,
61+
) -> list[complex]:
62+
"""Runs task and returns the state vector."""
63+
task = self.task(kernel, args, kwargs)
64+
task.run()
65+
return task.state.sim_reg.out_ket()
66+
67+
@staticmethod
68+
def pauli_expectation(pauli: list[Pauli], qubits: list[PyQrackQubit]) -> float:
69+
"""Returns the expectation value of the given Pauli operator given a list of Pauli operators and qubits.
70+
71+
Args:
72+
pauli (list[Pauli]):
73+
List of Pauli operators to compute the expectation value for.
74+
qubits (list[PyQrackQubit]):
75+
List of qubits corresponding to the Pauli operators.
76+
77+
returns:
78+
float:
79+
The expectation value of the Pauli operator.
80+
81+
"""
82+
83+
if len(pauli) == 0:
84+
return 0.0
85+
86+
if len(pauli) != len(qubits):
87+
raise ValueError("Length of Pauli and qubits must match.")
88+
89+
sim_reg = qubits[0].sim_reg
90+
91+
if any(qubit.sim_reg is not sim_reg for qubit in qubits):
92+
raise ValueError("All qubits must belong to the same simulator register.")
93+
94+
qubit_ids = [qubit.addr for qubit in qubits]
95+
96+
if len(qubit_ids) != len(set(qubit_ids)):
97+
raise ValueError("Qubits must be unique.")
98+
99+
return sim_reg.pauli_expectation(pauli, qubit_ids)
100+
101+
102+
@dataclass
103+
class StackMemorySimulator(PyQrackSimulatorBase):
104+
"""PyQrack simulator device with precalculated stack of qubits."""
105+
106+
min_qubits: int = field(default=0, kw_only=True)
107+
108+
def task(
109+
self,
110+
kernel: ir.Method[Params, RetType],
111+
args: tuple[Any, ...] = (),
112+
kwargs: dict[str, Any] | None = None,
113+
):
114+
if kwargs is None:
115+
kwargs = {}
116+
117+
address_analysis = AddressAnalysis(dialects=kernel.dialects)
118+
frame, _ = address_analysis.run_analysis(kernel)
119+
if self.min_qubits == 0 and any(
120+
isinstance(a, AnyAddress) for a in frame.entries.values()
121+
):
122+
raise ValueError(
123+
"All addresses must be resolved. Or set min_qubits to a positive integer."
124+
)
125+
126+
num_qubits = max(address_analysis.qubit_count, self.min_qubits)
127+
options = self.options.copy()
128+
options["qubitCount"] = num_qubits
129+
memory = StackMemory(
130+
options,
131+
total=num_qubits,
132+
)
133+
134+
return self.new_task(kernel, args, kwargs, memory)
135+
136+
137+
@dataclass
138+
class DynamicMemorySimulator(PyQrackSimulatorBase):
139+
"""PyQrack simulator device with dynamic qubit allocation."""
140+
141+
def task(
142+
self,
143+
kernel: ir.Method[Params, RetType],
144+
args: tuple[Any, ...] = (),
145+
kwargs: dict[str, Any] | None = None,
146+
):
147+
if kwargs is None:
148+
kwargs = {}
149+
150+
memory = DynamicMemory(self.options.copy())
151+
return self.new_task(kernel, args, kwargs, memory)
152+
153+
154+
def test():
155+
from bloqade.qasm2 import extended
156+
157+
@extended
158+
def main():
159+
return 1
160+
161+
@extended
162+
def obs(result: int) -> int:
163+
return result
164+
165+
res = DynamicMemorySimulator().task(main)
166+
return res.run()

src/bloqade/pyqrack/target.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, TypeVar, ParamSpec
2+
from warnings import warn
23
from dataclasses import field, dataclass
34

45
from kirin import ir
@@ -32,6 +33,12 @@ class PyQrack:
3233
"""Options to pass to the QrackSimulator object, node `qubitCount` will be overwritten."""
3334

3435
def __post_init__(self):
36+
warn(
37+
"The PyQrack target is deprecated and will be removed "
38+
"in a future release. Please use the DynamicMemorySimulator / "
39+
"StackMemorySimulator instead."
40+
)
41+
3542
self.pyqrack_options = PyQrackOptions(
3643
{**_default_pyqrack_args(), **self.pyqrack_options}
3744
)

src/bloqade/pyqrack/task.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import TypeVar, ParamSpec
2+
from dataclasses import dataclass
3+
4+
from bloqade.task import AbstractSimulatorTask
5+
from bloqade.pyqrack.base import (
6+
MemoryABC,
7+
PyQrackInterpreter,
8+
)
9+
10+
RetType = TypeVar("RetType")
11+
Param = ParamSpec("Param")
12+
MemoryType = TypeVar("MemoryType", bound=MemoryABC)
13+
14+
15+
@dataclass
16+
class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]):
17+
"""PyQrack simulator task for Bloqade."""
18+
19+
pyqrack_interp: PyQrackInterpreter[MemoryType]
20+
21+
def run(self) -> RetType:
22+
return self.pyqrack_interp.run(
23+
self.kernel,
24+
args=self.args,
25+
kwargs=self.kwargs,
26+
)
27+
28+
@property
29+
def state(self) -> MemoryType:
30+
return self.pyqrack_interp.memory

0 commit comments

Comments
 (0)