Skip to content

Commit 93b5490

Browse files
astralcaiJerryChen97mudit2812
authored
[PPM] Implement the PauliMeasure class and pauli_measure (#8461)
**Context:** **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-99878] [sc-99879] --------- Co-authored-by: Yushao Chen (Jerry) <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
1 parent 8920c55 commit 93b5490

File tree

14 files changed

+323
-52
lines changed

14 files changed

+323
-52
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
that produces a set of gate names to be used as the target gate set in decompositions.
77
[(#8522)](https://github.com/PennyLaneAI/pennylane/pull/8522)
88

9+
* Added a :func:`~pennylane.measurements.pauli_measure` that takes a Pauli product measurement.
10+
[(#8461)](https://github.com/PennyLaneAI/pennylane/pull/8461)
11+
912
<h3>Improvements 🛠</h3>
1013

1114
* Added a keyword argument ``recursive`` to ``qml.transforms.cancel_inverses`` that enables

pennylane/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
counts,
8282
density_matrix,
8383
measure,
84+
pauli_measure,
8485
expval,
8586
probs,
8687
sample,

pennylane/capture/primitives.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pennylane.control_flow.while_loop import _get_while_loop_qfunc_prim
2323
from pennylane.measurements.capture_measurements import _get_abstract_measurement
2424
from pennylane.measurements.mid_measure import _create_mid_measure_primitive
25+
from pennylane.measurements.pauli_measure import _create_pauli_measure_primitive
2526
from pennylane.operation import _get_abstract_operator
2627
from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim
2728
from pennylane.ops.op_math.condition import _get_cond_qfunc_prim
@@ -38,6 +39,7 @@
3839
for_loop_prim = _get_for_loop_qfunc_prim()
3940
while_loop_prim = _get_while_loop_qfunc_prim()
4041
measure_prim = _create_mid_measure_primitive()
42+
pauli_measure_prim = _create_pauli_measure_primitive()
4143

4244
__all__ = [
4345
"AbstractOperator",
@@ -51,4 +53,5 @@
5153
"for_loop_prim",
5254
"while_loop_prim",
5355
"measure_prim",
56+
"pauli_measure_prim",
5457
]

pennylane/measurements/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def circuit(x):
280280
StateMeasurement,
281281
)
282282
from .mid_measure import MidMeasureMP, find_post_processed_mcms, get_mcm_predicates, measure
283+
from .pauli_measure import PauliMeasure, pauli_measure
283284
from .mutual_info import MutualInfoMP, mutual_info
284285
from .null_measurement import NullMeasurement
285286
from .probs import ProbabilityMP, probs

pennylane/measurements/measurement_value.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,24 @@
1414
"""
1515
Defines the MeasurementValue class
1616
"""
17-
from collections.abc import Callable
18-
from typing import Generic, TypeVar
17+
from collections.abc import Callable, Generator
1918

2019
from pennylane import math
2120
from pennylane.wires import Wires
2221

23-
T = TypeVar("T")
24-
2522

2623
def no_processing(results):
2724
"""A postprocessing function with no effect."""
2825
return results
2926

3027

31-
class MeasurementValue(Generic[T]):
28+
class MeasurementValue:
3229
"""A class representing unknown measurement outcomes in the qubit model.
3330
34-
Measurements on a single qubit in the computational basis are assumed.
35-
3631
Args:
37-
measurements (list[.MidMeasureMP]): The measurement(s) that this object depends on.
32+
measurements (list[MidMeasureMP | PauliMeasure]): The measurement(s) that this object depends on.
3833
processing_fn (callable | None): A lazy transformation applied to the measurement values.
34+
3935
"""
4036

4137
name = "MeasurementValue"
@@ -56,14 +52,14 @@ def processing_fn(self) -> Callable:
5652
return no_processing
5753
return self._processing_fn
5854

59-
def items(self):
55+
def items(self) -> Generator:
6056
"""A generator representing all the possible outcomes of the MeasurementValue."""
6157
num_meas = len(self.measurements)
6258
for i in range(2**num_meas):
6359
branch = tuple(int(b) for b in f"{i:0{num_meas}b}")
6460
yield branch, self.processing_fn(*branch)
6561

66-
def postselected_items(self):
62+
def postselected_items(self) -> Generator:
6763
"""A generator representing all the possible outcomes of the MeasurementValue,
6864
taking postselection into account."""
6965
# pylint: disable=stop-iteration-return

pennylane/measurements/mid_measure.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515
This module contains the qml.measure measurement.
1616
"""
17+
1718
import uuid
1819
from collections.abc import Hashable
1920
from functools import lru_cache
@@ -192,13 +193,7 @@ def label(self, decimals=None, base_label=None, cache=None):
192193
@property
193194
def hash(self):
194195
"""int: Returns an integer hash uniquely representing the measurement process"""
195-
fingerprint = (
196-
self.__class__.__name__,
197-
tuple(self.wires.tolist()),
198-
self.id,
199-
)
200-
201-
return hash(fingerprint)
196+
return hash((self.__class__.__name__, tuple(self.wires.tolist()), self.id))
202197

203198

204199
def measure(
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Implements the pauli measurement.
16+
"""
17+
18+
import uuid
19+
from functools import lru_cache
20+
21+
from pennylane import math
22+
from pennylane.capture import enabled as capture_enabled
23+
from pennylane.operation import Operator
24+
from pennylane.wires import Wires, WiresLike
25+
26+
from .measurement_value import MeasurementValue
27+
28+
_VALID_PAULI_CHARS = "XYZ"
29+
30+
31+
class PauliMeasure(Operator):
32+
"""A Pauli product measurement."""
33+
34+
resource_keys = {"pauli_word"}
35+
36+
def __init__(
37+
self,
38+
pauli_word: str,
39+
wires: WiresLike,
40+
postselect: int | None = None,
41+
id: str | None = None,
42+
):
43+
if not all(c in _VALID_PAULI_CHARS for c in pauli_word):
44+
raise ValueError(
45+
f'The given Pauli word "{pauli_word}" contains characters that '
46+
"are not allowed. Allowed characters are X, Y and Z."
47+
)
48+
49+
wires = Wires(wires)
50+
if len(pauli_word) != len(wires):
51+
raise ValueError(
52+
"The number of wires must be equal to the length of the Pauli "
53+
f"word. The Pauli word {pauli_word} has length {len(pauli_word)} "
54+
f"and {len(wires)} wires were given: {wires}."
55+
)
56+
super().__init__(wires=wires, id=id)
57+
self.hyperparameters["pauli_word"] = pauli_word
58+
self.hyperparameters["postselect"] = postselect
59+
60+
@property
61+
def pauli_word(self) -> str:
62+
"""The Pauli word for the measurement."""
63+
return self.hyperparameters["pauli_word"]
64+
65+
@property
66+
def postselect(self) -> int | None:
67+
"""Which outcome to postselect after the measurement."""
68+
return self.hyperparameters["postselect"]
69+
70+
@classmethod
71+
def _primitive_bind_call(cls, *args, **kwargs):
72+
return type.__call__(cls, *args, **kwargs)
73+
74+
def __repr__(self) -> str:
75+
return f"PauliMeasure('{self.pauli_word}', wires={self.wires.tolist()})"
76+
77+
@property
78+
def resource_params(self) -> dict:
79+
return {"pauli_word": self.hyperparameters["pauli_word"]}
80+
81+
@property
82+
def hash(self) -> int:
83+
"""int: An integer hash uniquely representing the measurement."""
84+
return hash((self.__class__.__name__, self.pauli_word, tuple(self.wires.tolist()), self.id))
85+
86+
87+
def _pauli_measure_impl(wires: WiresLike, pauli_word: str, postselect: int | None = None):
88+
"""Concrete implementation of the pauli_measure primitive."""
89+
measurement_id = str(uuid.uuid4())
90+
measurement = PauliMeasure(pauli_word, wires, postselect, measurement_id)
91+
return MeasurementValue([measurement])
92+
93+
94+
@lru_cache
95+
def _create_pauli_measure_primitive():
96+
"""Create a primitive corresponding to a Pauli product measurement."""
97+
98+
# pylint: disable=import-outside-toplevel
99+
import jax
100+
101+
from pennylane.capture.custom_primitives import QmlPrimitive
102+
103+
pauli_measure_p = QmlPrimitive("pauli_measure")
104+
105+
@pauli_measure_p.def_impl
106+
def _pauli_measure_primitive_impl(*wires, pauli_word="", postselect=None):
107+
return _pauli_measure_impl(wires, pauli_word=pauli_word, postselect=postselect)
108+
109+
@pauli_measure_p.def_abstract_eval
110+
def _pauli_measure_primitive_abstract_eval(*_, **__):
111+
dtype = jax.numpy.int64 if jax.config.jax_enable_x64 else jax.numpy.int32
112+
return jax.core.ShapedArray((), dtype)
113+
114+
return pauli_measure_p
115+
116+
117+
def pauli_measure(pauli_word: str, wires: WiresLike, postselect: int | None = None):
118+
"""Perform a Pauli product measurement."""
119+
120+
if capture_enabled():
121+
primitive = _create_pauli_measure_primitive()
122+
wires = (wires,) if math.shape(wires) == () else tuple(wires)
123+
return primitive.bind(*wires, pauli_word=pauli_word, postselect=postselect)
124+
125+
return _pauli_measure_impl(wires, pauli_word, postselect)

pennylane/operation.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,14 @@
179179
~ops.qubit.attributes.symmetric_over_control_wires
180180
181181
"""
182+
182183
# pylint: disable=access-member-before-definition
183184
import abc
184185
import copy
185186
import warnings
186187
from collections.abc import Callable, Hashable, Iterable
187188
from functools import lru_cache
188-
from typing import Any, Literal, Optional, Union
189+
from typing import Any, ClassVar, Literal, Optional, Union
189190

190191
import numpy as np
191192
from scipy.sparse import spmatrix
@@ -680,6 +681,21 @@ def compute_decomposition(theta, wires):
680681
Optional[jax.extend.core.Primitive]
681682
"""
682683

684+
resource_keys: ClassVar[set | frozenset] = set()
685+
"""The set of parameters that affects the resource requirement of the operator.
686+
687+
All decomposition rules for this operator class are expected to have a resource function
688+
that accepts keyword arguments that match these keys exactly. The :func:`~pennylane.resource_rep`
689+
function will also expect keyword arguments that match these keys when called with this
690+
operator type.
691+
692+
The default implementation is an empty set, which is suitable for most operators.
693+
694+
.. seealso::
695+
:meth:`~.Operator.resource_params`
696+
697+
"""
698+
683699
def __init_subclass__(cls, **_):
684700
register_pytree(cls, cls._flatten, cls._unflatten)
685701
cls._primitive = create_operator_primitive(cls)
@@ -1092,13 +1108,7 @@ def _format(x):
10921108
return f"{op_label}"
10931109
return f"{op_label}\n({inner_string})"
10941110

1095-
def __init__(
1096-
self,
1097-
*params: TensorLike,
1098-
wires: WiresLike | None = None,
1099-
id: str | None = None,
1100-
):
1101-
1111+
def __init__(self, *params: TensorLike, wires: WiresLike | None = None, id: str | None = None):
11021112
self._name: str = self.__class__.__name__ #: str: name of the operator
11031113
self._id: str = id
11041114
self._pauli_rep: qml.pauli.PauliSentence | None = (
@@ -1398,23 +1408,6 @@ def compute_qfunc_decomposition(*args, **hyperparameters) -> None:
13981408

13991409
raise DecompositionUndefinedError
14001410

1401-
@classproperty
1402-
def resource_keys(self) -> set | frozenset: # pylint: disable=no-self-use
1403-
"""The set of parameters that affects the resource requirement of the operator.
1404-
1405-
All decomposition rules for this operator class are expected to have a resource function
1406-
that accepts keyword arguments that match these keys exactly. The :func:`~pennylane.resource_rep`
1407-
function will also expect keyword arguments that match these keys when called with this
1408-
operator type.
1409-
1410-
The default implementation is an empty set, which is suitable for most operators.
1411-
1412-
.. seealso::
1413-
:meth:`~.Operator.resource_params`
1414-
1415-
"""
1416-
return set()
1417-
14181411
@property
14191412
def resource_params(self) -> dict:
14201413
"""A dictionary containing the minimal information needed to compute a

pennylane/ops/cv.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,10 @@
3636
# As the qubit based ``decomposition``, ``_matrix``, ``diagonalizing_gates``
3737
# abstract methods are not defined in the CV case, disabling the related check
3838

39-
import math
40-
4139
import numpy as np
4240
from scipy.linalg import block_diag
4341

44-
from pennylane import math as qml_math
42+
from pennylane import math
4543
from pennylane.operation import CVObservable, CVOperation
4644

4745
from .identity import I, Identity # pylint: disable=unused-import
@@ -688,7 +686,7 @@ def _heisenberg_rep(p):
688686

689687
def adjoint(self):
690688
U = self.parameters[0]
691-
return InterferometerUnitary(qml_math.T(qml_math.conj(U)), wires=self.wires)
689+
return InterferometerUnitary(math.T(math.conj(U)), wires=self.wires)
692690

693691
def label(self, decimals=None, base_label=None, cache=None):
694692
return super().label(decimals=decimals, base_label=base_label or "U", cache=cache)
@@ -886,9 +884,9 @@ def label(self, decimals=None, base_label=None, cache=None):
886884
if base_label is not None:
887885
if decimals is None:
888886
return base_label
889-
p = format(qml_math.asarray(self.parameters[0]), ".0f")
887+
p = format(math.asarray(self.parameters[0]), ".0f")
890888
return base_label + f"\n({p})"
891-
return f"|{qml_math.asarray(self.parameters[0])}⟩"
889+
return f"|{math.asarray(self.parameters[0])}⟩"
892890

893891

894892
class FockStateVector(CVOperation):
@@ -1286,7 +1284,7 @@ def label(self, decimals=None, base_label=None, cache=None):
12861284
if decimals is None:
12871285
p = "φ"
12881286
else:
1289-
p = format(qml_math.array(self.parameters[0]), f".{decimals}f")
1287+
p = format(math.array(self.parameters[0]), f".{decimals}f")
12901288
return f"cos({p})x\n+sin({p})p"
12911289

12921290

0 commit comments

Comments
 (0)