Skip to content

Commit 1f0d942

Browse files
authored
feat: Program sets (#293)
1 parent 42986ba commit 1f0d942

File tree

5 files changed

+508
-95
lines changed

5 files changed

+508
-95
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
packages=find_namespace_packages(where="src", exclude=("test",)),
3535
package_dir={"": "src"},
3636
install_requires=[
37-
"amazon-braket-sdk>=1.87.0",
37+
"amazon-braket-sdk>=1.97.0",
3838
"autoray>=0.6.11",
3939
"pennylane>=0.34.0",
4040
],

src/braket/pennylane_plugin/braket_device.py

Lines changed: 105 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,14 @@
8383
translate_result,
8484
translate_result_type,
8585
)
86+
from braket.program_sets import ProgramSet
8687
from braket.simulator import BraketSimulator
8788
from braket.tasks import GateModelQuantumTaskResult, QuantumTask
8889
from braket.tasks.local_quantum_task_batch import LocalQuantumTaskBatch
8990

9091
from ._version import __version__
9192

92-
RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP)
93+
RETURN_TYPES = (ExpectationMP, VarianceMP, SampleMP, ProbabilityMP, StateMP, CountsMP)
9394
MIN_SIMULATOR_BILLED_MS = 3000
9495
OBS_LIST = (qml.PauliX, qml.PauliY, qml.PauliZ)
9596

@@ -168,6 +169,10 @@ def __init__(
168169
self._supported_obs = supported_observables(self._device, self.shots)
169170
self._check_supported_result_types()
170171
self._verbatim = verbatim
172+
self._supports_program_sets = (
173+
DeviceActionType.OPENQASM_PROGRAM_SET in self._device.properties.action
174+
and self._shots is not None
175+
)
171176

172177
if noise_model:
173178
self._validate_noise_model_support()
@@ -202,7 +207,7 @@ def parallel(self) -> bool:
202207
return self._parallel
203208

204209
def batch_execute(self, circuits, **run_kwargs):
205-
if not self._parallel:
210+
if not self._parallel and not self._supports_program_sets:
206211
return super().batch_execute(circuits)
207212

208213
for circuit in circuits:
@@ -220,6 +225,7 @@ def batch_execute(self, circuits, **run_kwargs):
220225
self._pl_to_braket_circuit(
221226
circuit,
222227
trainable_indices=frozenset(trainable.keys()),
228+
add_observables=not self._supports_program_sets,
223229
**run_kwargs,
224230
)
225231
)
@@ -232,18 +238,15 @@ def batch_execute(self, circuits, **run_kwargs):
232238
else []
233239
)
234240

235-
braket_results_batch = self._run_task_batch(braket_circuits, batch_shots, batch_inputs)
236-
237-
return [
238-
self._braket_to_pl_result(braket_result, circuit)
239-
for braket_result, circuit in zip(braket_results_batch, circuits)
240-
]
241+
return self._run_task_batch(braket_circuits, circuits, batch_shots, batch_inputs)
241242

242243
def _pl_to_braket_circuit(
243244
self,
244245
circuit: QuantumTape,
245246
compute_gradient: bool = False,
246247
trainable_indices: frozenset[int] = None,
248+
*,
249+
add_observables: bool = True,
247250
**run_kwargs,
248251
):
249252
"""Converts a PennyLane circuit to a Braket circuit"""
@@ -259,17 +262,28 @@ def _pl_to_braket_circuit(
259262
if compute_gradient:
260263
braket_circuit = self._apply_gradient_result_type(circuit, braket_circuit)
261264
elif not isinstance(circuit.measurements[0], MeasurementTransform):
262-
for measurement in circuit.measurements:
263-
translated = translate_result_type(
264-
measurement.map_wires(self.wire_map),
265-
None,
266-
self._braket_result_types,
265+
if add_observables:
266+
for measurement in circuit.measurements:
267+
translated = translate_result_type(
268+
measurement.map_wires(self.wire_map),
269+
None,
270+
self._braket_result_types,
271+
)
272+
if isinstance(translated, tuple):
273+
for result_type in translated:
274+
braket_circuit.add_result_type(result_type)
275+
else:
276+
braket_circuit.add_result_type(translated)
277+
else:
278+
groups = qml.pauli.group_observables(
279+
[measurement.obs for measurement in circuit.measurements], grouping_type="qwc"
267280
)
268-
if isinstance(translated, tuple):
269-
for result_type in translated:
270-
braket_circuit.add_result_type(result_type)
271-
else:
272-
braket_circuit.add_result_type(translated)
281+
if len(groups) > 1:
282+
raise ValueError(
283+
f"Observables need to mutually commute, but found {len(groups)}: {groups}"
284+
)
285+
diagonalizing_ops = qml.pauli.diagonalize_qwc_pauli_words(groups[0])[0]
286+
braket_circuit += self.apply(diagonalizing_ops, apply_identities=False)
273287

274288
return braket_circuit
275289

@@ -316,7 +330,7 @@ def _update_tracker_for_batch(
316330
self.tracker.update(batches=1, executions=total_executions, shots=total_shots)
317331
self.tracker.record()
318332

319-
def statistics(
333+
def _statistics(
320334
self,
321335
braket_result: GateModelQuantumTaskResult,
322336
measurements: Sequence[MeasurementProcess],
@@ -338,14 +352,18 @@ def statistics(
338352
for mp in measurements:
339353
if not isinstance(mp, RETURN_TYPES):
340354
raise QuantumFunctionError("Unsupported return type: {}".format(type(mp)))
341-
results.append(self._get_statistic(braket_result, mp))
355+
results.append(
356+
translate_result(
357+
braket_result, mp.map_wires(self.wire_map), None, self._braket_result_types
358+
)
359+
)
342360
return results
343361

344362
def _braket_to_pl_result(self, braket_result, circuit):
345363
"""Calculates the PennyLane results from a Braket task result. A PennyLane circuit
346364
also determines the output observables."""
347365
# Compute the required statistics
348-
results = self.statistics(braket_result, circuit.measurements)
366+
results = self._statistics(braket_result, circuit.measurements)
349367
ag_results = [
350368
result
351369
for result in braket_result.result_types
@@ -378,6 +396,25 @@ def _braket_to_pl_result(self, braket_result, circuit):
378396
return onp.array(results).squeeze()
379397
return tuple(onp.array(result).squeeze() for result in results)
380398

399+
def _braket_program_set_to_pl_result(self, program_set_result, circuits):
400+
results = []
401+
for program_result, circuit in zip(program_set_result, circuits):
402+
# Only one executable per program
403+
measurements = program_result[0].measurements
404+
405+
# Program sets require shots > 0,
406+
# so the circuit's measurements are guaranteed to be SampleMeasurements
407+
executable_results = [
408+
measurement.process_samples(measurements, wire_order=measurement.wires)
409+
for measurement in circuit.measurements
410+
]
411+
results.append(
412+
onp.array(executable_results).squeeze()
413+
if len(circuit.measurements) == 1
414+
else tuple(onp.array(result).squeeze() for result in executable_results)
415+
)
416+
return results
417+
381418
@staticmethod
382419
def _tracking_data(task):
383420
if task.state() == "COMPLETED":
@@ -410,8 +447,6 @@ def classical_shadow(self, obs, circuit):
410447
rng = np.random.default_rng(seed)
411448
recipes = rng.integers(0, 3, size=(n_snapshots, n_qubits))
412449

413-
outcomes = np.zeros((n_snapshots, n_qubits))
414-
415450
snapshot_rotations = [
416451
[
417452
rot
@@ -484,6 +519,7 @@ def apply(
484519
use_unique_params: bool = False,
485520
*,
486521
trainable_indices: Optional[frozenset[int]] = None,
522+
apply_identities: bool = True,
487523
**run_kwargs,
488524
) -> Circuit:
489525
"""Instantiate Braket Circuit object."""
@@ -518,8 +554,9 @@ def apply(
518554
unused = set(range(self.num_wires)) - {int(qubit) for qubit in circuit.qubits}
519555

520556
# To ensure the results have the right number of qubits
521-
for qubit in sorted(unused):
522-
circuit.i(qubit)
557+
if apply_identities:
558+
for qubit in sorted(unused):
559+
circuit.i(qubit)
523560

524561
if self._noise_model:
525562
circuit = self._noise_model.apply(circuit)
@@ -552,14 +589,12 @@ def _validate_noise_model_support(self):
552589
def _run_task(self, circuit, inputs=None):
553590
raise NotImplementedError("Need to implement task runner")
554591

592+
def _run_task_batch(self, braket_circuits, pl_circuits, circuit_shots, mapped_wires):
593+
raise NotImplementedError("Need to implement batch runner")
594+
555595
def _run_snapshots(self, snapshot_circuits, n_qubits, mapped_wires):
556596
raise NotImplementedError("Need to implement snapshots runner")
557597

558-
def _get_statistic(self, braket_result, mp):
559-
return translate_result(
560-
braket_result, mp.map_wires(self.wire_map), None, self._braket_result_types
561-
)
562-
563598
@staticmethod
564599
def _get_trainable_parameters(tape: QuantumTape) -> dict[int, numbers.Number]:
565600
trainable_indices = sorted(tape.trainable_params)
@@ -663,9 +698,24 @@ def use_grouping(self) -> bool:
663698
caps = self.capabilities()
664699
return not ("provides_jacobian" in caps and caps["provides_jacobian"])
665700

666-
def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
701+
def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs):
702+
if self._supports_program_sets:
703+
program_set = (
704+
ProgramSet.zip(braket_circuits, input_sets=inputs)
705+
if inputs
706+
else ProgramSet(braket_circuits)
707+
)
708+
task = self._device.run(
709+
program_set,
710+
s3_destination_folder=self._s3_folder,
711+
shots=len(program_set) * batch_shots,
712+
poll_timeout_seconds=self._poll_timeout_seconds,
713+
poll_interval_seconds=self._poll_interval_seconds,
714+
**self._run_kwargs,
715+
)
716+
return self._braket_program_set_to_pl_result(task.result(), pl_circuits)
667717
task_batch = self._device.run_batch(
668-
batch_circuits,
718+
braket_circuits,
669719
s3_destination_folder=self._s3_folder,
670720
shots=batch_shots,
671721
max_parallel=self._max_parallel,
@@ -687,7 +737,10 @@ def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
687737
if self.tracker.active:
688738
self._update_tracker_for_batch(task_batch, batch_shots)
689739

690-
return braket_results_batch
740+
return [
741+
self._braket_to_pl_result(braket_result, circuit)
742+
for braket_result, circuit in zip(braket_results_batch, pl_circuits)
743+
]
691744

692745
def _run_task(self, circuit, inputs=None):
693746
return self._device.run(
@@ -703,7 +756,19 @@ def _run_task(self, circuit, inputs=None):
703756
def _run_snapshots(self, snapshot_circuits, n_qubits, mapped_wires):
704757
n_snapshots = len(snapshot_circuits)
705758
outcomes = np.zeros((n_snapshots, n_qubits))
706-
if self._parallel:
759+
if self._supports_program_sets:
760+
program_set = ProgramSet(snapshot_circuits)
761+
task = self._device.run(
762+
program_set,
763+
s3_destination_folder=self._s3_folder,
764+
shots=len(program_set),
765+
poll_timeout_seconds=self._poll_timeout_seconds,
766+
poll_interval_seconds=self._poll_interval_seconds,
767+
**self._run_kwargs,
768+
)
769+
for t, result in enumerate(task.result()):
770+
outcomes[t] = np.array(result[0].measurements[0])[mapped_wires]
771+
elif self._parallel:
707772
task_batch = self._device.run_batch(
708773
snapshot_circuits,
709774
s3_destination_folder=self._s3_folder,
@@ -1041,9 +1106,9 @@ def __init__(
10411106
device = LocalSimulator(backend)
10421107
super().__init__(wires, device, shots=shots, **run_kwargs)
10431108

1044-
def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
1109+
def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs):
10451110
task_batch = self._device.run_batch(
1046-
batch_circuits,
1111+
braket_circuits,
10471112
shots=batch_shots,
10481113
max_parallel=self._max_parallel,
10491114
inputs=inputs,
@@ -1057,7 +1122,10 @@ def _run_task_batch(self, batch_circuits, batch_shots: int, inputs):
10571122
if self.tracker.active:
10581123
self._update_tracker_for_batch(task_batch, batch_shots)
10591124

1060-
return braket_results_batch
1125+
return [
1126+
self._braket_to_pl_result(braket_result, circuit)
1127+
for braket_result, circuit in zip(braket_results_batch, pl_circuits)
1128+
]
10611129

10621130
def _run_task(self, circuit, inputs=None):
10631131
return self._device.run(

src/braket/pennylane_plugin/translation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,9 @@ def translate_result_type( # noqa: C901
594594
if isinstance(observable, qml.ops.LinearCombination):
595595
if isinstance(measurement, qml.measurements.ExpectationMP):
596596
return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1])
597-
raise NotImplementedError(f"Return type {type(measurement)} unsupported for LinearCombination")
597+
raise NotImplementedError(
598+
f"Return type {type(measurement)} unsupported for LinearCombination"
599+
)
598600

599601
braket_observable = _translate_observable(observable)
600602
if isinstance(measurement, qml.measurements.ExpectationMP):
@@ -722,7 +724,11 @@ def translate_result(
722724
]
723725

724726
targets = targets or measurement.wires.tolist()
725-
if isinstance(measurement, qml.measurements.CountsMP) and not measurement.all_outcomes and observable is None:
727+
if (
728+
isinstance(measurement, qml.measurements.CountsMP)
729+
and not measurement.all_outcomes
730+
and observable is None
731+
):
726732
if targets:
727733
new_dict = {}
728734
for key, value in braket_result.measurement_counts.items():

0 commit comments

Comments
 (0)