2323from braket .aws .aws_quantum_task import AwsQuantumTask
2424from braket .aws .aws_session import AwsSession
2525from braket .circuits import Circuit
26+ from braket .circuits .gate import Gate
2627from braket .ir .blackbird import Program as BlackbirdProgram
2728from braket .ir .openqasm import Program as OpenQasmProgram
29+ from braket .pulse .pulse_sequence import PulseSequence
30+ from braket .registers .qubit_set import QubitSet
2831from braket .tasks .quantum_task_batch import QuantumTaskBatch
2932
3033
@@ -61,6 +64,13 @@ def __init__(
6164 poll_timeout_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_TIMEOUT ,
6265 poll_interval_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_INTERVAL ,
6366 inputs : Union [dict [str , float ], list [dict [str , float ]]] | None = None ,
67+ gate_definitions : (
68+ Union [
69+ dict [tuple [Gate , QubitSet ], PulseSequence ],
70+ list [dict [tuple [Gate , QubitSet ], PulseSequence ]],
71+ ]
72+ | None
73+ ) = None ,
6474 reservation_arn : str | None = None ,
6575 * aws_quantum_task_args : Any ,
6676 ** aws_quantum_task_kwargs : Any ,
@@ -92,6 +102,9 @@ def __init__(
92102 inputs (Union[dict[str, float], list[dict[str, float]]] | None): Inputs to be passed
93103 along with the IR. If the IR supports inputs, the inputs will be updated
94104 with this value. Default: {}.
105+ gate_definitions (Union[dict[tuple[Gate, QubitSet], PulseSequence], list[dict[tuple[Gate, QubitSet], PulseSequence]]] | None): # noqa: E501
106+ User-defined gate calibration. The calibration is defined for a particular `Gate` on a
107+ particular `QubitSet` and is represented by a `PulseSequence`. Default: None.
95108 reservation_arn (str | None): The reservation ARN provided by Braket Direct
96109 to reserve exclusive usage for the device to run the quantum task on.
97110 Note: If you are creating tasks in a job that itself was created reservation ARN,
@@ -111,6 +124,7 @@ def __init__(
111124 poll_timeout_seconds ,
112125 poll_interval_seconds ,
113126 inputs ,
127+ gate_definitions ,
114128 reservation_arn ,
115129 * aws_quantum_task_args ,
116130 ** aws_quantum_task_kwargs ,
@@ -134,7 +148,7 @@ def __init__(
134148 self ._aws_quantum_task_kwargs = aws_quantum_task_kwargs
135149
136150 @staticmethod
137- def _tasks_and_inputs (
151+ def _tasks_inputs_gatedefs (
138152 task_specifications : Union [
139153 Union [Circuit , Problem , OpenQasmProgram , BlackbirdProgram , AnalogHamiltonianSimulation ],
140154 list [
@@ -144,45 +158,55 @@ def _tasks_and_inputs(
144158 ],
145159 ],
146160 inputs : Union [dict [str , float ], list [dict [str , float ]]] = None ,
161+ gate_definitions : Union [
162+ dict [tuple [Gate , QubitSet ], PulseSequence ],
163+ list [dict [tuple [Gate , QubitSet ], PulseSequence ]],
164+ ] = None ,
147165 ) -> list [
148166 tuple [
149167 Union [Circuit , Problem , OpenQasmProgram , BlackbirdProgram , AnalogHamiltonianSimulation ],
150168 dict [str , float ],
169+ dict [tuple [Gate , QubitSet ], PulseSequence ],
151170 ]
152171 ]:
153172 inputs = inputs or {}
154-
155- max_inputs_tasks = 1
156- single_task = isinstance (
157- task_specifications ,
158- (Circuit , Problem , OpenQasmProgram , BlackbirdProgram , AnalogHamiltonianSimulation ),
159- )
160- single_input = isinstance (inputs , dict )
161-
162- max_inputs_tasks = (
163- max (max_inputs_tasks , len (task_specifications )) if not single_task else max_inputs_tasks
164- )
165- max_inputs_tasks = (
166- max (max_inputs_tasks , len (inputs )) if not single_input else max_inputs_tasks
173+ gate_definitions = gate_definitions or {}
174+
175+ single_task_type = (
176+ Circuit ,
177+ Problem ,
178+ OpenQasmProgram ,
179+ BlackbirdProgram ,
180+ AnalogHamiltonianSimulation ,
167181 )
182+ single_input_type = dict
183+ single_gate_definitions_type = dict
168184
169- if not single_task and not single_input :
170- if len (task_specifications ) != len (inputs ):
171- raise ValueError ("Multiple inputs and task specifications must be equal in number." )
172- if single_task :
173- task_specifications = repeat (task_specifications , times = max_inputs_tasks )
185+ args = [task_specifications , inputs , gate_definitions ]
186+ single_arg_types = [single_task_type , single_input_type , single_gate_definitions_type ]
174187
175- if single_input :
176- inputs = repeat (inputs , times = max_inputs_tasks )
188+ batch_length = 1
189+ arg_lengths = []
190+ for arg , single_arg_type in zip (args , single_arg_types ):
191+ arg_length = 1 if isinstance (arg , single_arg_type ) else len (arg )
192+ arg_lengths .append (arg_length )
177193
178- tasks_and_inputs = zip (task_specifications , inputs )
194+ if arg_length != 1 :
195+ if batch_length != 1 and arg_length != batch_length :
196+ raise ValueError (
197+ "Multiple inputs, task specifications and gate definitions must "
198+ "be equal in length."
199+ )
200+ else :
201+ batch_length = arg_length
179202
180- if single_task and single_input :
181- tasks_and_inputs = list (tasks_and_inputs )
203+ for i , arg_length in enumerate (arg_lengths ):
204+ if arg_length == 1 :
205+ args [i ] = repeat (args [i ], batch_length )
182206
183- tasks_and_inputs = list (tasks_and_inputs )
207+ tasks_inputs_definitions = list (zip ( * args ) )
184208
185- for task_specification , input_map in tasks_and_inputs :
209+ for task_specification , input_map , _gate_definitions in tasks_inputs_definitions :
186210 if isinstance (task_specification , Circuit ):
187211 param_names = {param .name for param in task_specification .parameters }
188212 unbounded_parameters = param_names - set (input_map .keys ())
@@ -192,7 +216,7 @@ def _tasks_and_inputs(
192216 f"{ unbounded_parameters } "
193217 )
194218
195- return tasks_and_inputs
219+ return tasks_inputs_definitions
196220
197221 @staticmethod
198222 def _execute (
@@ -213,13 +237,22 @@ def _execute(
213237 poll_timeout_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_TIMEOUT ,
214238 poll_interval_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_INTERVAL ,
215239 inputs : Union [dict [str , float ], list [dict [str , float ]]] = None ,
240+ gate_definitions : (
241+ Union [
242+ dict [tuple [Gate , QubitSet ], PulseSequence ],
243+ list [dict [tuple [Gate , QubitSet ], PulseSequence ]],
244+ ]
245+ | None
246+ ) = None ,
216247 reservation_arn : str | None = None ,
217248 * args ,
218249 ** kwargs ,
219250 ) -> list [AwsQuantumTask ]:
220- tasks_and_inputs = AwsQuantumTaskBatch ._tasks_and_inputs (task_specifications , inputs )
251+ tasks_inputs_gatedefs = AwsQuantumTaskBatch ._tasks_inputs_gatedefs (
252+ task_specifications , inputs , gate_definitions
253+ )
221254 max_threads = min (max_parallel , max_workers )
222- remaining = [0 for _ in tasks_and_inputs ]
255+ remaining = [0 for _ in tasks_inputs_gatedefs ]
223256 try :
224257 with ThreadPoolExecutor (max_workers = max_threads ) as executor :
225258 task_futures = [
@@ -234,11 +267,12 @@ def _execute(
234267 poll_timeout_seconds = poll_timeout_seconds ,
235268 poll_interval_seconds = poll_interval_seconds ,
236269 inputs = input_map ,
270+ gate_definitions = gatedefs ,
237271 reservation_arn = reservation_arn ,
238272 * args ,
239273 ** kwargs ,
240274 )
241- for task , input_map in tasks_and_inputs
275+ for task , input_map , gatedefs in tasks_inputs_gatedefs
242276 ]
243277 except KeyboardInterrupt :
244278 # If an exception is thrown before the thread pool has finished,
@@ -266,6 +300,7 @@ def _create_task(
266300 shots : int ,
267301 poll_interval_seconds : float = AwsQuantumTask .DEFAULT_RESULTS_POLL_INTERVAL ,
268302 inputs : dict [str , float ] = None ,
303+ gate_definitions : dict [tuple [Gate , QubitSet ], PulseSequence ] | None = None ,
269304 reservation_arn : str | None = None ,
270305 * args ,
271306 ** kwargs ,
@@ -278,6 +313,7 @@ def _create_task(
278313 shots ,
279314 poll_interval_seconds = poll_interval_seconds ,
280315 inputs = inputs ,
316+ gate_definitions = gate_definitions ,
281317 reservation_arn = reservation_arn ,
282318 * args ,
283319 ** kwargs ,
0 commit comments