Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
packages=find_namespace_packages(where="src", exclude=("test",)),
package_dir={"": "src"},
install_requires=[
"amazon-braket-sdk>=1.87.0",
"amazon-braket-sdk>=1.97.0",
"autoray>=0.6.11",
"pennylane>=0.34.0",
],
Expand Down
142 changes: 105 additions & 37 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@
translate_result,
translate_result_type,
)
from braket.program_sets import ProgramSet
from braket.simulator import BraketSimulator
from braket.tasks import GateModelQuantumTaskResult, QuantumTask
from braket.tasks.local_quantum_task_batch import LocalQuantumTaskBatch

from ._version import __version__

RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP)
RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP)
MIN_SIMULATOR_BILLED_MS = 3000
OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ)

Expand Down Expand Up @@ -168,6 +169,10 @@ def __init__(
self._supported_obs = supported_observables(self._device, self.shots)
self._check_supported_result_types()
self._verbatim = verbatim
self._supports_program_sets = (
DeviceActionType.OPENQASM_PROGRAM_SET in self._device.properties.action
and self._shots is not None
)

if noise_model:
self._validate_noise_model_support()
Expand Down Expand Up @@ -202,7 +207,7 @@ def parallel(self) -> bool:
return self._parallel

def batch_execute(self, circuits, **run_kwargs):
if not self._parallel:
if not self._parallel and not self._supports_program_sets:
return super().batch_execute(circuits)

for circuit in circuits:
Expand All @@ -220,6 +225,7 @@ def batch_execute(self, circuits, **run_kwargs):
self._pl_to_braket_circuit(
circuit,
trainable_indices=frozenset(trainable.keys()),
add_observables=not self._supports_program_sets,
**run_kwargs,
)
)
Expand All @@ -232,18 +238,15 @@ def batch_execute(self, circuits, **run_kwargs):
else []
)

braket_results_batch = self._run_task_batch(braket_circuits, batch_shots, batch_inputs)

return [
self._braket_to_pl_result(braket_result, circuit)
for braket_result, circuit in zip(braket_results_batch, circuits)
]
return self._run_task_batch(braket_circuits, circuits, batch_shots, batch_inputs)

def _pl_to_braket_circuit(
self,
circuit: QuantumTape,
compute_gradient: bool = False,
trainable_indices: frozenset[int] = None,
*,
add_observables: bool = True,
**run_kwargs,
):
"""Converts a PennyLane circuit to a Braket circuit"""
Expand All @@ -259,17 +262,28 @@ def _pl_to_braket_circuit(
if compute_gradient:
braket_circuit = self._apply_gradient_result_type(circuit, braket_circuit)
elif not isinstance(circuit.measurements[0], MeasurementTransform):
for measurement in circuit.measurements:
translated = translate_result_type(
measurement.map_wires(self.wire_map),
None,
self._braket_result_types,
if add_observables:
for measurement in circuit.measurements:
translated = translate_result_type(
measurement.map_wires(self.wire_map),
None,
self._braket_result_types,
)
if isinstance(translated, tuple):
for result_type in translated:
braket_circuit.add_result_type(result_type)
else:
braket_circuit.add_result_type(translated)
else:
groups = qml.pauli.group_observables(
[measurement.obs for measurement in circuit.measurements], grouping_type="qwc"
)
if isinstance(translated, tuple):
for result_type in translated:
braket_circuit.add_result_type(result_type)
else:
braket_circuit.add_result_type(translated)
if len(groups) > 1:
raise ValueError(
f"Observables need to mutually commute, but found {len(groups)}: {groups}"
)
diagonalizing_ops = qml.pauli.diagonalize_qwc_pauli_words(groups[0])[0]
braket_circuit += self.apply(diagonalizing_ops, apply_identities=False)

return braket_circuit

Expand Down Expand Up @@ -316,7 +330,7 @@ def _update_tracker_for_batch(
self.tracker.update(batches=1, executions=total_executions, shots=total_shots)
self.tracker.record()

def statistics(
def _statistics(
self,
braket_result: GateModelQuantumTaskResult,
measurements: Sequence[MeasurementProcess],
Expand All @@ -338,14 +352,18 @@ def statistics(
for mp in measurements:
if not isinstance(mp, RETURN_TYPES):
raise QuantumFunctionError("Unsupported return type: {}".format(type(mp)))
results.append(self._get_statistic(braket_result, mp))
results.append(
translate_result(
braket_result, mp.map_wires(self.wire_map), None, self._braket_result_types
)
)
return results

def _braket_to_pl_result(self, braket_result, circuit):
"""Calculates the PennyLane results from a Braket task result. A PennyLane circuit
also determines the output observables."""
# Compute the required statistics
results = self.statistics(braket_result, circuit.measurements)
results = self._statistics(braket_result, circuit.measurements)
ag_results = [
result
for result in braket_result.result_types
Expand Down Expand Up @@ -378,6 +396,25 @@ def _braket_to_pl_result(self, braket_result, circuit):
return onp.array(results).squeeze()
return tuple(onp.array(result).squeeze() for result in results)

def _braket_program_set_to_pl_result(self, program_set_result, circuits):
results = []
for program_result, circuit in zip(program_set_result, circuits):
# Only one executable per program
measurements = program_result[0].measurements

# Program sets require shots > 0,
# so the circuit's measurements are guaranteed to be SampleMeasurements
executable_results = [
measurement.process_samples(measurements, wire_order=measurement.wires)
for measurement in circuit.measurements
]
results.append(
onp.array(executable_results).squeeze()
if len(circuit.measurements) == 1
else tuple(onp.array(result).squeeze() for result in executable_results)
)
return results

@staticmethod
def _tracking_data(task):
if task.state() == "COMPLETED":
Expand Down Expand Up @@ -410,8 +447,6 @@ def classical_shadow(self, obs, circuit):
rng = np.random.default_rng(seed)
recipes = rng.integers(0, 3, size=(n_snapshots, n_qubits))

outcomes = np.zeros((n_snapshots, n_qubits))

snapshot_rotations = [
[
rot
Expand Down Expand Up @@ -484,6 +519,7 @@ def apply(
use_unique_params: bool = False,
*,
trainable_indices: Optional[frozenset[int]] = None,
apply_identities: bool = True,
**run_kwargs,
) -> Circuit:
"""Instantiate Braket Circuit object."""
Expand Down Expand Up @@ -518,8 +554,9 @@ def apply(
unused = set(range(self.num_wires)) - {int(qubit) for qubit in circuit.qubits}

# To ensure the results have the right number of qubits
for qubit in sorted(unused):
circuit.i(qubit)
if apply_identities:
for qubit in sorted(unused):
circuit.i(qubit)

if self._noise_model:
circuit = self._noise_model.apply(circuit)
Expand Down Expand Up @@ -552,14 +589,12 @@ def _validate_noise_model_support(self):
def _run_task(self, circuit, inputs=None):
raise NotImplementedError("Need to implement task runner")

def _run_task_batch(self, braket_circuits, pl_circuits, circuit_shots, mapped_wires):
raise NotImplementedError("Need to implement batch runner")

def _run_snapshots(self, snapshot_circuits, n_qubits, mapped_wires):
raise NotImplementedError("Need to implement snapshots runner")

def _get_statistic(self, braket_result, mp):
return translate_result(
braket_result, mp.map_wires(self.wire_map), None, self._braket_result_types
)

@staticmethod
def _get_trainable_parameters(tape: QuantumTape) -> dict[int, numbers.Number]:
trainable_indices = sorted(tape.trainable_params)
Expand Down Expand Up @@ -663,9 +698,24 @@ def use_grouping(self) -> bool:
caps = self.capabilities()
return not ("provides_jacobian" in caps and caps["provides_jacobian"])

def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs):
if self._supports_program_sets:
program_set = (
ProgramSet.zip(braket_circuits, input_sets=inputs)
if inputs
else ProgramSet(braket_circuits)
)
task = self._device.run(
program_set,
s3_destination_folder=self._s3_folder,
shots=len(program_set) * batch_shots,
poll_timeout_seconds=self._poll_timeout_seconds,
poll_interval_seconds=self._poll_interval_seconds,
**self._run_kwargs,
)
return self._braket_program_set_to_pl_result(task.result(), pl_circuits)
task_batch = self._device.run_batch(
batch_circuits,
braket_circuits,
s3_destination_folder=self._s3_folder,
shots=batch_shots,
max_parallel=self._max_parallel,
Expand All @@ -687,7 +737,10 @@ def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
if self.tracker.active:
self._update_tracker_for_batch(task_batch, batch_shots)

return braket_results_batch
return [
self._braket_to_pl_result(braket_result, circuit)
for braket_result, circuit in zip(braket_results_batch, pl_circuits)
]

def _run_task(self, circuit, inputs=None):
return self._device.run(
Expand All @@ -703,7 +756,19 @@ def _run_task(self, circuit, inputs=None):
def _run_snapshots(self, snapshot_circuits, n_qubits, mapped_wires):
n_snapshots = len(snapshot_circuits)
outcomes = np.zeros((n_snapshots, n_qubits))
if self._parallel:
if self._supports_program_sets:
program_set = ProgramSet(snapshot_circuits)
task = self._device.run(
program_set,
s3_destination_folder=self._s3_folder,
shots=len(program_set),
poll_timeout_seconds=self._poll_timeout_seconds,
poll_interval_seconds=self._poll_interval_seconds,
**self._run_kwargs,
)
for t, result in enumerate(task.result()):
outcomes[t] = np.array(result[0].measurements[0])[mapped_wires]
elif self._parallel:
task_batch = self._device.run_batch(
snapshot_circuits,
s3_destination_folder=self._s3_folder,
Expand Down Expand Up @@ -1041,9 +1106,9 @@ def __init__(
device = LocalSimulator(backend)
super().__init__(wires, device, shots=shots, **run_kwargs)

def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs):
task_batch = self._device.run_batch(
batch_circuits,
braket_circuits,
shots=batch_shots,
max_parallel=self._max_parallel,
inputs=inputs,
Expand All @@ -1057,7 +1122,10 @@ def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
if self.tracker.active:
self._update_tracker_for_batch(task_batch, batch_shots)

return braket_results_batch
return [
self._braket_to_pl_result(braket_result, circuit)
for braket_result, circuit in zip(braket_results_batch, pl_circuits)
]

def _run_task(self, circuit, inputs=None):
return self._device.run(
Expand Down
10 changes: 8 additions & 2 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,9 @@ def translate_result_type( # noqa: C901
if isinstance(observable, qml.ops.LinearCombination):
if isinstance(measurement, qml.measurements.ExpectationMP):
return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1])
raise NotImplementedError(f"Return type {type(measurement)} unsupported for LinearCombination")
raise NotImplementedError(
f"Return type {type(measurement)} unsupported for LinearCombination"
)

braket_observable = _translate_observable(observable)
if isinstance(measurement, qml.measurements.ExpectationMP):
Expand Down Expand Up @@ -722,7 +724,11 @@ def translate_result(
]

targets = targets or measurement.wires.tolist()
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes and observable is None:
if (
isinstance(measurement, qml.measurements.CountsMP)
and not measurement.all_outcomes
and observable is None
):
if targets:
new_dict = {}
for key, value in braket_result.measurement_counts.items():
Expand Down
Loading
Loading