diff --git a/src/braket/pennylane_plugin/ahs_device.py b/src/braket/pennylane_plugin/ahs_device.py index 57366221..1cbe9050 100644 --- a/src/braket/pennylane_plugin/ahs_device.py +++ b/src/braket/pennylane_plugin/ahs_device.py @@ -210,7 +210,7 @@ def check_validity(self, queue, observables): Args: queue (Iterable[~.operation.Operation]): quantum operation objects which are intended to be applied on the device - observables (Iterable[~.operation.Observable]): observables which are intended + observables (Iterable[~.operation.Operator]): observables which are intended to be evaluated on the device Raises: diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index 6a3559ad..24123a87 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -49,15 +49,15 @@ from pennylane.devices import QubitDevice from pennylane.gradients import param_shift from pennylane.measurements import ( - Counts, - Expectation, + CountsMP, + ExpectationMP, MeasurementProcess, MeasurementTransform, - Probability, - Sample, + ProbabilityMP, + SampleMP, ShadowExpvalMP, - State, - Variance, + StateMP, + VarianceMP, ) from pennylane.operation import Operation from pennylane.ops import Sum @@ -89,7 +89,7 @@ from ._version import __version__ -RETURN_TYPES = [Expectation, Variance, Sample, Probability, State, Counts] +RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP) MIN_SIMULATOR_BILLED_MS = 3000 OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ) @@ -283,10 +283,10 @@ def _apply_gradient_result_type(self, circuit, braket_circuit): ) pl_measurements = circuit.measurements[0] pl_observable = flatten_observable(pl_measurements.obs) - if pl_measurements.return_type != Expectation: + if not isinstance(pl_measurements, ExpectationMP): raise ValueError( f"Braket can only compute gradients for circuits with a single expectation" - f" observable, not a {pl_measurements.return_type} observable." + f" observable, not a {type(pl_measurements)} measurement." ) if isinstance(pl_observable, Sum): targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]] @@ -328,7 +328,7 @@ def statistics( measurements (Sequence[MeasurementProcess]): the list of measurements Raises: - QuantumFunctionError: if the value of :attr:`~.MeasurementProcess.return_type` is + QuantumFunctionError: if the type of :attr:`~.MeasurementProcess` is not supported. Returns: @@ -336,8 +336,8 @@ def statistics( """ results = [] for mp in measurements: - if mp.return_type not in RETURN_TYPES: - raise QuantumFunctionError("Unsupported return type: {}".format(mp.return_type)) + if not isinstance(mp, RETURN_TYPES): + raise QuantumFunctionError("Unsupported return type: {}".format(type(mp))) results.append(self._get_statistic(braket_result, mp)) return results @@ -835,7 +835,7 @@ def check_validity(self, queue, observables): Args: queue (Iterable[~.operation.Operation]): quantum operation objects which are intended to be applied on the device - observables (Iterable[~.operation.Observable]): observables which are intended + observables (Iterable[~.operation.Operator]): observables which are intended to be evaluated on the device Raises: @@ -880,7 +880,7 @@ def execute_and_gradients(self, circuits, **kwargs): new_res = self.execute(circuit, compute_gradient=False) # don't bother computing a gradient when there aren't any trainable parameters. new_jac = np.tensor([]) - elif len(observables) != 1 or measurements[0].return_type != Expectation: + elif len(observables) != 1 or not isinstance(measurements[0], ExpectationMP): gradient_circuits, post_processing_fn = param_shift(circuit) warnings.warn( "This circuit cannot be differentiated with the adjoint method. " diff --git a/src/braket/pennylane_plugin/translation.py b/src/braket/pennylane_plugin/translation.py index 7a9fcd61..3eb8a9ed 100644 --- a/src/braket/pennylane_plugin/translation.py +++ b/src/braket/pennylane_plugin/translation.py @@ -18,8 +18,8 @@ import numpy as onp import pennylane as qml from pennylane import numpy as np -from pennylane.measurements import MeasurementProcess, ObservableReturnTypes -from pennylane.operation import Observable, Operation +from pennylane.measurements import MeasurementProcess +from pennylane.operation import Operation, Operator from pennylane.pulse import ParametrizedEvolution from braket.aws import AwsDevice @@ -538,7 +538,7 @@ def supported_observables(device: Device, shots: int) -> frozenset[str]: def get_adjoint_gradient_result_type( - observable: Observable, + observable: Operator, targets: Union[list[int], list[list[int]]], supported_result_types: frozenset[str], parameters: list[str], @@ -571,41 +571,41 @@ def translate_result_type( # noqa: C901 the given observable; if the observable type has multiple terms, for example a Sum, then this will return a result type for each term. """ - return_type = measurement.return_type targets = targets or measurement.wires.tolist() observable = measurement.obs - if return_type is ObservableReturnTypes.Probability: + if isinstance(measurement, qml.measurements.ProbabilityMP): return Probability(targets) - if return_type is ObservableReturnTypes.State: + if isinstance(measurement, qml.measurements.StateMP): if not targets and "StateVector" in supported_result_types: return StateVector() elif "DensityMatrix" in supported_result_types: return DensityMatrix(targets) - raise NotImplementedError(f"Unsupported return type: {return_type}") + raise NotImplementedError(f"Unsupported return type: {type(measurement)}") if observable is None: - if return_type is ObservableReturnTypes.Counts: + if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes: return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires) - raise NotImplementedError(f"Unsupported return type: {return_type}") + raise NotImplementedError(f"Unsupported return type: {type(measurement)}") observable = flatten_observable(observable) if isinstance(observable, qml.ops.LinearCombination): - if return_type is ObservableReturnTypes.Expectation: + if isinstance(measurement, qml.measurements.ExpectationMP): return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1]) - raise NotImplementedError(f"Return type {return_type} unsupported for LinearCombination") + raise NotImplementedError(f"Return type {type(measurement)} unsupported for LinearCombination") braket_observable = _translate_observable(observable) - if return_type is ObservableReturnTypes.Expectation: + if isinstance(measurement, qml.measurements.ExpectationMP): return Expectation(braket_observable) - elif return_type is ObservableReturnTypes.Variance: + if isinstance(measurement, qml.measurements.VarianceMP): return Variance(braket_observable) - elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts): + if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes: return Sample(braket_observable) - else: - raise NotImplementedError(f"Unsupported return type: {return_type}") + if isinstance(measurement, qml.measurements.SampleMP): + return Sample(braket_observable) + raise NotImplementedError(f"Unsupported return type: {type(measurement)}") def flatten_observable(observable): @@ -722,7 +722,7 @@ def translate_result( ] targets = targets or measurement.wires.tolist() - if measurement.return_type is ObservableReturnTypes.Counts and observable is None: + if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes and observable is None: if targets: new_dict = {} for key, value in braket_result.measurement_counts.items(): @@ -742,7 +742,7 @@ def translate_result( coeff * braket_result.get_value_by_result_type(result_type) for coeff, result_type in zip(coeffs, translated) ) - elif measurement.return_type is ObservableReturnTypes.Counts: + elif isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes: return dict(Counter(braket_result.get_value_by_result_type(translated))) else: return braket_result.get_value_by_result_type(translated) diff --git a/test/unit_tests/test_ahs_device.py b/test/unit_tests/test_ahs_device.py index 47840f01..2b8e908d 100644 --- a/test/unit_tests/test_ahs_device.py +++ b/test/unit_tests/test_ahs_device.py @@ -646,7 +646,7 @@ class DummyOp(qml.operation.Operator): ], ) def test_validate_measurement_basis(self, observable, error_expected): - """Tests that when given an Observable not in the Z basis, _validate_measurement_basis, + """Tests that when given an Operator not in the Z basis, _validate_measurement_basis, fails with an error, but otherwise passes""" dev = qml.device("braket.local.ahs", wires=3) diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index 11c659a8..789fa69f 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -1467,7 +1467,7 @@ def test_counts_all_outcomes_fails(): qml.CNOT(wires=[0, 1]) qml.counts(all_outcomes=True) - does_not_support = "Unsupported return type: ObservableReturnTypes.AllCounts" + does_not_support = "Unsupported return type: " with pytest.raises(NotImplementedError, match=does_not_support): dev.execute(circuit) @@ -1481,7 +1481,7 @@ def test_sample_fails(): qml.CNOT(wires=[0, 1]) qml.sample() - does_not_support = "Unsupported return type: ObservableReturnTypes.Sample" + does_not_support = "Unsupported return type: " with pytest.raises(NotImplementedError, match=does_not_support): dev.execute(circuit) @@ -1491,14 +1491,13 @@ def test_unsupported_return_type(): dev = _aws_device(wires=2, shots=4) mock_measurement = Mock() - mock_measurement.return_type = Enum("ObservableReturnTypes", {"Foo": "foo"}).Foo mock_measurement.obs = qml.PauliZ(0) mock_measurement.wires = qml.wires.Wires([0]) mock_measurement.map_wires.return_value = mock_measurement tape = qml.tape.QuantumTape(measurements=[mock_measurement]) - does_not_support = "Unsupported return type: ObservableReturnTypes.Foo" + does_not_support = "Unsupported return type: " with pytest.raises(NotImplementedError, match=does_not_support): dev.execute(tape) diff --git a/test/unit_tests/test_translation.py b/test/unit_tests/test_translation.py index b28470f8..8bb25ab7 100644 --- a/test/unit_tests/test_translation.py +++ b/test/unit_tests/test_translation.py @@ -47,7 +47,6 @@ ) from pennylane import measurements from pennylane import numpy as pnp -from pennylane.measurements import ObservableReturnTypes from pennylane.pulse import ParametrizedEvolution, transmon_drive from pennylane.wires import Wires @@ -366,11 +365,7 @@ def _aws_device( for op in _BRAKET_TO_PENNYLANE_OPERATIONS } -pl_return_types = [ - ObservableReturnTypes.Expectation, - ObservableReturnTypes.Variance, - ObservableReturnTypes.Sample, -] +pl_return_types = [qml.expval, qml.var, qml.sample] braket_result_types = [ Expectation(observables.H(), [0]),