Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def check_validity(self, queue, observables):
Args:
queue (Iterable[~.operation.Operation]): quantum operation objects which are intended
to be applied on the device
observables (Iterable[~.operation.Observable]): observables which are intended
observables (Iterable[~.operation.Operator]): observables which are intended
to be evaluated on the device

Raises:
Expand Down
28 changes: 14 additions & 14 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@
from pennylane.devices import QubitDevice
from pennylane.gradients import param_shift
from pennylane.measurements import (
Counts,
Expectation,
CountsMP,
ExpectationMP,
MeasurementProcess,
MeasurementTransform,
Probability,
Sample,
ProbabilityMP,
SampleMP,
ShadowExpvalMP,
State,
Variance,
StateMP,
VarianceMP,
)
from pennylane.operation import Operation
from pennylane.ops import Sum
Expand Down Expand Up @@ -89,7 +89,7 @@

from ._version import __version__

RETURN_TYPES = [Expectation, Variance, Sample, Probability, State, Counts]
RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP)
MIN_SIMULATOR_BILLED_MS = 3000
OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ)

Expand Down Expand Up @@ -283,10 +283,10 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
)
pl_measurements = circuit.measurements[0]
pl_observable = flatten_observable(pl_measurements.obs)
if pl_measurements.return_type != Expectation:
if not isinstance(pl_measurements, ExpectationMP):
raise ValueError(
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
f" observable, not a {type(pl_measurements)} measurement."
)
if isinstance(pl_observable, Sum):
targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]]
Expand Down Expand Up @@ -328,16 +328,16 @@ def statistics(
measurements (Sequence[MeasurementProcess]): the list of measurements

Raises:
QuantumFunctionError: if the value of :attr:`~.MeasurementProcess.return_type` is
QuantumFunctionError: if the type of :attr:`~.MeasurementProcess` is
not supported.

Returns:
list[float]: the corresponding statistics
"""
results = []
for mp in measurements:
if mp.return_type not in RETURN_TYPES:
raise QuantumFunctionError("Unsupported return type: {}".format(mp.return_type))
if not isinstance(mp, RETURN_TYPES):
raise QuantumFunctionError("Unsupported return type: {}".format(type(mp)))
results.append(self._get_statistic(braket_result, mp))
return results

Expand Down Expand Up @@ -835,7 +835,7 @@ def check_validity(self, queue, observables):
Args:
queue (Iterable[~.operation.Operation]): quantum operation objects which are intended
to be applied on the device
observables (Iterable[~.operation.Observable]): observables which are intended
observables (Iterable[~.operation.Operator]): observables which are intended
to be evaluated on the device

Raises:
Expand Down Expand Up @@ -880,7 +880,7 @@ def execute_and_gradients(self, circuits, **kwargs):
new_res = self.execute(circuit, compute_gradient=False)
# don't bother computing a gradient when there aren't any trainable parameters.
new_jac = np.tensor([])
elif len(observables) != 1 or measurements[0].return_type != Expectation:
elif len(observables) != 1 or not isinstance(measurements[0], ExpectationMP):
gradient_circuits, post_processing_fn = param_shift(circuit)
warnings.warn(
"This circuit cannot be differentiated with the adjoint method. "
Expand Down
36 changes: 18 additions & 18 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as onp
import pennylane as qml
from pennylane import numpy as np
from pennylane.measurements import MeasurementProcess, ObservableReturnTypes
from pennylane.operation import Observable, Operation
from pennylane.measurements import MeasurementProcess
from pennylane.operation import Operation, Operator
from pennylane.pulse import ParametrizedEvolution

from braket.aws import AwsDevice
Expand Down Expand Up @@ -538,7 +538,7 @@ def supported_observables(device: Device, shots: int) -> frozenset[str]:


def get_adjoint_gradient_result_type(
observable: Observable,
observable: Operator,
targets: Union[list[int], list[list[int]]],
supported_result_types: frozenset[str],
parameters: list[str],
Expand Down Expand Up @@ -571,41 +571,41 @@ def translate_result_type( # noqa: C901
the given observable; if the observable type has multiple terms, for example a Sum,
then this will return a result type for each term.
"""
return_type = measurement.return_type
targets = targets or measurement.wires.tolist()
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
if isinstance(measurement, qml.measurements.ProbabilityMP):
return Probability(targets)

if return_type is ObservableReturnTypes.State:
if isinstance(measurement, qml.measurements.StateMP):
if not targets and "StateVector" in supported_result_types:
return StateVector()
elif "DensityMatrix" in supported_result_types:
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")
raise NotImplementedError(f"Unsupported return type: {type(measurement)}")

if observable is None:
if return_type is ObservableReturnTypes.Counts:
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes:
return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")
raise NotImplementedError(f"Unsupported return type: {type(measurement)}")

observable = flatten_observable(observable)

if isinstance(observable, qml.ops.LinearCombination):
if return_type is ObservableReturnTypes.Expectation:
if isinstance(measurement, qml.measurements.ExpectationMP):
return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1])
raise NotImplementedError(f"Return type {return_type} unsupported for LinearCombination")
raise NotImplementedError(f"Return type {type(measurement)} unsupported for LinearCombination")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
if isinstance(measurement, qml.measurements.ExpectationMP):
return Expectation(braket_observable)
elif return_type is ObservableReturnTypes.Variance:
if isinstance(measurement, qml.measurements.VarianceMP):
return Variance(braket_observable)
elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts):
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes:
return Sample(braket_observable)
else:
raise NotImplementedError(f"Unsupported return type: {return_type}")
if isinstance(measurement, qml.measurements.SampleMP):
return Sample(braket_observable)
raise NotImplementedError(f"Unsupported return type: {type(measurement)}")


def flatten_observable(observable):
Expand Down Expand Up @@ -722,7 +722,7 @@ def translate_result(
]

targets = targets or measurement.wires.tolist()
if measurement.return_type is ObservableReturnTypes.Counts and observable is None:
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes and observable is None:
if targets:
new_dict = {}
for key, value in braket_result.measurement_counts.items():
Expand All @@ -742,7 +742,7 @@ def translate_result(
coeff * braket_result.get_value_by_result_type(result_type)
for coeff, result_type in zip(coeffs, translated)
)
elif measurement.return_type is ObservableReturnTypes.Counts:
elif isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes:
return dict(Counter(braket_result.get_value_by_result_type(translated)))
else:
return braket_result.get_value_by_result_type(translated)
2 changes: 1 addition & 1 deletion test/unit_tests/test_ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class DummyOp(qml.operation.Operator):
],
)
def test_validate_measurement_basis(self, observable, error_expected):
"""Tests that when given an Observable not in the Z basis, _validate_measurement_basis,
"""Tests that when given an Operator not in the Z basis, _validate_measurement_basis,
fails with an error, but otherwise passes"""

dev = qml.device("braket.local.ahs", wires=3)
Expand Down
7 changes: 3 additions & 4 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ def test_counts_all_outcomes_fails():
qml.CNOT(wires=[0, 1])
qml.counts(all_outcomes=True)

does_not_support = "Unsupported return type: ObservableReturnTypes.AllCounts"
does_not_support = "Unsupported return type: <class 'pennylane.measurements.counts.CountsMP'>"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(circuit)

Expand All @@ -1481,7 +1481,7 @@ def test_sample_fails():
qml.CNOT(wires=[0, 1])
qml.sample()

does_not_support = "Unsupported return type: ObservableReturnTypes.Sample"
does_not_support = "Unsupported return type: <class 'pennylane.measurements.sample.SampleMP'>"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(circuit)

Expand All @@ -1491,14 +1491,13 @@ def test_unsupported_return_type():
dev = _aws_device(wires=2, shots=4)

mock_measurement = Mock()
mock_measurement.return_type = Enum("ObservableReturnTypes", {"Foo": "foo"}).Foo
mock_measurement.obs = qml.PauliZ(0)
mock_measurement.wires = qml.wires.Wires([0])
mock_measurement.map_wires.return_value = mock_measurement

tape = qml.tape.QuantumTape(measurements=[mock_measurement])

does_not_support = "Unsupported return type: ObservableReturnTypes.Foo"
does_not_support = "Unsupported return type: <class 'unittest.mock.Mock'>"
with pytest.raises(NotImplementedError, match=does_not_support):
dev.execute(tape)

Expand Down
7 changes: 1 addition & 6 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
)
from pennylane import measurements
from pennylane import numpy as pnp
from pennylane.measurements import ObservableReturnTypes
from pennylane.pulse import ParametrizedEvolution, transmon_drive
from pennylane.wires import Wires

Expand Down Expand Up @@ -366,11 +365,7 @@ def _aws_device(
for op in _BRAKET_TO_PENNYLANE_OPERATIONS
}

pl_return_types = [
ObservableReturnTypes.Expectation,
ObservableReturnTypes.Variance,
ObservableReturnTypes.Sample,
]
pl_return_types = [qml.expval, qml.var, qml.sample]

braket_result_types = [
Expectation(observables.H(), [0]),
Expand Down