Skip to content

Commit 85b1e6a

Browse files
authored
[BUG] Fix measure from sample shot vector condition (#1981)
**Context:** With the new `shots` param added to the test `test_measurement_from_samples_single_measurement_analytic`, it fails immediately. The root cause has been discussed in a [PennyLane discussion](PennyLaneAI/pennylane#7317 (comment)). **Description of the Change:** Use `shots.has_partitioned_shots` instead of `len(shots.shot_vector) > 1` **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-91172]
1 parent 03cfe73 commit 85b1e6a

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656

5757
<h3>Bug fixes 🐛</h3>
5858

59+
* Fix wrong handling of partitioned shots in the decomposition pass of `measurements_from_samples`.
60+
[(#1981)](https://github.com/PennyLaneAI/catalyst/pull/1981)
61+
5962
* Fix errors in AutoGraph transformed functions when `qml.prod` is used together with other operator
6063
transforms (e.g. `qml.adjoint`).
6164
[(#1910)](https://github.com/PennyLaneAI/catalyst/pull/1910)

frontend/catalyst/device/decomposition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def postprocessing_samples(results):
348348
results_processed = []
349349
for m in tape.measurements:
350350
if isinstance(m, (ExpectationMP, VarianceMP, ProbabilityMP, SampleMP)):
351-
if len(tape.shots.shot_vector) > 1:
351+
if tape.shots.has_partitioned_shots:
352352
res = tuple(m.process_samples(s, measured_wires) for s in samples)
353353
else:
354354
res = m.process_samples(samples, measured_wires)

frontend/test/pytest/test_measurement_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def circuit(theta: float):
478478
),
479479
],
480480
)
481-
@pytest.mark.parametrize("shots", [3000, (3000, 4000), (3000, 3500, 4000)])
481+
@pytest.mark.parametrize("shots", [3000, (3000, 3000), (3000, 4000), (3000, 3500, 4000)])
482482
def test_measurement_from_samples_single_measurement_analytic(
483483
self,
484484
input_measurement,

0 commit comments

Comments
 (0)