diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index d2a29f326..b05c5dc76 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -58,7 +58,7 @@ class AwsDeviceType(str, Enum): QPU = "QPU" -class AwsDevice(Device): +class AwsDevice(Device[AwsQuantumTask]): """Amazon Braket implementation of a device. Use this class to retrieve the latest metadata about the device and to run a quantum task on the device. diff --git a/src/braket/aws/aws_quantum_task_batch.py b/src/braket/aws/aws_quantum_task_batch.py index 2c8a81692..11c54342b 100644 --- a/src/braket/aws/aws_quantum_task_batch.py +++ b/src/braket/aws/aws_quantum_task_batch.py @@ -33,7 +33,7 @@ from braket.tasks.quantum_task_batch import QuantumTaskBatch -class AwsQuantumTaskBatch(QuantumTaskBatch): +class AwsQuantumTaskBatch(QuantumTaskBatch[AwsQuantumTask]): """Executes a batch of quantum tasks in parallel. Using this class can yield vast speedups over executing quantum tasks sequentially, diff --git a/src/braket/devices/device.py b/src/braket/devices/device.py index f84d53e94..bf29ef0c5 100644 --- a/src/braket/devices/device.py +++ b/src/braket/devices/device.py @@ -15,18 +15,18 @@ import warnings from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Generic from braket.device_schema import DeviceActionType from braket.circuits import Circuit, Noise from braket.circuits.noise_model import NoiseModel from braket.circuits.translations import SUPPORTED_NOISE_PRAGMA_TO_NOISE -from braket.tasks.quantum_task import QuantumTask, TaskSpecification +from braket.tasks.quantum_task import QuantumTaskType, TaskSpecification from braket.tasks.quantum_task_batch import QuantumTaskBatch -class Device(ABC): +class Device(ABC, Generic[QuantumTaskType]): """An abstraction over quantum devices that includes quantum computers and simulators.""" def __init__(self, name: str, status: str): @@ -47,7 +47,7 @@ def run( inputs: dict[str, float] | None, *args, **kwargs, - ) -> QuantumTask: + ) -> QuantumTaskType: """Run a quantum task specification on this quantum device. A quantum task can be a circuit or an annealing problem. @@ -75,7 +75,7 @@ def run_batch( inputs: dict[str, float] | list[dict[str, float]] | None, *args: Any, **kwargs: Any, - ) -> QuantumTaskBatch: + ) -> QuantumTaskBatch[QuantumTaskType]: """Executes a batch of quantum tasks in parallel Args: @@ -91,7 +91,7 @@ def run_batch( **kwargs (Any): Arbitrary keyword arguments. Returns: - QuantumTaskBatch: A batch containing all of the qauntum tasks run + QuantumTaskBatch: A batch containing all of the quantum tasks run """ @property diff --git a/src/braket/devices/local_simulator.py b/src/braket/devices/local_simulator.py index b75f0fea1..b58ac1f82 100644 --- a/src/braket/devices/local_simulator.py +++ b/src/braket/devices/local_simulator.py @@ -56,7 +56,7 @@ _simulator_devices = {entry.name: entry for entry in entry_points(group="braket.simulators")} -class LocalSimulator(Device): +class LocalSimulator(Device[LocalQuantumTask]): """A simulator meant to run directly on the user's machine. This class wraps a BraketSimulator object so that it can be run and returns diff --git a/src/braket/emulation/emulator.py b/src/braket/emulation/emulator.py index ec14564c5..adefb3414 100644 --- a/src/braket/emulation/emulator.py +++ b/src/braket/emulation/emulator.py @@ -23,12 +23,11 @@ from braket.devices import Device from braket.emulation.pass_manager import PassManager from braket.emulation.passes import ValidationPass -from braket.tasks import QuantumTask -from braket.tasks.quantum_task import TaskSpecification +from braket.tasks.quantum_task import QuantumTaskType, TaskSpecification from braket.tasks.quantum_task_batch import QuantumTaskBatch -class Emulator(Device): +class Emulator(Device[QuantumTaskType]): """ An emulator is a simulation device that more closely resembles the capabilities and constraints of a real device or of a specific device model. @@ -43,7 +42,7 @@ class Emulator(Device): def __init__( self, - backend: Device, + backend: Device[QuantumTaskType], noise_model: NoiseModel | None = None, passes: Iterable[ValidationPass] | None = None, **kwargs, @@ -60,7 +59,7 @@ def run( inputs: dict[str, float] | None = None, *args: Any, **kwargs: Any, - ) -> QuantumTask: + ) -> QuantumTaskType: """Emulate a quantum task specification on this quantum device emulator. A quantum task can be a circuit or an annealing problem. Emulation involves running all emulator passes on the input program before running @@ -94,7 +93,7 @@ def run_batch( inputs: dict[str, float] | list[dict[str, float]] | None, *args: Any, **kwargs: Any, - ) -> QuantumTaskBatch: + ) -> QuantumTaskBatch[QuantumTaskType]: raise NotImplementedError("Emulator does not support run_batch.") @property diff --git a/src/braket/emulation/local_emulator.py b/src/braket/emulation/local_emulator.py index 543e6f0a5..666ced18c 100644 --- a/src/braket/emulation/local_emulator.py +++ b/src/braket/emulation/local_emulator.py @@ -36,9 +36,10 @@ ResultTypeValidator, _NotImplementedValidator, ) +from braket.tasks.local_quantum_task import LocalQuantumTask -class LocalEmulator(Emulator): +class LocalEmulator(Emulator[LocalQuantumTask]): """ A local emulator that mimics the restrictions and noises of a QPU based on the provided device properties. diff --git a/src/braket/tasks/local_quantum_task_batch.py b/src/braket/tasks/local_quantum_task_batch.py index 6f107a1d4..0440ca440 100644 --- a/src/braket/tasks/local_quantum_task_batch.py +++ b/src/braket/tasks/local_quantum_task_batch.py @@ -12,10 +12,11 @@ # language governing permissions and limitations under the License. from braket.tasks import QuantumTaskBatch +from braket.tasks.local_quantum_task import LocalQuantumTask from braket.tasks.quantum_task import TaskResult -class LocalQuantumTaskBatch(QuantumTaskBatch): +class LocalQuantumTaskBatch(QuantumTaskBatch[LocalQuantumTask]): """Executes a batch of quantum tasks in parallel. Since this class is instantiated with the results, cancel() and run_async() are unsupported. diff --git a/src/braket/tasks/quantum_task.py b/src/braket/tasks/quantum_task.py index 76771a0c2..9277dcac9 100644 --- a/src/braket/tasks/quantum_task.py +++ b/src/braket/tasks/quantum_task.py @@ -13,7 +13,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeVar from braket.ir.openqasm import Program as OpenQASMProgram from braket.ir.openqasm import ProgramSet as OpenQASMProgramSet @@ -70,9 +70,7 @@ def state(self) -> str: """ @abstractmethod - def result( - self, - ) -> TaskResult: + def result(self) -> TaskResult: """Get the quantum task result. Returns: @@ -99,3 +97,6 @@ def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: # noqa: B dict[str, Any]: The metadata regarding the quantum task. If `use_cached_value` is True, then the value retrieved from the most recent request is used. """ + + +QuantumTaskType = TypeVar("QuantumTaskType", bound=QuantumTask, covariant=True) # noqa: PLC0105 diff --git a/src/braket/tasks/quantum_task_batch.py b/src/braket/tasks/quantum_task_batch.py index e115a3fc1..ab422b010 100644 --- a/src/braket/tasks/quantum_task_batch.py +++ b/src/braket/tasks/quantum_task_batch.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. from abc import ABC, abstractmethod +from typing import Generic from braket.tasks import AnalogHamiltonianSimulationQuantumTaskResult, GateModelQuantumTaskResult +from braket.tasks.quantum_task import QuantumTaskType -class QuantumTaskBatch(ABC): +class QuantumTaskBatch(ABC, Generic[QuantumTaskType]): """An abstraction over a quantum task batch on a quantum device.""" @abstractmethod