Skip to content

Commit feeecae

Browse files
JerryChen97dime10paul0403
authored
Replacing device(..., shots=...) with set_shots (#1952)
**Context:** PennyLane is [deprecating](PennyLaneAI/pennylane#7979) the interface of setting shots on device level. Hopefully this PR can help Catalyst to avoid `PennyLaneDeprecationWarning`. We will start from tests and then carefully check if within source code there's any places to modify **Description of the Change:** Source code: - `get_device_capabilities` accepts an extra kwarg `shots=` in order to be able to receive `tape.shots`. - The `capabilities` of `QJITDevice` has been heavily adjusted, but only to extend to accept `shots` information from `tape(s)`. As the typical deprecation cycle intended all the time, the previous `QJITDevice.shots` and `QJITDevice.capabilities` passes keep as the same as before if a user want to use them. - `validate_measurements` arg `shots` has been changed from positional to keyword. Internally, this help prefers the shots of its input `tape` - `QFunc` does not re-initialize the copied devices with shots any more. Tests: - A lot. But just to transit from `device(..., shots=...)` to `set_shots`. - Tests related to `capture` remain the same but with warnings caught. **Benefits:** **Possible Drawbacks:** Unlike a typical deprecation, where we completely remove the deprecated usage, this one will be co-existing with the new UI. **Related GitHub Issues:** [sc-95431] --------- Co-authored-by: David Ittah <[email protected]> Co-authored-by: Paul <[email protected]>
1 parent 91df026 commit feeecae

33 files changed

+478
-299
lines changed

.dep-versions

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ enzyme=v0.0.186
99
# Always remove custom PL/LQ versions before release.
1010

1111
# For a custom PL version, update the package version here and at
12-
pennylane=0.43.0.dev24
12+
# 'doc/requirements.txt'
13+
pennylane=0.43.0.dev30
1314

1415
# For a custom LQ/LK version, update the package version here and at
1516
# 'doc/requirements.txt'

doc/releases/changelog-dev.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454

5555
<h3>Deprecations 👋</h3>
5656

57+
* Deprecated usages of `Device.shots` along with setting `device(..., shots=...)`.
58+
Heavily adjusted frontend pipelines within qfunc, tracer, verification and QJITDevice to account for this change.
59+
[(#1952)](https://github.com/PennyLaneAI/catalyst/pull/1952)
60+
5761
<h3>Bug fixes 🐛</h3>
5862

5963
* Fix wrong handling of partitioned shots in the decomposition pass of `measurements_from_samples`.
@@ -162,6 +166,7 @@
162166
This release contains contributions from (in alphabetical order):
163167

164168
Joey Carter,
169+
Yushao Chen,
165170
Sengthai Heng,
166171
David Ittah,
167172
Christina Lee,

doc/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ lxml_html_clean
3333
--extra-index-url https://test.pypi.org/simple/
3434
pennylane-lightning-kokkos==0.43.0.dev12
3535
pennylane-lightning==0.43.0.dev12
36-
pennylane==0.43.0.dev24
36+
pennylane==0.43.0.dev30

frontend/catalyst/device/qjit_device.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,10 @@ def __init__(self, original_device):
306306

307307
check_device_wires(original_device.wires)
308308

309-
super().__init__(wires=original_device.wires, shots=original_device.shots)
309+
super().__init__(wires=original_device.wires)
310310

311311
# Capability loading
312-
device_capabilities = get_device_capabilities(original_device)
313-
314-
# TODO: This is a temporary measure to ensure consistency of behaviour. Remove this
315-
# when customizable multi-pathway decomposition is implemented. (Epic 74474)
316-
if hasattr(original_device, "_to_matrix_ops"):
317-
_to_matrix_ops = getattr(original_device, "_to_matrix_ops")
318-
setattr(device_capabilities, "to_matrix_ops", _to_matrix_ops)
319-
if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"):
320-
raise CompileError(
321-
"The device that specifies to_matrix_ops must support QubitUnitary."
322-
)
312+
device_capabilities = get_device_capabilities(original_device, self.original_device.shots)
323313

324314
backend = QJITDevice.extract_backend_info(original_device)
325315

@@ -333,6 +323,7 @@ def preprocess(
333323
self,
334324
ctx,
335325
execution_config: Optional[qml.devices.ExecutionConfig] = None,
326+
shots=None,
336327
):
337328
"""This function defines the device transform program to be applied and an updated device
338329
configuration. The transform program will be created and applied to the tape before
@@ -357,22 +348,27 @@ def preprocess(
357348

358349
if execution_config is None:
359350
execution_config = qml.devices.ExecutionConfig()
360-
361351
_, config = self.original_device.preprocess(execution_config)
362352

363353
program = TransformProgram()
354+
if shots is None:
355+
capabilities = self.capabilities
356+
else:
357+
# recompute device capabilities if shots were provided through set_shots
358+
device_caps = get_device_capabilities(self.original_device, shots)
359+
capabilities = get_qjit_device_capabilities(device_caps)
364360

365361
# measurement transforms may change operations on the tape to accommodate
366362
# measurement transformations, so must occur before decomposition
367-
measurement_transforms = self._measurement_transform_program()
363+
measurement_transforms = self._measurement_transform_program(capabilities)
368364
config = replace(config, device_options=deepcopy(config.device_options))
369365
program = program + measurement_transforms
370366

371367
# decomposition to supported ops/measurements
372368
program.add_transform(
373369
catalyst_decompose,
374370
ctx=ctx,
375-
capabilities=self.capabilities,
371+
capabilities=capabilities,
376372
grad_method=config.gradient_method,
377373
)
378374

@@ -382,9 +378,9 @@ def preprocess(
382378
)
383379
program.add_transform(
384380
validate_measurements,
385-
self.capabilities,
381+
capabilities,
386382
self.original_device.name,
387-
self.original_device.shots,
383+
shots,
388384
)
389385

390386
if config.gradient_method is not None:
@@ -396,47 +392,47 @@ def preprocess(
396392

397393
return program, config
398394

399-
def _measurement_transform_program(self):
400-
395+
def _measurement_transform_program(self, capabilities=None):
396+
capabilities = capabilities or self.capabilities
401397
measurement_program = TransformProgram()
402398
if isinstance(self.original_device, SoftwareQQPP):
403399
return measurement_program
404400

405-
supports_sum_observables = "Sum" in self.capabilities.observables
401+
supports_sum_observables = "Sum" in capabilities.observables
406402

407-
if self.capabilities.non_commuting_observables is False:
403+
if capabilities.non_commuting_observables is False:
408404
measurement_program.add_transform(split_non_commuting)
409405
elif not supports_sum_observables:
410406
measurement_program.add_transform(split_to_single_terms)
411407

412408
# if no observables are supported, we apply a transform to convert *everything* to the
413409
# readout basis, using either sample or counts based on device specification
414-
if not self.capabilities.observables:
410+
if not capabilities.observables:
415411
if not split_non_commuting in measurement_program:
416412
# this *should* be redundant, a TOML that doesn't have observables should have
417413
# a False non_commuting_observables flag, but we aren't enforcing that
418414
measurement_program.add_transform(split_non_commuting)
419-
if "SampleMP" in self.capabilities.measurement_processes:
415+
if "SampleMP" in capabilities.measurement_processes:
420416
measurement_program.add_transform(measurements_from_samples, self.wires)
421-
elif "CountsMP" in self.capabilities.measurement_processes:
417+
elif "CountsMP" in capabilities.measurement_processes:
422418
measurement_program.add_transform(measurements_from_counts, self.wires)
423419
else:
424420
raise RuntimeError("The device does not support observables or sample/counts")
425421

426-
elif not self.capabilities.measurement_processes.keys() - {"CountsMP", "SampleMP"}:
422+
elif not capabilities.measurement_processes.keys() - {"CountsMP", "SampleMP"}:
427423
# ToDo: this branch should become unnecessary when selective conversion of
428424
# unsupported MPs is finished, see ToDo below
429425
if not split_non_commuting in measurement_program: # pragma: no branch
430426
measurement_program.add_transform(split_non_commuting)
431427
mp_transform = (
432428
measurements_from_samples
433-
if "SampleMP" in self.capabilities.measurement_processes
429+
if "SampleMP" in capabilities.measurement_processes
434430
else measurements_from_counts
435431
)
436432
measurement_program.add_transform(mp_transform, self.wires)
437433

438434
# if only some observables are supported, we try to diagonalize those that aren't
439-
elif not {"PauliX", "PauliY", "PauliZ", "Hadamard"}.issubset(self.capabilities.observables):
435+
elif not {"PauliX", "PauliY", "PauliZ", "Hadamard"}.issubset(capabilities.observables):
440436
if not split_non_commuting in measurement_program:
441437
# the device might support non commuting measurements but not all the
442438
# Pauli + Hadamard observables, so here it is needed
@@ -449,7 +445,7 @@ def _measurement_transform_program(self):
449445
}
450446
# checking which base observables are unsupported and need to be diagonalized
451447
supported_observables = {"PauliX", "PauliY", "PauliZ", "Hadamard"}.intersection(
452-
self.capabilities.observables
448+
capabilities.observables
453449
)
454450
supported_observables = [_obs_dict[obs] for obs in supported_observables]
455451

@@ -520,15 +516,23 @@ def _load_device_capabilities(device) -> DeviceCapabilities:
520516
return capabilities
521517

522518

523-
def get_device_capabilities(device) -> DeviceCapabilities:
519+
def get_device_capabilities(device, shots=None) -> DeviceCapabilities:
524520
"""Get or load the original DeviceCapabilities from device"""
525521

526522
assert not isinstance(device, QJITDevice)
527523

528-
shots_present = bool(device.shots)
529-
device_capabilities = _load_device_capabilities(device)
524+
shots_present = bool(shots)
525+
device_capabilities = _load_device_capabilities(device).filter(finite_shots=shots_present)
526+
527+
# TODO: This is a temporary measure to ensure consistency of behaviour. Remove this
528+
# when customizable multi-pathway decomposition is implemented. (Epic 74474)
529+
if hasattr(device, "_to_matrix_ops"):
530+
_to_matrix_ops = getattr(device, "_to_matrix_ops")
531+
setattr(device_capabilities, "to_matrix_ops", _to_matrix_ops)
532+
if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"):
533+
raise CompileError("The device that specifies to_matrix_ops must support QubitUnitary.")
530534

531-
return device_capabilities.filter(finite_shots=shots_present)
535+
return device_capabilities
532536

533537

534538
def is_dynamic_wires(wires: qml.wires.Wires):

frontend/catalyst/device/verification.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
with the compiler and device.
1818
"""
1919

20-
from typing import Any, Callable, List, Sequence, Union
20+
from typing import Any, Callable, List, Sequence
2121

2222
from pennylane import transform
2323
from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties
@@ -30,7 +30,7 @@
3030
VarianceMP,
3131
VnEntropyMP,
3232
)
33-
from pennylane.measurements.shots import Shots
33+
from pennylane.measurements.shots import Shots, ShotsLike
3434
from pennylane.operation import Operation
3535
from pennylane.ops import (
3636
Adjoint,
@@ -268,7 +268,7 @@ def _obs_checker(obs):
268268

269269
@transform
270270
def validate_measurements(
271-
tape: QuantumTape, capabilities: DeviceCapabilities, name: str, shots: Union[int, Shots]
271+
tape: QuantumTape, capabilities: DeviceCapabilities, name: str, shots: ShotsLike = None
272272
) -> (Sequence[QuantumTape], Callable):
273273
"""Validates the observables and measurements for a circuit against the capabilites
274274
from the TOML file.
@@ -287,6 +287,7 @@ def validate_measurements(
287287
CompileError: if a measurement is not supported by the given device with Catalyst
288288
289289
"""
290+
shots = tape.shots if shots is None else Shots(shots)
290291

291292
def _obs_checker(obs):
292293
if not obs.name in capabilities.observables:

frontend/catalyst/from_plxpr/from_plxpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion
4444
from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot
4545

46-
from catalyst.device import extract_backend_info, get_device_capabilities
46+
from catalyst.device import extract_backend_info
4747
from catalyst.from_plxpr.qreg_manager import QregManager
4848
from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config
4949
from catalyst.jax_primitives import (

frontend/catalyst/jax_tracer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,9 @@ def is_leaf(obj):
13821382
)
13831383
if isinstance(device, qml.devices.Device):
13841384
config = _make_execution_config(qnode)
1385-
device_program, config = device.preprocess(ctx, config)
1385+
device_program, config = device.preprocess(
1386+
ctx, execution_config=config, shots=shots
1387+
)
13861388
else:
13871389
device_program = TransformProgram()
13881390

frontend/catalyst/qfunc.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def __call__(self, *args, **kwargs):
141141
return Function(dynamic_one_shot(self, mcm_config=mcm_config))(*args, **kwargs)
142142

143143
new_device = copy(self.device)
144-
new_device._shots = self._shots # pylint: disable=protected-access
145144
qjit_device = QJITDevice(new_device)
146145

147146
static_argnums = kwargs.pop("static_argnums", ())
@@ -264,18 +263,13 @@ def processing_fn(results):
264263
return dynamic_one_shot_partial(qnode)
265264

266265
single_shot_qnode = transform_to_single_shot(qnode)
266+
single_shot_qnode = qml.set_shots(single_shot_qnode, shots=1)
267267
if mcm_config is not None:
268268
single_shot_qnode.execute_kwargs["postselect_mode"] = mcm_config.postselect_mode
269269
single_shot_qnode.execute_kwargs["mcm_method"] = mcm_config.mcm_method
270270
single_shot_qnode._dynamic_one_shot_called = True
271-
dev = qnode.device
272271
total_shots = _get_total_shots(qnode)
273272

274-
new_dev = copy(dev)
275-
new_dev._shots = qml.measurements.Shots(1)
276-
single_shot_qnode.device = new_dev
277-
single_shot_qnode._set_shots(qml.measurements.Shots(1)) # pylint: disable=protected-access
278-
279273
def one_shot_wrapper(*args, **kwargs):
280274
def wrap_single_shot_qnode(*_):
281275
return single_shot_qnode(*args, **kwargs)

frontend/test/lit/test_mitigation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
@qjit(target="mlir")
2727
def mcm_method_with_zne():
2828
"""Test that the dynamic_one_shot works with ZNE."""
29-
dev = qml.device("lightning.qubit", wires=1, shots=5)
29+
dev = qml.device("lightning.qubit", wires=1)
3030

3131
def circuit():
3232
return qml.expval(qml.PauliZ(0))
3333

3434
s = [1, 3]
35-
g = qml.QNode(circuit, dev, mcm_method="one-shot")
35+
g = qml.set_shots(qml.QNode(circuit, dev, mcm_method="one-shot"), shots=5)
3636
return mitigate_with_zne(g, scale_factors=s)()
3737

3838

frontend/test/pytest/device/test_decomposition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def matrix(self):
156156
dtype=np.complex128,
157157
)
158158

159-
dev = NoUnitaryDevice(4, wires=4)
159+
dev = NoUnitaryDevice(wires=4)
160160

161+
@qml.set_shots(4)
161162
@qml.qnode(dev)
162163
def f():
163164
ctrl(UnknownOp(wires=[0, 1]), control=[2, 3])

0 commit comments

Comments
 (0)