Skip to content

Commit 09572f4

Browse files
committed
Merge branch 'main' into jcjaskula-aws/add_in_place_modifications
1 parent fb1db5d commit 09572f4

File tree

5 files changed

+127
-85
lines changed

5 files changed

+127
-85
lines changed

src/braket/aws/aws_quantum_task.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def create(
105105
disable_qubit_rewiring: bool = False,
106106
tags: dict[str, str] | None = None,
107107
inputs: dict[str, float] | None = None,
108-
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None,
108+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None,
109109
quiet: bool = False,
110110
reservation_arn: str | None = None,
111111
*args,
@@ -148,10 +148,9 @@ def create(
148148
IR. If the IR supports inputs, the inputs will be updated with this value.
149149
Default: {}.
150150
151-
gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None):
152-
A `Dict` for user defined gate calibration. The calibration is defined for
153-
for a particular `Gate` on a particular `QubitSet` and is represented by
154-
a `PulseSequence`.
151+
gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): A `dict`
152+
of user defined gate calibrations. Each calibration is defined for a particular
153+
`Gate` on a particular `QubitSet` and is represented by a `PulseSequence`.
155154
Default: None.
156155
157156
quiet (bool): Sets the verbosity of the logger to low and does not report queue
@@ -190,6 +189,7 @@ def create(
190189
if tags is not None:
191190
create_task_kwargs.update({"tags": tags})
192191
inputs = inputs or {}
192+
gate_definitions = gate_definitions or {}
193193

194194
if reservation_arn:
195195
create_task_kwargs.update(
@@ -561,7 +561,7 @@ def _create_internal(
561561
device_parameters: Union[dict, BraketSchemaBase],
562562
disable_qubit_rewiring: bool,
563563
inputs: dict[str, float],
564-
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
564+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
565565
*args,
566566
**kwargs,
567567
) -> AwsQuantumTask:
@@ -577,7 +577,7 @@ def _(
577577
_device_parameters: Union[dict, BraketSchemaBase], # Not currently used for OpenQasmProgram
578578
_disable_qubit_rewiring: bool,
579579
inputs: dict[str, float],
580-
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
580+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
581581
*args,
582582
**kwargs,
583583
) -> AwsQuantumTask:
@@ -600,7 +600,7 @@ def _(
600600
device_parameters: Union[dict, BraketSchemaBase],
601601
_disable_qubit_rewiring: bool,
602602
inputs: dict[str, float],
603-
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
603+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
604604
*args,
605605
**kwargs,
606606
) -> AwsQuantumTask:
@@ -639,7 +639,7 @@ def _(
639639
_device_parameters: Union[dict, BraketSchemaBase],
640640
_disable_qubit_rewiring: bool,
641641
inputs: dict[str, float],
642-
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
642+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
643643
*args,
644644
**kwargs,
645645
) -> AwsQuantumTask:
@@ -657,7 +657,7 @@ def _(
657657
device_parameters: Union[dict, BraketSchemaBase],
658658
disable_qubit_rewiring: bool,
659659
inputs: dict[str, float],
660-
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
660+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
661661
*args,
662662
**kwargs,
663663
) -> AwsQuantumTask:
@@ -678,7 +678,7 @@ def _(
678678
if (
679679
disable_qubit_rewiring
680680
or Instruction(StartVerbatimBox()) in circuit.instructions
681-
or gate_definitions is not None
681+
or gate_definitions
682682
or any(isinstance(instruction.operator, PulseGate) for instruction in circuit.instructions)
683683
):
684684
qubit_reference_type = QubitReferenceType.PHYSICAL

src/braket/aws/aws_quantum_task_batch.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
from braket.aws.aws_quantum_task import AwsQuantumTask
2424
from braket.aws.aws_session import AwsSession
2525
from braket.circuits import Circuit
26+
from braket.circuits.gate import Gate
2627
from braket.ir.blackbird import Program as BlackbirdProgram
2728
from braket.ir.openqasm import Program as OpenQasmProgram
29+
from braket.pulse.pulse_sequence import PulseSequence
30+
from braket.registers.qubit_set import QubitSet
2831
from braket.tasks.quantum_task_batch import QuantumTaskBatch
2932

3033

@@ -61,6 +64,13 @@ def __init__(
6164
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
6265
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
6366
inputs: Union[dict[str, float], list[dict[str, float]]] | None = None,
67+
gate_definitions: (
68+
Union[
69+
dict[tuple[Gate, QubitSet], PulseSequence],
70+
list[dict[tuple[Gate, QubitSet], PulseSequence]],
71+
]
72+
| None
73+
) = None,
6474
reservation_arn: str | None = None,
6575
*aws_quantum_task_args: Any,
6676
**aws_quantum_task_kwargs: Any,
@@ -92,6 +102,9 @@ def __init__(
92102
inputs (Union[dict[str, float], list[dict[str, float]]] | None): Inputs to be passed
93103
along with the IR. If the IR supports inputs, the inputs will be updated
94104
with this value. Default: {}.
105+
gate_definitions (Union[dict[tuple[Gate, QubitSet], PulseSequence], list[dict[tuple[Gate, QubitSet], PulseSequence]]] | None): # noqa: E501
106+
User-defined gate calibration. The calibration is defined for a particular `Gate` on a
107+
particular `QubitSet` and is represented by a `PulseSequence`. Default: None.
95108
reservation_arn (str | None): The reservation ARN provided by Braket Direct
96109
to reserve exclusive usage for the device to run the quantum task on.
97110
Note: If you are creating tasks in a job that itself was created reservation ARN,
@@ -111,6 +124,7 @@ def __init__(
111124
poll_timeout_seconds,
112125
poll_interval_seconds,
113126
inputs,
127+
gate_definitions,
114128
reservation_arn,
115129
*aws_quantum_task_args,
116130
**aws_quantum_task_kwargs,
@@ -134,7 +148,7 @@ def __init__(
134148
self._aws_quantum_task_kwargs = aws_quantum_task_kwargs
135149

136150
@staticmethod
137-
def _tasks_and_inputs(
151+
def _tasks_inputs_gatedefs(
138152
task_specifications: Union[
139153
Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation],
140154
list[
@@ -144,45 +158,55 @@ def _tasks_and_inputs(
144158
],
145159
],
146160
inputs: Union[dict[str, float], list[dict[str, float]]] = None,
161+
gate_definitions: Union[
162+
dict[tuple[Gate, QubitSet], PulseSequence],
163+
list[dict[tuple[Gate, QubitSet], PulseSequence]],
164+
] = None,
147165
) -> list[
148166
tuple[
149167
Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation],
150168
dict[str, float],
169+
dict[tuple[Gate, QubitSet], PulseSequence],
151170
]
152171
]:
153172
inputs = inputs or {}
154-
155-
max_inputs_tasks = 1
156-
single_task = isinstance(
157-
task_specifications,
158-
(Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation),
159-
)
160-
single_input = isinstance(inputs, dict)
161-
162-
max_inputs_tasks = (
163-
max(max_inputs_tasks, len(task_specifications)) if not single_task else max_inputs_tasks
164-
)
165-
max_inputs_tasks = (
166-
max(max_inputs_tasks, len(inputs)) if not single_input else max_inputs_tasks
173+
gate_definitions = gate_definitions or {}
174+
175+
single_task_type = (
176+
Circuit,
177+
Problem,
178+
OpenQasmProgram,
179+
BlackbirdProgram,
180+
AnalogHamiltonianSimulation,
167181
)
182+
single_input_type = dict
183+
single_gate_definitions_type = dict
168184

169-
if not single_task and not single_input:
170-
if len(task_specifications) != len(inputs):
171-
raise ValueError("Multiple inputs and task specifications must be equal in number.")
172-
if single_task:
173-
task_specifications = repeat(task_specifications, times=max_inputs_tasks)
185+
args = [task_specifications, inputs, gate_definitions]
186+
single_arg_types = [single_task_type, single_input_type, single_gate_definitions_type]
174187

175-
if single_input:
176-
inputs = repeat(inputs, times=max_inputs_tasks)
188+
batch_length = 1
189+
arg_lengths = []
190+
for arg, single_arg_type in zip(args, single_arg_types):
191+
arg_length = 1 if isinstance(arg, single_arg_type) else len(arg)
192+
arg_lengths.append(arg_length)
177193

178-
tasks_and_inputs = zip(task_specifications, inputs)
194+
if arg_length != 1:
195+
if batch_length != 1 and arg_length != batch_length:
196+
raise ValueError(
197+
"Multiple inputs, task specifications and gate definitions must "
198+
"be equal in length."
199+
)
200+
else:
201+
batch_length = arg_length
179202

180-
if single_task and single_input:
181-
tasks_and_inputs = list(tasks_and_inputs)
203+
for i, arg_length in enumerate(arg_lengths):
204+
if arg_length == 1:
205+
args[i] = repeat(args[i], batch_length)
182206

183-
tasks_and_inputs = list(tasks_and_inputs)
207+
tasks_inputs_definitions = list(zip(*args))
184208

185-
for task_specification, input_map in tasks_and_inputs:
209+
for task_specification, input_map, _gate_definitions in tasks_inputs_definitions:
186210
if isinstance(task_specification, Circuit):
187211
param_names = {param.name for param in task_specification.parameters}
188212
unbounded_parameters = param_names - set(input_map.keys())
@@ -192,7 +216,7 @@ def _tasks_and_inputs(
192216
f"{unbounded_parameters}"
193217
)
194218

195-
return tasks_and_inputs
219+
return tasks_inputs_definitions
196220

197221
@staticmethod
198222
def _execute(
@@ -213,13 +237,22 @@ def _execute(
213237
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
214238
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
215239
inputs: Union[dict[str, float], list[dict[str, float]]] = None,
240+
gate_definitions: (
241+
Union[
242+
dict[tuple[Gate, QubitSet], PulseSequence],
243+
list[dict[tuple[Gate, QubitSet], PulseSequence]],
244+
]
245+
| None
246+
) = None,
216247
reservation_arn: str | None = None,
217248
*args,
218249
**kwargs,
219250
) -> list[AwsQuantumTask]:
220-
tasks_and_inputs = AwsQuantumTaskBatch._tasks_and_inputs(task_specifications, inputs)
251+
tasks_inputs_gatedefs = AwsQuantumTaskBatch._tasks_inputs_gatedefs(
252+
task_specifications, inputs, gate_definitions
253+
)
221254
max_threads = min(max_parallel, max_workers)
222-
remaining = [0 for _ in tasks_and_inputs]
255+
remaining = [0 for _ in tasks_inputs_gatedefs]
223256
try:
224257
with ThreadPoolExecutor(max_workers=max_threads) as executor:
225258
task_futures = [
@@ -234,11 +267,12 @@ def _execute(
234267
poll_timeout_seconds=poll_timeout_seconds,
235268
poll_interval_seconds=poll_interval_seconds,
236269
inputs=input_map,
270+
gate_definitions=gatedefs,
237271
reservation_arn=reservation_arn,
238272
*args,
239273
**kwargs,
240274
)
241-
for task, input_map in tasks_and_inputs
275+
for task, input_map, gatedefs in tasks_inputs_gatedefs
242276
]
243277
except KeyboardInterrupt:
244278
# If an exception is thrown before the thread pool has finished,
@@ -266,6 +300,7 @@ def _create_task(
266300
shots: int,
267301
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
268302
inputs: dict[str, float] = None,
303+
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None,
269304
reservation_arn: str | None = None,
270305
*args,
271306
**kwargs,
@@ -278,6 +313,7 @@ def _create_task(
278313
shots,
279314
poll_interval_seconds=poll_interval_seconds,
280315
inputs=inputs,
316+
gate_definitions=gate_definitions,
281317
reservation_arn=reservation_arn,
282318
*args,
283319
**kwargs,

0 commit comments

Comments
 (0)