Skip to content

Commit c34853a

Browse files
albi3rormshaffer
andauthored
fix: Remove use of MeasurementProcess.return_type (#288)
Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>
1 parent b578a4c commit c34853a

File tree

6 files changed

+38
-44
lines changed

6 files changed

+38
-44
lines changed

src/braket/pennylane_plugin/ahs_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def check_validity(self, queue, observables):
210210
Args:
211211
queue (Iterable[~.operation.Operation]): quantum operation objects which are intended
212212
to be applied on the device
213-
observables (Iterable[~.operation.Observable]): observables which are intended
213+
observables (Iterable[~.operation.Operator]): observables which are intended
214214
to be evaluated on the device
215215
216216
Raises:

src/braket/pennylane_plugin/braket_device.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@
4949
from pennylane.devices import QubitDevice
5050
from pennylane.gradients import param_shift
5151
from pennylane.measurements import (
52-
Counts,
53-
Expectation,
52+
CountsMP,
53+
ExpectationMP,
5454
MeasurementProcess,
5555
MeasurementTransform,
56-
Probability,
57-
Sample,
56+
ProbabilityMP,
57+
SampleMP,
5858
ShadowExpvalMP,
59-
State,
60-
Variance,
59+
StateMP,
60+
VarianceMP,
6161
)
6262
from pennylane.operation import Operation
6363
from pennylane.ops import Sum
@@ -89,7 +89,7 @@
8989

9090
from ._version import __version__
9191

92-
RETURN_TYPES = [Expectation, Variance, Sample, Probability, State, Counts]
92+
RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP)
9393
MIN_SIMULATOR_BILLED_MS = 3000
9494
OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ)
9595

@@ -283,10 +283,10 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
283283
)
284284
pl_measurements = circuit.measurements[0]
285285
pl_observable = flatten_observable(pl_measurements.obs)
286-
if pl_measurements.return_type != Expectation:
286+
if not isinstance(pl_measurements, ExpectationMP):
287287
raise ValueError(
288288
f"Braket can only compute gradients for circuits with a single expectation"
289-
f" observable, not a {pl_measurements.return_type} observable."
289+
f" observable, not a {type(pl_measurements)} measurement."
290290
)
291291
if isinstance(pl_observable, Sum):
292292
targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]]
@@ -328,16 +328,16 @@ def statistics(
328328
measurements (Sequence[MeasurementProcess]): the list of measurements
329329
330330
Raises:
331-
QuantumFunctionError: if the value of :attr:`~.MeasurementProcess.return_type` is
331+
QuantumFunctionError: if the type of :attr:`~.MeasurementProcess` is
332332
not supported.
333333
334334
Returns:
335335
list[float]: the corresponding statistics
336336
"""
337337
results = []
338338
for mp in measurements:
339-
if mp.return_type not in RETURN_TYPES:
340-
raise QuantumFunctionError("Unsupported return type: {}".format(mp.return_type))
339+
if not isinstance(mp, RETURN_TYPES):
340+
raise QuantumFunctionError("Unsupported return type: {}".format(type(mp)))
341341
results.append(self._get_statistic(braket_result, mp))
342342
return results
343343

@@ -835,7 +835,7 @@ def check_validity(self, queue, observables):
835835
Args:
836836
queue (Iterable[~.operation.Operation]): quantum operation objects which are intended
837837
to be applied on the device
838-
observables (Iterable[~.operation.Observable]): observables which are intended
838+
observables (Iterable[~.operation.Operator]): observables which are intended
839839
to be evaluated on the device
840840
841841
Raises:
@@ -880,7 +880,7 @@ def execute_and_gradients(self, circuits, **kwargs):
880880
new_res = self.execute(circuit, compute_gradient=False)
881881
# don't bother computing a gradient when there aren't any trainable parameters.
882882
new_jac = np.tensor([])
883-
elif len(observables) != 1 or measurements[0].return_type != Expectation:
883+
elif len(observables) != 1 or not isinstance(measurements[0], ExpectationMP):
884884
gradient_circuits, post_processing_fn = param_shift(circuit)
885885
warnings.warn(
886886
"This circuit cannot be differentiated with the adjoint method. "

src/braket/pennylane_plugin/translation.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import numpy as onp
1919
import pennylane as qml
2020
from pennylane import numpy as np
21-
from pennylane.measurements import MeasurementProcess, ObservableReturnTypes
22-
from pennylane.operation import Observable, Operation
21+
from pennylane.measurements import MeasurementProcess
22+
from pennylane.operation import Operation, Operator
2323
from pennylane.pulse import ParametrizedEvolution
2424

2525
from braket.aws import AwsDevice
@@ -538,7 +538,7 @@ def supported_observables(device: Device, shots: int) -> frozenset[str]:
538538

539539

540540
def get_adjoint_gradient_result_type(
541-
observable: Observable,
541+
observable: Operator,
542542
targets: Union[list[int], list[list[int]]],
543543
supported_result_types: frozenset[str],
544544
parameters: list[str],
@@ -571,41 +571,41 @@ def translate_result_type( # noqa: C901
571571
the given observable; if the observable type has multiple terms, for example a Sum,
572572
then this will return a result type for each term.
573573
"""
574-
return_type = measurement.return_type
575574
targets = targets or measurement.wires.tolist()
576575
observable = measurement.obs
577576

578-
if return_type is ObservableReturnTypes.Probability:
577+
if isinstance(measurement, qml.measurements.ProbabilityMP):
579578
return Probability(targets)
580579

581-
if return_type is ObservableReturnTypes.State:
580+
if isinstance(measurement, qml.measurements.StateMP):
582581
if not targets and "StateVector" in supported_result_types:
583582
return StateVector()
584583
elif "DensityMatrix" in supported_result_types:
585584
return DensityMatrix(targets)
586-
raise NotImplementedError(f"Unsupported return type: {return_type}")
585+
raise NotImplementedError(f"Unsupported return type: {type(measurement)}")
587586

588587
if observable is None:
589-
if return_type is ObservableReturnTypes.Counts:
588+
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes:
590589
return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires)
591-
raise NotImplementedError(f"Unsupported return type: {return_type}")
590+
raise NotImplementedError(f"Unsupported return type: {type(measurement)}")
592591

593592
observable = flatten_observable(observable)
594593

595594
if isinstance(observable, qml.ops.LinearCombination):
596-
if return_type is ObservableReturnTypes.Expectation:
595+
if isinstance(measurement, qml.measurements.ExpectationMP):
597596
return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1])
598-
raise NotImplementedError(f"Return type {return_type} unsupported for LinearCombination")
597+
raise NotImplementedError(f"Return type {type(measurement)} unsupported for LinearCombination")
599598

600599
braket_observable = _translate_observable(observable)
601-
if return_type is ObservableReturnTypes.Expectation:
600+
if isinstance(measurement, qml.measurements.ExpectationMP):
602601
return Expectation(braket_observable)
603-
elif return_type is ObservableReturnTypes.Variance:
602+
if isinstance(measurement, qml.measurements.VarianceMP):
604603
return Variance(braket_observable)
605-
elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts):
604+
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes:
606605
return Sample(braket_observable)
607-
else:
608-
raise NotImplementedError(f"Unsupported return type: {return_type}")
606+
if isinstance(measurement, qml.measurements.SampleMP):
607+
return Sample(braket_observable)
608+
raise NotImplementedError(f"Unsupported return type: {type(measurement)}")
609609

610610

611611
def flatten_observable(observable):
@@ -722,7 +722,7 @@ def translate_result(
722722
]
723723

724724
targets = targets or measurement.wires.tolist()
725-
if measurement.return_type is ObservableReturnTypes.Counts and observable is None:
725+
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes and observable is None:
726726
if targets:
727727
new_dict = {}
728728
for key, value in braket_result.measurement_counts.items():
@@ -742,7 +742,7 @@ def translate_result(
742742
coeff * braket_result.get_value_by_result_type(result_type)
743743
for coeff, result_type in zip(coeffs, translated)
744744
)
745-
elif measurement.return_type is ObservableReturnTypes.Counts:
745+
elif isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes:
746746
return dict(Counter(braket_result.get_value_by_result_type(translated)))
747747
else:
748748
return braket_result.get_value_by_result_type(translated)

test/unit_tests/test_ahs_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ class DummyOp(qml.operation.Operator):
646646
],
647647
)
648648
def test_validate_measurement_basis(self, observable, error_expected):
649-
"""Tests that when given an Observable not in the Z basis, _validate_measurement_basis,
649+
"""Tests that when given an Operator not in the Z basis, _validate_measurement_basis,
650650
fails with an error, but otherwise passes"""
651651

652652
dev = qml.device("braket.local.ahs", wires=3)

test/unit_tests/test_braket_device.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,7 @@ def test_counts_all_outcomes_fails():
14671467
qml.CNOT(wires=[0, 1])
14681468
qml.counts(all_outcomes=True)
14691469

1470-
does_not_support = "Unsupported return type: ObservableReturnTypes.AllCounts"
1470+
does_not_support = "Unsupported return type: <class 'pennylane.measurements.counts.CountsMP'>"
14711471
with pytest.raises(NotImplementedError, match=does_not_support):
14721472
dev.execute(circuit)
14731473

@@ -1481,7 +1481,7 @@ def test_sample_fails():
14811481
qml.CNOT(wires=[0, 1])
14821482
qml.sample()
14831483

1484-
does_not_support = "Unsupported return type: ObservableReturnTypes.Sample"
1484+
does_not_support = "Unsupported return type: <class 'pennylane.measurements.sample.SampleMP'>"
14851485
with pytest.raises(NotImplementedError, match=does_not_support):
14861486
dev.execute(circuit)
14871487

@@ -1491,14 +1491,13 @@ def test_unsupported_return_type():
14911491
dev = _aws_device(wires=2, shots=4)
14921492

14931493
mock_measurement = Mock()
1494-
mock_measurement.return_type = Enum("ObservableReturnTypes", {"Foo": "foo"}).Foo
14951494
mock_measurement.obs = qml.PauliZ(0)
14961495
mock_measurement.wires = qml.wires.Wires([0])
14971496
mock_measurement.map_wires.return_value = mock_measurement
14981497

14991498
tape = qml.tape.QuantumTape(measurements=[mock_measurement])
15001499

1501-
does_not_support = "Unsupported return type: ObservableReturnTypes.Foo"
1500+
does_not_support = "Unsupported return type: <class 'unittest.mock.Mock'>"
15021501
with pytest.raises(NotImplementedError, match=does_not_support):
15031502
dev.execute(tape)
15041503

test/unit_tests/test_translation.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
)
4848
from pennylane import measurements
4949
from pennylane import numpy as pnp
50-
from pennylane.measurements import ObservableReturnTypes
5150
from pennylane.pulse import ParametrizedEvolution, transmon_drive
5251
from pennylane.wires import Wires
5352

@@ -366,11 +365,7 @@ def _aws_device(
366365
for op in _BRAKET_TO_PENNYLANE_OPERATIONS
367366
}
368367

369-
pl_return_types = [
370-
ObservableReturnTypes.Expectation,
371-
ObservableReturnTypes.Variance,
372-
ObservableReturnTypes.Sample,
373-
]
368+
pl_return_types = [qml.expval, qml.var, qml.sample]
374369

375370
braket_result_types = [
376371
Expectation(observables.H(), [0]),

0 commit comments

Comments
 (0)