Skip to content

Commit aac5ac8

Browse files
Capture PPR/PPM passes and register MLIR pass names (#8519)
**Context:** Currently, we can run Catalyst/MLIR passes with the unified compiler alongside xDSL passes, _if they are represented in the MLIR_. However, for current MBQC workloads, we want to be able to use them the `capture` enables, and currently they can't be captured (or converted to MLIR). **Description of the Change:** We add dummy versions of the transforms that just raise a `NotImplemented` error if executed, but are `transform` objects and are captured in `plxpr`. We register them with the relevant Catalyst pass name so they can be converted from plxpr to MLIR. **Benefits:** We unblock current MBQC work with the PPR/PPM compilation pipeline. **Possible Drawbacks:** A more long-term solution to unifying the API for the various transform/pass types is on the roadmap, and ideally we would use that instead of manually add the transforms we are interested in to the `ftqc` module. This is intended to be a temporary patch to unblock current MBQC work. --------- Co-authored-by: Joey Carter <[email protected]>
1 parent 1a7542c commit aac5ac8

File tree

3 files changed

+184
-0
lines changed

3 files changed

+184
-0
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@
168168
[(#8486)](https://github.com/PennyLaneAI/pennylane/pull/8486)
169169
[(#8495)](https://github.com/PennyLaneAI/pennylane/pull/8495)
170170

171+
* The `ftqc` module now includes dummy transforms for several Catalyst/MLIR passes (`to-ppr`, `commute-ppr`, `merge-ppr-ppm`, `pprm-to-mbqc`
172+
and `reduce-t-depth`), to allow them to be captured as primitives in PLxPR and mapped to the MLIR passes in Catalyst. This enables using the passes with the unified compiler and program capture.
173+
[(#8519)](https://github.com/PennyLaneAI/pennylane/pull/8519)
174+
171175
* The decompositions for several templates have been updated to use
172176
:class:`~.ops.op_math.ChangeOpBasis`, which makes their decompositions more resource efficient
173177
by eliminating unnecessary controlled operations. The templates include :class:`~.PhaseAdder`,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""This module contains wrapper transforms to provide an API to bind primitives for
15+
Catalyst passes when using capture and the unified compiler. This is a temporary fix
16+
to manually add passes relevant for ongoing MBQC work. It can be removed one a more
17+
general solution for all Catalyst passes is in place."""
18+
19+
from catalyst.from_plxpr import register_transform
20+
21+
from ..transforms.core import transform
22+
23+
24+
@transform
25+
def to_ppr(tape):
26+
"""A wrapper that allows us to register a primitive that represents the transform during capture.
27+
The transform itself is only implemented in Catalyst. This is just to enable capture."""
28+
raise NotImplementedError("The to_ppm pass is only implemented when using capture and QJIT.")
29+
30+
31+
register_transform(to_ppr, "to-ppr", False)
32+
33+
34+
@transform
35+
def commute_ppr(tape):
36+
"""A wrapper that allows us to register a primitive that represents the transform during capture.
37+
The transform itself is only implemented in Catalyst. This is just to enable capture."""
38+
raise NotImplementedError(
39+
"The commute_ppr pass is only implemented when using capture and QJIT."
40+
)
41+
42+
43+
register_transform(commute_ppr, "commute-ppr", False)
44+
45+
46+
@transform
47+
def merge_ppr_ppm(tape):
48+
"""A wrapper that allows us to register a primitive that represents the transform during capture.
49+
The transform itself is only implemented in Catalyst. This is just to enable capture."""
50+
raise NotImplementedError(
51+
"The merge_ppr_ppm pass is only implemented when using capture and QJIT."
52+
)
53+
54+
55+
register_transform(merge_ppr_ppm, "merge-ppr-ppm", False)
56+
57+
58+
@transform
59+
def ppm_to_mbqc(tape):
60+
"""A wrapper that allows us to register a primitive that represents the transform during capture.
61+
The transform itself is only implemented in Catalyst. This is just to enable capture."""
62+
raise NotImplementedError(
63+
"The ppm_to_mbqc pass is only implemented when using capture and QJIT."
64+
)
65+
66+
67+
register_transform(ppm_to_mbqc, "ppm-to-mbqc", False)
68+
69+
70+
@transform
71+
def reduce_t_depth(tape):
72+
"""A wrapper that allows us to register a primitive that represents the transform during capture.
73+
The transform itself is only implemented in Catalyst. This is just to enable capture."""
74+
raise NotImplementedError(
75+
"The reduce_t_depth pass is only implemented when using capture and QJIT."
76+
)
77+
78+
79+
register_transform(reduce_t_depth, "reduce-t-depth", False)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Unit tests for wrappers for capturing catalyst passes"""
15+
16+
import pytest
17+
18+
pytest.importorskip("catalyst")
19+
# pylint: disable=wrong-import-position
20+
21+
import pennylane as qml
22+
from pennylane.capture import make_plxpr
23+
from pennylane.ftqc.catalyst_pass_aliases import (
24+
commute_ppr,
25+
merge_ppr_ppm,
26+
ppm_to_mbqc,
27+
reduce_t_depth,
28+
to_ppr,
29+
)
30+
31+
pytestmark = pytest.mark.external
32+
33+
34+
@pytest.mark.catalyst
35+
@pytest.mark.usefixtures("enable_disable_plxpr")
36+
@pytest.mark.parametrize(
37+
"pass_fn", [to_ppr, commute_ppr, merge_ppr_ppm, ppm_to_mbqc, reduce_t_depth]
38+
)
39+
def test_pass_is_captured(pass_fn):
40+
41+
@pass_fn
42+
@qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000)
43+
def circ():
44+
qml.H(0)
45+
qml.S(0)
46+
qml.T(1)
47+
qml.CNOT([0, 1])
48+
return qml.sample()
49+
50+
plxpr = make_plxpr(circ)()
51+
prim = plxpr.eqns[0].primitive
52+
assert prim.name == pass_fn.__name__ + "_transform"
53+
54+
55+
@pytest.mark.catalyst
56+
@pytest.mark.usefixtures("enable_disable_plxpr")
57+
@pytest.mark.parametrize(
58+
"pass_fn, pass_name",
59+
[
60+
(to_ppr, "to-ppr"),
61+
(commute_ppr, "commute-ppr"),
62+
(merge_ppr_ppm, "merge-ppr-ppm"),
63+
(ppm_to_mbqc, "ppm-to-mbqc"),
64+
(reduce_t_depth, "reduce-t-depth"),
65+
],
66+
)
67+
def test_converstion_to_mlir(pass_fn, pass_name):
68+
"""Test that we can generate MLIR from the captured circuit and that the generated MLIR
69+
includes the pass name we are mapping to"""
70+
71+
@qml.qjit(target="mlir")
72+
@pass_fn
73+
@qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000)
74+
def circ():
75+
qml.H(0)
76+
qml.S(0)
77+
qml.T(1)
78+
qml.CNOT([0, 1])
79+
return qml.sample()
80+
81+
assert pass_name in circ.mlir
82+
83+
84+
@pytest.mark.catalyst
85+
@pytest.mark.parametrize(
86+
"pass_fn", [to_ppr, commute_ppr, merge_ppr_ppm, ppm_to_mbqc, reduce_t_depth]
87+
)
88+
def test_pass_without_qjit_raises_error(pass_fn):
89+
"""Test that trying to apply the transform without QJIT raises an error"""
90+
91+
@pass_fn
92+
@qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000)
93+
def circ():
94+
qml.H(0)
95+
qml.S(0)
96+
qml.T(1)
97+
qml.CNOT([0, 1])
98+
return qml.sample()
99+
100+
with pytest.raises(NotImplementedError, match="only implemented when using capture and QJIT"):
101+
circ()

0 commit comments

Comments
 (0)