diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index f97ea81ee4..7800584539 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -73,6 +73,18 @@
Improvements ðŸ›
+* The frontend no longer maintains a hardcoded list of runtime operations,
+ allowing arbitrary PennyLane gates with Quantum dialect-compatible
+ hyperparameters to be captured and represented in MLIR.
+ Users of the legacy compilation pipeline are unaffected,
+ as Catalyst continues to decompose unsupported gates
+ based on device capabilities before lowering to MLIR.
+ Gates that cannot be represented as MLIR operators will temporarily
+ raise a `CompileError` during program capture.
+ The long-term solution will integrate the new decomposition framework
+ with capture-enabled compilation.
+ [(#2215)](https://github.com/PennyLaneAI/catalyst/pull/2215)
+
* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now
be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`.
[(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149)
diff --git a/frontend/catalyst/device/decomposition.py b/frontend/catalyst/device/decomposition.py
index 22c66deac3..d72921ef31 100644
--- a/frontend/catalyst/device/decomposition.py
+++ b/frontend/catalyst/device/decomposition.py
@@ -42,6 +42,7 @@
is_controllable,
is_differentiable,
is_invertible,
+ is_lowering_compatible,
is_supported,
)
from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes
@@ -227,7 +228,7 @@ def catalyst_acceptance(
if match and is_controllable(op.base, capabilities):
return match
- elif is_supported(op, capabilities):
+ elif is_supported(op, capabilities) and is_lowering_compatible(op):
return op.name
return None
diff --git a/frontend/catalyst/device/op_support.py b/frontend/catalyst/device/op_support.py
index f86f0cc2c9..0df047c9fd 100644
--- a/frontend/catalyst/device/op_support.py
+++ b/frontend/catalyst/device/op_support.py
@@ -41,6 +41,28 @@ def is_supported(op: Operator, capabilities: DeviceCapabilities) -> bool:
return op.name in capabilities.operations
+def is_lowering_compatible(op: Operator) -> bool:
+ """Check if an operation can be lowered to MLIR using JAX primitives."""
+ # Exceptions for operations that are not quantum instructions but are allowed
+ # via custom lowering rules.
+ # TODO: Revisit this as more explicit ops will be added to Catalyst Compiler.
+ if isinstance(op, (qml.Snapshot, qml.PCPhase, qml.MultiRZ)):
+ return True
+
+ # Accepted hyperparameters for quantum instructions bind calls
+ _accepted_hyperparams = {
+ "base",
+ "n_wires",
+ "num_wires",
+ "control_wires",
+ "control_values",
+ "work_wires",
+ "work_wire_type",
+ }
+
+ return set(op.hyperparameters).issubset(_accepted_hyperparams)
+
+
def _is_grad_recipe_same_as_catalyst(op):
"""Checks that the grad_recipe for the op matches the hard coded one in Catalyst."""
diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py
index b4f92bf022..d6708b3851 100644
--- a/frontend/catalyst/device/qjit_device.py
+++ b/frontend/catalyst/device/qjit_device.py
@@ -56,44 +56,6 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
-RUNTIME_OPERATIONS = [
- "CNOT",
- "ControlledPhaseShift",
- "CRot",
- "CRX",
- "CRY",
- "CRZ",
- "CSWAP",
- "CY",
- "CZ",
- "Hadamard",
- "Identity",
- "IsingXX",
- "IsingXY",
- "IsingYY",
- "IsingZZ",
- "SingleExcitation",
- "DoubleExcitation",
- "ISWAP",
- "MultiRZ",
- "PauliX",
- "PauliY",
- "PauliZ",
- "PCPhase",
- "PhaseShift",
- "PSWAP",
- "QubitUnitary",
- "Rot",
- "RX",
- "RY",
- "RZ",
- "S",
- "SWAP",
- "T",
- "Toffoli",
- "GlobalPhase",
-]
-
RUNTIME_OBSERVABLES = [
"Identity",
"PauliX",
@@ -109,11 +71,9 @@
RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"]
-# The runtime interface does not care about specific gate properties, so set them all to True.
-RUNTIME_OPERATIONS = {
- op: OperatorProperties(invertible=True, controllable=True, differentiable=True)
- for op in RUNTIME_OPERATIONS
-}
+# A list of custom operations supported by the Catalyst compiler.
+# This is useful especially for testing a device with custom operations.
+CUSTOM_OPERATIONS = {}
RUNTIME_OBSERVABLES = {
obs: OperatorProperties(invertible=True, controllable=True, differentiable=True)
@@ -199,6 +159,13 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo:
return BackendInfo(dname, device_name, device_lpath, device_kwargs)
+def union_operations(
+ a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
+) -> Dict[str, OperatorProperties]:
+ """Union of two sets of operator properties"""
+ return {**a, **b}
+
+
def intersect_operations(
a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
) -> Dict[str, OperatorProperties]:
@@ -223,8 +190,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev
qjit_capabilities = deepcopy(target_capabilities)
# Intersection of gates and observables supported by the device and by Catalyst runtime.
- qjit_capabilities.operations = intersect_operations(
- target_capabilities.operations, RUNTIME_OPERATIONS
+ qjit_capabilities.operations = union_operations(
+ target_capabilities.operations, CUSTOM_OPERATIONS
)
qjit_capabilities.observables = intersect_operations(
target_capabilities.observables, RUNTIME_OBSERVABLES
diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py
index 59c6e516c7..177408addc 100644
--- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py
+++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py
@@ -36,6 +36,7 @@
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP
+from catalyst.device.op_support import is_lowering_compatible
from catalyst.jax_extras import jaxpr_pad_consts
from catalyst.jax_primitives import (
AbstractQbit,
@@ -183,6 +184,14 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w
if (fn := _special_op_bind_call.get(type(op))) is not None:
bind_fn = partial(fn, hyperparameters=op.hyperparameters)
else:
+ # TODO: Remove this after enabling the graph-based decomposition by default
+ # With graph enabled, all unsupported templates and operations will be decomposed
+ # away resulting the same behaviour with capture disabled in Catalyst.
+ if not is_lowering_compatible(op):
+ raise CompileError(
+ f"Operation {op.name} with hyperparameters {list(op.hyperparameters.keys())} "
+ "is not compatible with quantum instructions."
+ )
bind_fn = qinst_p.bind
out_qubits = bind_fn(
diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py
index 8cd9ff764c..922b9dbd78 100644
--- a/frontend/test/lit/test_from_plxpr.py
+++ b/frontend/test/lit/test_from_plxpr.py
@@ -457,3 +457,33 @@ def circuit2():
test_two_qnodes_with_different_passes_in_one_workflow()
+
+
+def test_capture_custom_op():
+ """Test capture of a custom op"""
+
+ dev = qml.device("lightning.qubit", wires=2)
+
+ class MuCustomOp(qml.operation.Operator):
+ """A custom operator for testing."""
+
+ def __init__(self, theta, wires):
+ """Initialize the custom operator."""
+ super().__init__(theta, wires=wires)
+
+ qml.capture.enable()
+
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit():
+ # CHECK: [[QUBIT_1:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ # CHECK-NEXT: [[QUBIT_2:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
+ # CHECK-NEXT: {{%.+}} = quantum.custom "MuCustomOp"({{%.+}}) [[QUBIT_1]], [[QUBIT_2]] : !quantum.bit, !quantum.bit
+ MuCustomOp(0.5, wires=[0, 1])
+ return qml.state()
+
+ print(circuit.mlir)
+ qml.capture.disable()
+
+
+test_capture_custom_op()
diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py
index c93ddb0b0d..e969ee9aa4 100644
--- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py
+++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py
@@ -36,6 +36,8 @@
pytestmark = pytest.mark.usefixtures("disable_capture")
+# pylint: disable=too-many-lines
+
def catalyst_execute_jaxpr(jaxpr):
"""Create a function capable of executing the provided catalyst-variant jaxpr."""
@@ -194,6 +196,29 @@ def c():
):
from_plxpr(jaxpr)()
+ def test_unsupported_op(self):
+ """Test that a CompileError is raised when an unsupported op is encountered."""
+
+ dev = qml.device("lightning.qubit", wires=5)
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.QROM(
+ bitstrings=["010", "111", "110", "000"],
+ control_wires=[0, 1],
+ target_wires=[2, 3, 4],
+ work_wires=[5, 6, 7],
+ )
+ return qml.state()
+
+ jaxpr = jax.make_jaxpr(circuit)()
+
+ with pytest.raises(
+ catalyst.utils.exceptions.CompileError,
+ match="Operation QROM with hyperparameters",
+ ):
+ from_plxpr(jaxpr)()
+
class TestCatalystCompareJaxpr:
"""Test comparing catalyst and pennylane jaxpr for a variety of situations."""
diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py
index a918ec09d4..916e3e7c97 100644
--- a/frontend/test/pytest/test_preprocess.py
+++ b/frontend/test/pytest/test_preprocess.py
@@ -118,14 +118,14 @@ def test_decompose_integration(self):
@qml.qnode(dev)
def circuit(theta: float):
- qml.SingleExcitationPlus(theta, wires=[0, 1])
+ qml.OrbitalRotation(theta, wires=[0, 1, 2, 3])
return qml.state()
mlir = qjit(circuit, target="mlir").mlir
+ assert "SingleExcitation" in mlir
assert "Hadamard" in mlir
- assert "CNOT" in mlir
- assert "RY" in mlir
- assert "SingleExcitationPlus" not in mlir
+ assert "RX" in mlir
+ assert "OrbitalRotation" not in mlir
def test_decompose_ops_to_unitary(self):
"""Test the decompose ops to unitary transform."""
diff --git a/frontend/test/pytest/test_verification.py b/frontend/test/pytest/test_verification.py
index e50d5c10ce..b88cf26642 100644
--- a/frontend/test/pytest/test_verification.py
+++ b/frontend/test/pytest/test_verification.py
@@ -37,7 +37,7 @@
from catalyst.api_extensions import HybridAdjoint, HybridCtrl
from catalyst.compiler import get_lib_path
from catalyst.device import get_device_capabilities
-from catalyst.device.qjit_device import RUNTIME_OPERATIONS, get_qjit_device_capabilities
+from catalyst.device.qjit_device import CUSTOM_OPERATIONS, get_qjit_device_capabilities
from catalyst.device.verification import validate_measurements
# pylint: disable = unused-argument, unnecessary-lambda-assignment, unnecessary-lambda
@@ -290,7 +290,7 @@ def test_non_controllable_gate_hybridctrl(self):
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
- # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
+ # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.
@qml.qnode(
get_custom_device(
@@ -302,12 +302,12 @@ def f(x: float):
assert isinstance(op, HybridCtrl), f"op expected to be HybridCtrl but got {type(op)}"
return qml.expval(qml.PauliX(0))
- runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
+ runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)
- with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
+ with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match="PauliZ is not controllable"):
qjit(f)(1.2)
@@ -321,7 +321,7 @@ def test_hybridctrl_raises_error(self):
"""Test that a HybridCtrl operator is rejected by the verification."""
# TODO: If you are deleting this test because HybridCtrl support has been added, consider
- # updating the tests that patch RUNTIME_OPERATIONS to inclue HybridCtrl accordingly
+ # updating the tests that patch CUSTOM_OPERATIONS to inclue HybridCtrl accordingly
@qml.qnode(get_custom_device(non_controllable_gates={"PauliZ"}, wires=4))
def f(x: float):
@@ -391,7 +391,7 @@ def test_hybrid_ctrl_containing_adjoint(self, adjoint_type, unsupported_gate_att
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
- # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
+ # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.
def _ops(x, wires):
if adjoint_type == HybridAdjoint:
@@ -410,12 +410,12 @@ def f(x: float):
assert isinstance(base, adjoint_type), f"expected {adjoint_type} but got {type(op)}"
return qml.expval(qml.PauliX(0))
- runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
+ runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)
- with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
+ with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"):
qjit(f)(1.2)
@@ -434,7 +434,7 @@ def test_hybrid_adjoint_containing_hybrid_ctrl(self, ctrl_type, unsupported_gate
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
- # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
+ # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.
def _ops(x, wires):
if ctrl_type == HybridCtrl:
@@ -453,12 +453,12 @@ def f(x: float):
assert isinstance(base, ctrl_type), f"expected {ctrl_type} but got {type(op)}"
return qml.expval(qml.PauliX(0))
- runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
+ runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)
- with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
+ with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"):
qjit(f)(1.2)