Skip to content

Commit 7b02a09

Browse files
maliasadirniczh
andauthored
Enable decomposition of MultiRZ at MLIR (#2160)
**Context:** This PR refactors the implementation of `SignatureAnalyzer` in the decompose-lowering pass extending the analyser and the pass to support special operator definitions (e.g., MultiRZ) along with CustomOps as new patterns. **Description of the Change:** - Add a `BaseSignatureAnalyzer` class that servers a base analyzer in decompose-lowering - Add MultiRZ decomposition patterns to the lowering pass - Fix the issue with trying to capture and compile unutilized abstracted Adjoint and Controlled decomposition rules from the graph. (We don't need these rules from the graph as we handle these symbolic ops using Catalyst AdjointOp and ControlledOp). Avoiding this saves compilation time and failures for templated rules. - Include MLIR and Python tests. **Related Issues:** [sc-102152] --------- Co-authored-by: Hong-Sheng Zheng <[email protected]>
1 parent 101c7a7 commit 7b02a09

File tree

7 files changed

+645
-351
lines changed

7 files changed

+645
-351
lines changed

doc/releases/changelog-dev.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
<h3>Improvements 🛠</h3>
66

7+
* The ``decompose-lowering`` MLIR pass now supports ``qml.MultiRZ``
8+
with an arbitrary number of wires. This decomposition is performed
9+
at MLIR when both capture and graph-decomposition are enabled.
10+
[(#2160)](https://github.com/PennyLaneAI/catalyst/pull/2160)
11+
712
* A new option ``use_nameloc`` has been added to :func:`~.qjit` that embeds variable names
813
from Python into the compiler IR, which can make it easier to read when debugging programs.
914
[(#2054)](https://github.com/PennyLaneAI/catalyst/pull/2054)
@@ -21,6 +26,10 @@
2126

2227
<h3>Bug fixes 🐛</h3>
2328

29+
* Fixes the issue with capturing unutilized abstracted adjoint and controlled rules
30+
by the graph in the new decomposition framework.
31+
[(#2160)](https://github.com/PennyLaneAI/catalyst/pull/2160)
32+
2433
* Fixes the translation of plxpr control flow for edge cases where the `consts` were being
2534
reordered.
2635
[(#2128)](https://github.com/PennyLaneAI/catalyst/pull/2128)
@@ -47,6 +56,7 @@
4756

4857
This release contains contributions from (in alphabetical order):
4958

59+
Ali Asadi,
5060
Christina Lee,
5161
Roberto Turrado,
5262
Paul Haochen Wang.

frontend/catalyst/from_plxpr/decompose.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,14 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess
234234
num_params=num_params,
235235
requires_copy=num_wires == -1,
236236
)
237-
else: # pragma: no cover
237+
elif not any(
238+
keyword in getattr(op.op, "name", "") for keyword in ("Adjoint", "Controlled")
239+
): # pragma: no cover
240+
# Note that the graph-decomposition returns abstracted rules
241+
# for Adjoint and Controlled operations, so we skip them here.
242+
# These abstracted rules cannot be captured and lowered.
243+
# We use MLIR AdjointOp and ControlledOp primitives
244+
# to deal with decomposition of symbolic operations at PLxPR.
238245
raise ValueError(f"Could not capture {op} without the number of wires.")
239246

240247
data, struct = jax.tree_util.tree_flatten(measurement)

frontend/test/pytest/from_plxpr/test_from_plxpr.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,45 @@ def circuit():
10511051
expected = np.array([1, 0, 0, 1]) / np.sqrt(2)
10521052
assert qml.math.allclose(circuit(), expected)
10531053

1054+
expected_resources = {"CZ": 1, "Hadamard": 3}
1055+
resources = qml.specs(circuit, level="device")()["resources"]
1056+
assert resources.gate_types == expected_resources
1057+
1058+
qml.decomposition.disable_graph()
1059+
qml.capture.disable()
1060+
1061+
def test_multirz(self):
1062+
"""Test that multirz decomposition works with from_plxpr."""
1063+
1064+
qml.capture.enable()
1065+
qml.decomposition.enable_graph()
1066+
1067+
@partial(
1068+
qml.transforms.decompose,
1069+
gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "Rot", "GlobalPhase"},
1070+
)
1071+
@qml.qnode(qml.device("lightning.qubit", wires=4))
1072+
def circuit():
1073+
qml.Hadamard(0)
1074+
qml.ctrl(qml.MultiRZ(0.345, wires=[1, 2]), control=0)
1075+
qml.adjoint(qml.MultiRZ(0.25, wires=[1, 2]))
1076+
qml.MultiRZ(0.5, wires=[0, 1])
1077+
qml.MultiRZ(0.5, wires=[0])
1078+
qml.MultiRZ(0.5, wires=[0, 1, 3])
1079+
return qml.expval(qml.X(0))
1080+
1081+
without_qjit = circuit()
1082+
with_qjit = qml.qjit(circuit)
1083+
1084+
assert qml.math.allclose(without_qjit, with_qjit())
1085+
1086+
# TODO: Remove this static dict when capture & graph enabled support
1087+
# resource counting with qml.specs via from_plxpr conversion.
1088+
expected_resources = {"GlobalPhase": 14, "RZ": 20, "CNOT": 22, "Hadamard": 5}
1089+
1090+
resources = qml.specs(with_qjit, level="device")()["resources"]
1091+
assert resources.gate_types == expected_resources
1092+
10541093
qml.decomposition.disable_graph()
10551094
qml.capture.disable()
10561095

0 commit comments

Comments
 (0)