Skip to content

Commit 33debf0

Browse files
albi3roandrijapauJerryChen97isaacdevlugt
authored
No longer squeeze singleton dimension in samples (#7944)
**Context:** Similar to #7917 , but doesn't squeeze observables. Currently, we always squeeze out singleton dimension from `qml.sample` measurments. This leads to lots of additional logical branches, and error-prone postprocessing. It can be really easy to loose track of which dimensions are being sliced into. Catalyst also never squeezes out singleton dimensions. If we want to unify the behavior of pennylane and catalyst, we need to either start squeezing in catalyst, or stop squeezing in pennylane. We are opting to stop squeezing in pennylane. **Description of the Change:** Switches to that sampling wires *always* returns an array of shape `(shots, wires)`, regardless of whether or not we have a single shot or a single wire. **Benefits:** Much simpler, more robust code and unified behavior between pennylane and catalyst. **Possible Drawbacks:** This is going to be a breaking change, and there's really no way to soften this to a deprecation. We just have to pull the bandaid off. Need an accompanying issue to lightning. **Related GitHub Issues:** [[sc-95797](https://app.shortcut.com/xanaduai/story/95797)] --------- Co-authored-by: Andrija Paurevic <[email protected]> Co-authored-by: Yushao Chen (Jerry) <[email protected]> Co-authored-by: Isaac De Vlugt <[email protected]>
1 parent cc6e16c commit 33debf0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+366
-326
lines changed

doc/development/plugins.rst

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,21 @@ circuits.
8787
>>> tape1 = qml.tape.QuantumScript([], [qml.sample(wires=0)], shots=10)
8888
>>> dev = qml.device('default.qubit')
8989
>>> dev.execute((tape0, tape1))
90-
(array([0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
90+
(array([[0],
91+
[0],
92+
[0],
93+
[0],
94+
[0]]),
95+
array([[0],
96+
[0],
97+
[0],
98+
[0],
99+
[0],
100+
[0],
101+
[0],
102+
[0],
103+
[0],
104+
[0]]))
91105

92106
The :class:`~.measurements.Shots` class describes the shots. Users can optionally specify a shot vector, or
93107
different numbers of shots to use when calculating the final expectation value.
@@ -457,9 +471,27 @@ the top user level, we aim to allow dynamic configuration of the device.
457471
>>> config = qml.devices.ExecutionConfig(device_options={"rng": 42})
458472
>>> tape = qml.tape.QuantumTape([qml.Hadamard(0)], [qml.sample(wires=0)], shots=10)
459473
>>> dev.execute(tape, config)
460-
array([1, 0, 1, 1, 0, 1, 1, 1, 0, 0])
474+
array([[1],
475+
[1],
476+
[0],
477+
[1],
478+
[0],
479+
[1],
480+
[0],
481+
[1],
482+
[0],
483+
[0]])
461484
>>> dev.execute(tape, config)
462-
array([1, 0, 1, 1, 0, 1, 1, 1, 0, 0])
485+
array([[0],
486+
[1],
487+
[0],
488+
[0],
489+
[0],
490+
[1],
491+
[1],
492+
[0],
493+
[0],
494+
[0]])
463495

464496
By pulling options from this dictionary instead of from device properties, we unlock two key
465497
pieces of functionality:

doc/introduction/dynamic_quantum_circuits.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,13 @@ Executing this QNode with 10 shots yields
9494
.. code-block:: pycon
9595
9696
>>> func(np.pi / 2, shots=10)
97-
array([1, 1, 1, 1, 1, 1, 1])
97+
array([[1],
98+
[1],
99+
[1],
100+
[1],
101+
[1]])
98102
99-
Note that only 7 samples are returned. This is because samples that do not meet the postselection criteria are
103+
Note that less than 10 samples are returned. This is because samples that do not meet the postselection criteria are
100104
discarded. This behaviour can be customized, see the section
101105
:ref:`"Configuring mid-circuit measurements" <mcm_config>`.
102106

doc/introduction/unsupported_gradients.rst

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ error, but the results will be incorrect:
287287
def sample_backward():
288288
dev = qml.device('default.qubit', wires=1)
289289
290-
@partial(qml.set_shots, shots=20)
290+
@partial(qml.set_shots, shots=5)
291291
@qml.qnode(dev)
292292
def circuit(x):
293293
qml.RX(x[0], wires=0)
@@ -297,26 +297,15 @@ error, but the results will be incorrect:
297297
print(qml.jacobian(circuit)(x))
298298
299299
>>> sample_backward()
300-
[[0.5]
301-
[0.5]
302-
[0.5]
303-
[0.5]
304-
[0.5]
305-
[0.5]
306-
[0.5]
307-
[0.5]
308-
[0.5]
309-
[0.5]
310-
[0.5]
311-
[0.5]
312-
[0.5]
313-
[0.5]
314-
[0.5]
315-
[0.5]
316-
[0.5]
317-
[0.5]
318-
[0.5]
319-
[0.5]]
300+
[[[0.5]]
301+
<BLANKLINE>
302+
[[0.5]]
303+
<BLANKLINE>
304+
[[0.5]]
305+
<BLANKLINE>
306+
[[0.5]]
307+
<BLANKLINE>
308+
[[0.5]]]
320309

321310
The forward pass is supported and will work as expected:
322311

@@ -335,4 +324,23 @@ The forward pass is supported and will work as expected:
335324
print(circuit(x))
336325
337326
>>> sample_forward()
338-
[0 1 0 0 0 1 1 0 0 1 1 1 0 0 0 1 1 0 0 0]
327+
[[0]
328+
[0]
329+
[0]
330+
[0]
331+
[1]
332+
[1]
333+
[0]
334+
[0]
335+
[1]
336+
[1]
337+
[1]
338+
[1]
339+
[0]
340+
[1]
341+
[1]
342+
[0]
343+
[1]
344+
[0]
345+
[0]
346+
[1]]

doc/releases/changelog-dev.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,40 @@
291291

292292
<h3>Breaking changes 💔</h3>
293293

294+
* `qml.sample` no longer has singleton dimensions squeezed out for single shots or single wires. This cuts
295+
down on the complexity of post-processing due to having to handle single shot and single wire cases
296+
separately. The return shape will now *always* be `(shots, num_wires)`.
297+
[(#7944)](https://github.com/PennyLaneAI/pennylane/pull/7944)
298+
299+
For a simple qnode:
300+
301+
```pycon
302+
>>> @qml.qnode(qml.device('default.qubit'))
303+
... def c():
304+
... return qml.sample(wires=0)
305+
```
306+
307+
Before the change, we had:
308+
309+
```pycon
310+
>>> qml.set_shots(c, shots=1)()
311+
0
312+
```
313+
314+
and now we have:
315+
316+
```pycon
317+
>>> qml.set_shots(c, shots=1)()
318+
array([[0]])
319+
```
320+
321+
Previous behavior can be recovered by squeezing the output:
322+
323+
```pycon
324+
>>> qml.math.squeeze(qml.set_shots(c, shots=1)())
325+
0
326+
```
327+
294328
* `ExecutionConfig` and `MCMConfig` from `pennylane.devices` are now frozen dataclasses whose fields should be updated with `dataclass.replace`.
295329
[(#7697)](https://github.com/PennyLaneAI/pennylane/pull/7697)
296330

pennylane/devices/_qubit_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def statistics(
659659
elif isinstance(m, SampleMP):
660660
samples = self.sample(obs, shot_range=shot_range, bin_size=bin_size, counts=False)
661661
dtype = int if isinstance(obs, SampleMP) else None
662-
result = self._asarray(qml.math.squeeze(samples), dtype=dtype)
662+
result = self._asarray(samples, dtype=dtype)
663663

664664
elif isinstance(m, CountsMP):
665665
result = self.sample(m, shot_range=shot_range, bin_size=bin_size, counts=True)

pennylane/devices/legacy_facade.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ class LegacyDeviceFacade(Device):
152152
Shots(total_shots=None, shot_vector=())
153153
>>> tape = qml.tape.QuantumScript([], [qml.sample(wires=0)], shots=5)
154154
>>> new_dev.execute(tape)
155-
array([0, 0, 0, 0, 0])
155+
array([[0],
156+
[0],
157+
[0],
158+
[0],
159+
[0]])
156160
157161
"""
158162

pennylane/devices/preprocess.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -639,31 +639,19 @@ def measurements_from_samples(tape):
639639
diagonalized_tape, measured_wires = _get_diagonalized_tape_and_wires(tape)
640640
new_tape = diagonalized_tape.copy(measurements=[qml.sample(wires=measured_wires)])
641641

642-
def unsqueezed(samples):
643-
"""If the samples have been squeezed to remove the 'extra' dimension in the case where
644-
shots=1 or wires=1, unsqueeze to restore the raw samples format expected by mp.process_samples
645-
"""
646-
647-
# if we stop squeezing out the extra dimension for shots=1 or wires=1 when sampling (as is
648-
# already the case in Catalyst), this problem goes away
649-
650-
if len(samples.shape) == 1:
651-
samples = qml.math.array([[s] for s in samples], like=samples)
652-
return samples
653-
654642
def postprocessing_fn(results):
655643
"""A processing function to get measurement values from samples."""
656644
samples = results[0]
657645
if tape.shots.has_partitioned_shots:
658646
results_processed = []
659647
for s in samples:
660-
res = [m.process_samples(unsqueezed(s), measured_wires) for m in tape.measurements]
648+
res = [m.process_samples(s, measured_wires) for m in tape.measurements]
661649
if len(tape.measurements) == 1:
662650
res = res[0]
663651
results_processed.append(res)
664652
else:
665653
results_processed = [
666-
m.process_samples(unsqueezed(samples), measured_wires) for m in tape.measurements
654+
m.process_samples(samples, measured_wires) for m in tape.measurements
667655
]
668656
if len(tape.measurements) == 1:
669657
results_processed = results_processed[0]

pennylane/devices/qubit/sampling.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import pennylane as qml
2020
from pennylane.measurements import (
2121
ClassicalShadowMP,
22-
CountsMP,
2322
ExpectationMP,
2423
SampleMeasurement,
2524
ShadowExpvalMP,
@@ -308,15 +307,7 @@ def _measure_with_samples_diagonalizing_gates(
308307
wires = qml.wires.Wires(range(total_indices))
309308

310309
def _process_single_shot(samples):
311-
processed = []
312-
for mp in mps:
313-
res = mp.process_samples(samples, wires)
314-
if not isinstance(mp, CountsMP):
315-
res = qml.math.squeeze(res)
316-
317-
processed.append(res)
318-
319-
return tuple(processed)
310+
return tuple(mp.process_samples(samples, wires) for mp in mps)
320311

321312
try:
322313
prng_key, _ = jax_random_split(prng_key)

pennylane/devices/qubit/simulate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,6 @@ def combine_measurements(terminal_measurements, results, mcm_samples):
858858
comb_meas = measurement_with_no_shots(circ_meas)
859859
else:
860860
comb_meas = combine_measurements_core(circ_meas, results.pop(0))
861-
if isinstance(circ_meas, SampleMP):
862-
comb_meas = qml.math.squeeze(comb_meas)
863861
final_measurements.append(comb_meas)
864862
return final_measurements[0] if len(final_measurements) == 1 else tuple(final_measurements)
865863

@@ -916,7 +914,7 @@ def _(original_measurement: SampleMP, measures):
916914
new_sample = tuple(
917915
qml.math.atleast_1d(m[1]) for m in measures.values() if m[0] and not m[1] is tuple()
918916
)
919-
return qml.math.squeeze(qml.math.concatenate(new_sample))
917+
return qml.math.concatenate(new_sample)
920918

921919

922920
@debug_logger

pennylane/devices/qubit_mixed/sampling.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import numpy as np
2121

2222
import pennylane as qml
23-
from pennylane import math
2423
from pennylane.devices.qubit.sampling import _group_measurements, jax_random_split, sample_probs
25-
from pennylane.measurements import CountsMP, ExpectationMP, SampleMeasurement, Shots
24+
from pennylane.measurements import ExpectationMP, SampleMeasurement, Shots
2625
from pennylane.measurements.classical_shadow import ClassicalShadowMP, ShadowExpvalMP
2726
from pennylane.ops import LinearCombination, Sum
2827
from pennylane.typing import TensorLike
@@ -87,15 +86,7 @@ def _measure_with_samples_diagonalizing_gates(
8786
wires = qml.wires.Wires(range(total_indices))
8887

8988
def _process_single_shot(samples):
90-
processed = []
91-
for mp in mps:
92-
res = mp.process_samples(samples, wires)
93-
if not isinstance(mp, CountsMP):
94-
res = math.squeeze(res)
95-
96-
processed.append(res)
97-
98-
return tuple(processed)
89+
return tuple(mp.process_samples(samples, wires) for mp in mps)
9990

10091
prng_key, _ = jax_random_split(prng_key)
10192
samples = sample_state(

0 commit comments

Comments
 (0)