8383 translate_result ,
8484 translate_result_type ,
8585)
86+ from braket .program_sets import ProgramSet
8687from braket .simulator import BraketSimulator
8788from braket .tasks import GateModelQuantumTaskResult , QuantumTask
8889from braket .tasks .local_quantum_task_batch import LocalQuantumTaskBatch
8990
9091from ._version import __version__
9192
92- RETURN_TYPES = (ExpectationMP , VarianceMP , SampleMP , ProbabilityMP , StateMP , CountsMP )
93+ RETURN_TYPES = (ExpectationMP , VarianceMP , SampleMP , ProbabilityMP , StateMP , CountsMP )
9394MIN_SIMULATOR_BILLED_MS = 3000
9495OBS_LIST = (qml .PauliX , qml .PauliY , qml .PauliZ )
9596
@@ -168,6 +169,10 @@ def __init__(
168169 self ._supported_obs = supported_observables (self ._device , self .shots )
169170 self ._check_supported_result_types ()
170171 self ._verbatim = verbatim
172+ self ._supports_program_sets = (
173+ DeviceActionType .OPENQASM_PROGRAM_SET in self ._device .properties .action
174+ and self ._shots is not None
175+ )
171176
172177 if noise_model :
173178 self ._validate_noise_model_support ()
@@ -202,7 +207,7 @@ def parallel(self) -> bool:
202207 return self ._parallel
203208
204209 def batch_execute (self , circuits , ** run_kwargs ):
205- if not self ._parallel :
210+ if not self ._parallel and not self . _supports_program_sets :
206211 return super ().batch_execute (circuits )
207212
208213 for circuit in circuits :
@@ -220,6 +225,7 @@ def batch_execute(self, circuits, **run_kwargs):
220225 self ._pl_to_braket_circuit (
221226 circuit ,
222227 trainable_indices = frozenset (trainable .keys ()),
228+ add_observables = not self ._supports_program_sets ,
223229 ** run_kwargs ,
224230 )
225231 )
@@ -232,18 +238,15 @@ def batch_execute(self, circuits, **run_kwargs):
232238 else []
233239 )
234240
235- braket_results_batch = self ._run_task_batch (braket_circuits , batch_shots , batch_inputs )
236-
237- return [
238- self ._braket_to_pl_result (braket_result , circuit )
239- for braket_result , circuit in zip (braket_results_batch , circuits )
240- ]
241+ return self ._run_task_batch (braket_circuits , circuits , batch_shots , batch_inputs )
241242
242243 def _pl_to_braket_circuit (
243244 self ,
244245 circuit : QuantumTape ,
245246 compute_gradient : bool = False ,
246247 trainable_indices : frozenset [int ] = None ,
248+ * ,
249+ add_observables : bool = True ,
247250 ** run_kwargs ,
248251 ):
249252 """Converts a PennyLane circuit to a Braket circuit"""
@@ -259,17 +262,28 @@ def _pl_to_braket_circuit(
259262 if compute_gradient :
260263 braket_circuit = self ._apply_gradient_result_type (circuit , braket_circuit )
261264 elif not isinstance (circuit .measurements [0 ], MeasurementTransform ):
262- for measurement in circuit .measurements :
263- translated = translate_result_type (
264- measurement .map_wires (self .wire_map ),
265- None ,
266- self ._braket_result_types ,
265+ if add_observables :
266+ for measurement in circuit .measurements :
267+ translated = translate_result_type (
268+ measurement .map_wires (self .wire_map ),
269+ None ,
270+ self ._braket_result_types ,
271+ )
272+ if isinstance (translated , tuple ):
273+ for result_type in translated :
274+ braket_circuit .add_result_type (result_type )
275+ else :
276+ braket_circuit .add_result_type (translated )
277+ else :
278+ groups = qml .pauli .group_observables (
279+ [measurement .obs for measurement in circuit .measurements ], grouping_type = "qwc"
267280 )
268- if isinstance (translated , tuple ):
269- for result_type in translated :
270- braket_circuit .add_result_type (result_type )
271- else :
272- braket_circuit .add_result_type (translated )
281+ if len (groups ) > 1 :
282+ raise ValueError (
283+ f"Observables need to mutually commute, but found { len (groups )} : { groups } "
284+ )
285+ diagonalizing_ops = qml .pauli .diagonalize_qwc_pauli_words (groups [0 ])[0 ]
286+ braket_circuit += self .apply (diagonalizing_ops , apply_identities = False )
273287
274288 return braket_circuit
275289
@@ -316,7 +330,7 @@ def _update_tracker_for_batch(
316330 self .tracker .update (batches = 1 , executions = total_executions , shots = total_shots )
317331 self .tracker .record ()
318332
319- def statistics (
333+ def _statistics (
320334 self ,
321335 braket_result : GateModelQuantumTaskResult ,
322336 measurements : Sequence [MeasurementProcess ],
@@ -338,14 +352,18 @@ def statistics(
338352 for mp in measurements :
339353 if not isinstance (mp , RETURN_TYPES ):
340354 raise QuantumFunctionError ("Unsupported return type: {}" .format (type (mp )))
341- results .append (self ._get_statistic (braket_result , mp ))
355+ results .append (
356+ translate_result (
357+ braket_result , mp .map_wires (self .wire_map ), None , self ._braket_result_types
358+ )
359+ )
342360 return results
343361
344362 def _braket_to_pl_result (self , braket_result , circuit ):
345363 """Calculates the PennyLane results from a Braket task result. A PennyLane circuit
346364 also determines the output observables."""
347365 # Compute the required statistics
348- results = self .statistics (braket_result , circuit .measurements )
366+ results = self ._statistics (braket_result , circuit .measurements )
349367 ag_results = [
350368 result
351369 for result in braket_result .result_types
@@ -378,6 +396,25 @@ def _braket_to_pl_result(self, braket_result, circuit):
378396 return onp .array (results ).squeeze ()
379397 return tuple (onp .array (result ).squeeze () for result in results )
380398
399+ def _braket_program_set_to_pl_result (self , program_set_result , circuits ):
400+ results = []
401+ for program_result , circuit in zip (program_set_result , circuits ):
402+ # Only one executable per program
403+ measurements = program_result [0 ].measurements
404+
405+ # Program sets require shots > 0,
406+ # so the circuit's measurements are guaranteed to be SampleMeasurements
407+ executable_results = [
408+ measurement .process_samples (measurements , wire_order = measurement .wires )
409+ for measurement in circuit .measurements
410+ ]
411+ results .append (
412+ onp .array (executable_results ).squeeze ()
413+ if len (circuit .measurements ) == 1
414+ else tuple (onp .array (result ).squeeze () for result in executable_results )
415+ )
416+ return results
417+
381418 @staticmethod
382419 def _tracking_data (task ):
383420 if task .state () == "COMPLETED" :
@@ -410,8 +447,6 @@ def classical_shadow(self, obs, circuit):
410447 rng = np .random .default_rng (seed )
411448 recipes = rng .integers (0 , 3 , size = (n_snapshots , n_qubits ))
412449
413- outcomes = np .zeros ((n_snapshots , n_qubits ))
414-
415450 snapshot_rotations = [
416451 [
417452 rot
@@ -484,6 +519,7 @@ def apply(
484519 use_unique_params : bool = False ,
485520 * ,
486521 trainable_indices : Optional [frozenset [int ]] = None ,
522+ apply_identities : bool = True ,
487523 ** run_kwargs ,
488524 ) -> Circuit :
489525 """Instantiate Braket Circuit object."""
@@ -518,8 +554,9 @@ def apply(
518554 unused = set (range (self .num_wires )) - {int (qubit ) for qubit in circuit .qubits }
519555
520556 # To ensure the results have the right number of qubits
521- for qubit in sorted (unused ):
522- circuit .i (qubit )
557+ if apply_identities :
558+ for qubit in sorted (unused ):
559+ circuit .i (qubit )
523560
524561 if self ._noise_model :
525562 circuit = self ._noise_model .apply (circuit )
@@ -552,14 +589,12 @@ def _validate_noise_model_support(self):
552589 def _run_task (self , circuit , inputs = None ):
553590 raise NotImplementedError ("Need to implement task runner" )
554591
592+ def _run_task_batch (self , braket_circuits , pl_circuits , circuit_shots , mapped_wires ):
593+ raise NotImplementedError ("Need to implement batch runner" )
594+
555595 def _run_snapshots (self , snapshot_circuits , n_qubits , mapped_wires ):
556596 raise NotImplementedError ("Need to implement snapshots runner" )
557597
558- def _get_statistic (self , braket_result , mp ):
559- return translate_result (
560- braket_result , mp .map_wires (self .wire_map ), None , self ._braket_result_types
561- )
562-
563598 @staticmethod
564599 def _get_trainable_parameters (tape : QuantumTape ) -> dict [int , numbers .Number ]:
565600 trainable_indices = sorted (tape .trainable_params )
@@ -663,9 +698,24 @@ def use_grouping(self) -> bool:
663698 caps = self .capabilities ()
664699 return not ("provides_jacobian" in caps and caps ["provides_jacobian" ])
665700
666- def _run_task_batch (self , batch_circuits , batch_shots : int , inputs ):
701+ def _run_task_batch (self , braket_circuits , pl_circuits , batch_shots : int , inputs ):
702+ if self ._supports_program_sets :
703+ program_set = (
704+ ProgramSet .zip (braket_circuits , input_sets = inputs )
705+ if inputs
706+ else ProgramSet (braket_circuits )
707+ )
708+ task = self ._device .run (
709+ program_set ,
710+ s3_destination_folder = self ._s3_folder ,
711+ shots = len (program_set ) * batch_shots ,
712+ poll_timeout_seconds = self ._poll_timeout_seconds ,
713+ poll_interval_seconds = self ._poll_interval_seconds ,
714+ ** self ._run_kwargs ,
715+ )
716+ return self ._braket_program_set_to_pl_result (task .result (), pl_circuits )
667717 task_batch = self ._device .run_batch (
668- batch_circuits ,
718+ braket_circuits ,
669719 s3_destination_folder = self ._s3_folder ,
670720 shots = batch_shots ,
671721 max_parallel = self ._max_parallel ,
@@ -687,7 +737,10 @@ def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
687737 if self .tracker .active :
688738 self ._update_tracker_for_batch (task_batch , batch_shots )
689739
690- return braket_results_batch
740+ return [
741+ self ._braket_to_pl_result (braket_result , circuit )
742+ for braket_result , circuit in zip (braket_results_batch , pl_circuits )
743+ ]
691744
692745 def _run_task (self , circuit , inputs = None ):
693746 return self ._device .run (
@@ -703,7 +756,19 @@ def _run_task(self, circuit, inputs=None):
703756 def _run_snapshots (self , snapshot_circuits , n_qubits , mapped_wires ):
704757 n_snapshots = len (snapshot_circuits )
705758 outcomes = np .zeros ((n_snapshots , n_qubits ))
706- if self ._parallel :
759+ if self ._supports_program_sets :
760+ program_set = ProgramSet (snapshot_circuits )
761+ task = self ._device .run (
762+ program_set ,
763+ s3_destination_folder = self ._s3_folder ,
764+ shots = len (program_set ),
765+ poll_timeout_seconds = self ._poll_timeout_seconds ,
766+ poll_interval_seconds = self ._poll_interval_seconds ,
767+ ** self ._run_kwargs ,
768+ )
769+ for t , result in enumerate (task .result ()):
770+ outcomes [t ] = np .array (result [0 ].measurements [0 ])[mapped_wires ]
771+ elif self ._parallel :
707772 task_batch = self ._device .run_batch (
708773 snapshot_circuits ,
709774 s3_destination_folder = self ._s3_folder ,
@@ -1041,9 +1106,9 @@ def __init__(
10411106 device = LocalSimulator (backend )
10421107 super ().__init__ (wires , device , shots = shots , ** run_kwargs )
10431108
1044- def _run_task_batch (self , batch_circuits , batch_shots : int , inputs ):
1109+ def _run_task_batch (self , braket_circuits , pl_circuits , batch_shots : int , inputs ):
10451110 task_batch = self ._device .run_batch (
1046- batch_circuits ,
1111+ braket_circuits ,
10471112 shots = batch_shots ,
10481113 max_parallel = self ._max_parallel ,
10491114 inputs = inputs ,
@@ -1057,7 +1122,10 @@ def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
10571122 if self .tracker .active :
10581123 self ._update_tracker_for_batch (task_batch , batch_shots )
10591124
1060- return braket_results_batch
1125+ return [
1126+ self ._braket_to_pl_result (braket_result , circuit )
1127+ for braket_result , circuit in zip (braket_results_batch , pl_circuits )
1128+ ]
10611129
10621130 def _run_task (self , circuit , inputs = None ):
10631131 return self ._device .run (
0 commit comments