@@ -306,20 +306,10 @@ def __init__(self, original_device):
306306
307307 check_device_wires (original_device .wires )
308308
309- super ().__init__ (wires = original_device .wires , shots = original_device . shots )
309+ super ().__init__ (wires = original_device .wires )
310310
311311 # Capability loading
312- device_capabilities = get_device_capabilities (original_device )
313-
314- # TODO: This is a temporary measure to ensure consistency of behaviour. Remove this
315- # when customizable multi-pathway decomposition is implemented. (Epic 74474)
316- if hasattr (original_device , "_to_matrix_ops" ):
317- _to_matrix_ops = getattr (original_device , "_to_matrix_ops" )
318- setattr (device_capabilities , "to_matrix_ops" , _to_matrix_ops )
319- if _to_matrix_ops and not device_capabilities .supports_operation ("QubitUnitary" ):
320- raise CompileError (
321- "The device that specifies to_matrix_ops must support QubitUnitary."
322- )
312+ device_capabilities = get_device_capabilities (original_device , self .original_device .shots )
323313
324314 backend = QJITDevice .extract_backend_info (original_device )
325315
@@ -333,6 +323,7 @@ def preprocess(
333323 self ,
334324 ctx ,
335325 execution_config : Optional [qml .devices .ExecutionConfig ] = None ,
326+ shots = None ,
336327 ):
337328 """This function defines the device transform program to be applied and an updated device
338329 configuration. The transform program will be created and applied to the tape before
@@ -357,22 +348,27 @@ def preprocess(
357348
358349 if execution_config is None :
359350 execution_config = qml .devices .ExecutionConfig ()
360-
361351 _ , config = self .original_device .preprocess (execution_config )
362352
363353 program = TransformProgram ()
354+ if shots is None :
355+ capabilities = self .capabilities
356+ else :
357+ # recompute device capabilities if shots were provided through set_shots
358+ device_caps = get_device_capabilities (self .original_device , shots )
359+ capabilities = get_qjit_device_capabilities (device_caps )
364360
365361 # measurement transforms may change operations on the tape to accommodate
366362 # measurement transformations, so must occur before decomposition
367- measurement_transforms = self ._measurement_transform_program ()
363+ measurement_transforms = self ._measurement_transform_program (capabilities )
368364 config = replace (config , device_options = deepcopy (config .device_options ))
369365 program = program + measurement_transforms
370366
371367 # decomposition to supported ops/measurements
372368 program .add_transform (
373369 catalyst_decompose ,
374370 ctx = ctx ,
375- capabilities = self . capabilities ,
371+ capabilities = capabilities ,
376372 grad_method = config .gradient_method ,
377373 )
378374
@@ -382,9 +378,9 @@ def preprocess(
382378 )
383379 program .add_transform (
384380 validate_measurements ,
385- self . capabilities ,
381+ capabilities ,
386382 self .original_device .name ,
387- self . original_device . shots ,
383+ shots ,
388384 )
389385
390386 if config .gradient_method is not None :
@@ -396,47 +392,47 @@ def preprocess(
396392
397393 return program , config
398394
399- def _measurement_transform_program (self ):
400-
395+ def _measurement_transform_program (self , capabilities = None ):
396+ capabilities = capabilities or self . capabilities
401397 measurement_program = TransformProgram ()
402398 if isinstance (self .original_device , SoftwareQQPP ):
403399 return measurement_program
404400
405- supports_sum_observables = "Sum" in self . capabilities .observables
401+ supports_sum_observables = "Sum" in capabilities .observables
406402
407- if self . capabilities .non_commuting_observables is False :
403+ if capabilities .non_commuting_observables is False :
408404 measurement_program .add_transform (split_non_commuting )
409405 elif not supports_sum_observables :
410406 measurement_program .add_transform (split_to_single_terms )
411407
412408 # if no observables are supported, we apply a transform to convert *everything* to the
413409 # readout basis, using either sample or counts based on device specification
414- if not self . capabilities .observables :
410+ if not capabilities .observables :
415411 if not split_non_commuting in measurement_program :
416412 # this *should* be redundant, a TOML that doesn't have observables should have
417413 # a False non_commuting_observables flag, but we aren't enforcing that
418414 measurement_program .add_transform (split_non_commuting )
419- if "SampleMP" in self . capabilities .measurement_processes :
415+ if "SampleMP" in capabilities .measurement_processes :
420416 measurement_program .add_transform (measurements_from_samples , self .wires )
421- elif "CountsMP" in self . capabilities .measurement_processes :
417+ elif "CountsMP" in capabilities .measurement_processes :
422418 measurement_program .add_transform (measurements_from_counts , self .wires )
423419 else :
424420 raise RuntimeError ("The device does not support observables or sample/counts" )
425421
426- elif not self . capabilities .measurement_processes .keys () - {"CountsMP" , "SampleMP" }:
422+ elif not capabilities .measurement_processes .keys () - {"CountsMP" , "SampleMP" }:
427423 # ToDo: this branch should become unnecessary when selective conversion of
428424 # unsupported MPs is finished, see ToDo below
429425 if not split_non_commuting in measurement_program : # pragma: no branch
430426 measurement_program .add_transform (split_non_commuting )
431427 mp_transform = (
432428 measurements_from_samples
433- if "SampleMP" in self . capabilities .measurement_processes
429+ if "SampleMP" in capabilities .measurement_processes
434430 else measurements_from_counts
435431 )
436432 measurement_program .add_transform (mp_transform , self .wires )
437433
438434 # if only some observables are supported, we try to diagonalize those that aren't
439- elif not {"PauliX" , "PauliY" , "PauliZ" , "Hadamard" }.issubset (self . capabilities .observables ):
435+ elif not {"PauliX" , "PauliY" , "PauliZ" , "Hadamard" }.issubset (capabilities .observables ):
440436 if not split_non_commuting in measurement_program :
441437 # the device might support non commuting measurements but not all the
442438 # Pauli + Hadamard observables, so here it is needed
@@ -449,7 +445,7 @@ def _measurement_transform_program(self):
449445 }
450446 # checking which base observables are unsupported and need to be diagonalized
451447 supported_observables = {"PauliX" , "PauliY" , "PauliZ" , "Hadamard" }.intersection (
452- self . capabilities .observables
448+ capabilities .observables
453449 )
454450 supported_observables = [_obs_dict [obs ] for obs in supported_observables ]
455451
@@ -520,15 +516,23 @@ def _load_device_capabilities(device) -> DeviceCapabilities:
520516 return capabilities
521517
522518
523- def get_device_capabilities (device ) -> DeviceCapabilities :
519+ def get_device_capabilities (device , shots = None ) -> DeviceCapabilities :
524520 """Get or load the original DeviceCapabilities from device"""
525521
526522 assert not isinstance (device , QJITDevice )
527523
528- shots_present = bool (device .shots )
529- device_capabilities = _load_device_capabilities (device )
524+ shots_present = bool (shots )
525+ device_capabilities = _load_device_capabilities (device ).filter (finite_shots = shots_present )
526+
527+ # TODO: This is a temporary measure to ensure consistency of behaviour. Remove this
528+ # when customizable multi-pathway decomposition is implemented. (Epic 74474)
529+ if hasattr (device , "_to_matrix_ops" ):
530+ _to_matrix_ops = getattr (device , "_to_matrix_ops" )
531+ setattr (device_capabilities , "to_matrix_ops" , _to_matrix_ops )
532+ if _to_matrix_ops and not device_capabilities .supports_operation ("QubitUnitary" ):
533+ raise CompileError ("The device that specifies to_matrix_ops must support QubitUnitary." )
530534
531- return device_capabilities . filter ( finite_shots = shots_present )
535+ return device_capabilities
532536
533537
534538def is_dynamic_wires (wires : qml .wires .Wires ):
0 commit comments