1818import numpy as onp
1919import pennylane as qml
2020from pennylane import numpy as np
21- from pennylane .measurements import MeasurementProcess , ObservableReturnTypes
22- from pennylane .operation import Observable , Operation
21+ from pennylane .measurements import MeasurementProcess
22+ from pennylane .operation import Operation , Operator
2323from pennylane .pulse import ParametrizedEvolution
2424
2525from braket .aws import AwsDevice
@@ -538,7 +538,7 @@ def supported_observables(device: Device, shots: int) -> frozenset[str]:
538538
539539
540540def get_adjoint_gradient_result_type (
541- observable : Observable ,
541+ observable : Operator ,
542542 targets : Union [list [int ], list [list [int ]]],
543543 supported_result_types : frozenset [str ],
544544 parameters : list [str ],
@@ -571,41 +571,41 @@ def translate_result_type( # noqa: C901
571571 the given observable; if the observable type has multiple terms, for example a Sum,
572572 then this will return a result type for each term.
573573 """
574- return_type = measurement .return_type
575574 targets = targets or measurement .wires .tolist ()
576575 observable = measurement .obs
577576
578- if return_type is ObservableReturnTypes . Probability :
577+ if isinstance ( measurement , qml . measurements . ProbabilityMP ) :
579578 return Probability (targets )
580579
581- if return_type is ObservableReturnTypes . State :
580+ if isinstance ( measurement , qml . measurements . StateMP ) :
582581 if not targets and "StateVector" in supported_result_types :
583582 return StateVector ()
584583 elif "DensityMatrix" in supported_result_types :
585584 return DensityMatrix (targets )
586- raise NotImplementedError (f"Unsupported return type: { return_type } " )
585+ raise NotImplementedError (f"Unsupported return type: { type ( measurement ) } " )
587586
588587 if observable is None :
589- if return_type is ObservableReturnTypes . Counts :
588+ if isinstance ( measurement , qml . measurements . CountsMP ) and not measurement . all_outcomes :
590589 return tuple (Sample (observables .Z (target )) for target in targets or measurement .wires )
591- raise NotImplementedError (f"Unsupported return type: { return_type } " )
590+ raise NotImplementedError (f"Unsupported return type: { type ( measurement ) } " )
592591
593592 observable = flatten_observable (observable )
594593
595594 if isinstance (observable , qml .ops .LinearCombination ):
596- if return_type is ObservableReturnTypes . Expectation :
595+ if isinstance ( measurement , qml . measurements . ExpectationMP ) :
597596 return tuple (Expectation (_translate_observable (op )) for op in observable .terms ()[1 ])
598- raise NotImplementedError (f"Return type { return_type } unsupported for LinearCombination" )
597+ raise NotImplementedError (f"Return type { type ( measurement ) } unsupported for LinearCombination" )
599598
600599 braket_observable = _translate_observable (observable )
601- if return_type is ObservableReturnTypes . Expectation :
600+ if isinstance ( measurement , qml . measurements . ExpectationMP ) :
602601 return Expectation (braket_observable )
603- elif return_type is ObservableReturnTypes . Variance :
602+ if isinstance ( measurement , qml . measurements . VarianceMP ) :
604603 return Variance (braket_observable )
605- elif return_type in ( ObservableReturnTypes . Sample , ObservableReturnTypes . Counts ) :
604+ if isinstance ( measurement , qml . measurements . CountsMP ) and not measurement . all_outcomes :
606605 return Sample (braket_observable )
607- else :
608- raise NotImplementedError (f"Unsupported return type: { return_type } " )
606+ if isinstance (measurement , qml .measurements .SampleMP ):
607+ return Sample (braket_observable )
608+ raise NotImplementedError (f"Unsupported return type: { type (measurement )} " )
609609
610610
611611def flatten_observable (observable ):
@@ -722,7 +722,7 @@ def translate_result(
722722 ]
723723
724724 targets = targets or measurement .wires .tolist ()
725- if measurement . return_type is ObservableReturnTypes . Counts and observable is None :
725+ if isinstance ( measurement , qml . measurements . CountsMP ) and not measurement . all_outcomes and observable is None :
726726 if targets :
727727 new_dict = {}
728728 for key , value in braket_result .measurement_counts .items ():
@@ -742,7 +742,7 @@ def translate_result(
742742 coeff * braket_result .get_value_by_result_type (result_type )
743743 for coeff , result_type in zip (coeffs , translated )
744744 )
745- elif measurement . return_type is ObservableReturnTypes . Counts :
745+ elif isinstance ( measurement , qml . measurements . CountsMP ) and not measurement . all_outcomes :
746746 return dict (Counter (braket_result .get_value_by_result_type (translated )))
747747 else :
748748 return braket_result .get_value_by_result_type (translated )
0 commit comments