diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f97ea81ee4..7800584539 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -73,6 +73,18 @@

Improvements 🛠

+* The frontend no longer maintains a hardcoded list of runtime operations, + allowing arbitrary PennyLane gates with Quantum dialect-compatible + hyperparameters to be captured and represented in MLIR. + Users of the legacy compilation pipeline are unaffected, + as Catalyst continues to decompose unsupported gates + based on device capabilities before lowering to MLIR. + Gates that cannot be represented as MLIR operators will temporarily + raise a `CompileError` during program capture. + The long-term solution will integrate the new decomposition framework + with capture-enabled compilation. + [(#2215)](https://github.com/PennyLaneAI/catalyst/pull/2215) + * Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`. [(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149) diff --git a/frontend/catalyst/device/decomposition.py b/frontend/catalyst/device/decomposition.py index 22c66deac3..d72921ef31 100644 --- a/frontend/catalyst/device/decomposition.py +++ b/frontend/catalyst/device/decomposition.py @@ -42,6 +42,7 @@ is_controllable, is_differentiable, is_invertible, + is_lowering_compatible, is_supported, ) from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes @@ -227,7 +228,7 @@ def catalyst_acceptance( if match and is_controllable(op.base, capabilities): return match - elif is_supported(op, capabilities): + elif is_supported(op, capabilities) and is_lowering_compatible(op): return op.name return None diff --git a/frontend/catalyst/device/op_support.py b/frontend/catalyst/device/op_support.py index f86f0cc2c9..0df047c9fd 100644 --- a/frontend/catalyst/device/op_support.py +++ b/frontend/catalyst/device/op_support.py @@ -41,6 +41,28 @@ def is_supported(op: Operator, capabilities: DeviceCapabilities) -> bool: return op.name in capabilities.operations +def is_lowering_compatible(op: Operator) -> bool: + """Check if an operation can be lowered to MLIR using JAX primitives.""" + # Exceptions for operations that are not quantum instructions but are allowed + # via custom lowering rules. + # TODO: Revisit this as more explicit ops will be added to Catalyst Compiler. + if isinstance(op, (qml.Snapshot, qml.PCPhase, qml.MultiRZ)): + return True + + # Accepted hyperparameters for quantum instructions bind calls + _accepted_hyperparams = { + "base", + "n_wires", + "num_wires", + "control_wires", + "control_values", + "work_wires", + "work_wire_type", + } + + return set(op.hyperparameters).issubset(_accepted_hyperparams) + + def _is_grad_recipe_same_as_catalyst(op): """Checks that the grad_recipe for the op matches the hard coded one in Catalyst.""" diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index b4f92bf022..d6708b3851 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -56,44 +56,6 @@ logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -RUNTIME_OPERATIONS = [ - "CNOT", - "ControlledPhaseShift", - "CRot", - "CRX", - "CRY", - "CRZ", - "CSWAP", - "CY", - "CZ", - "Hadamard", - "Identity", - "IsingXX", - "IsingXY", - "IsingYY", - "IsingZZ", - "SingleExcitation", - "DoubleExcitation", - "ISWAP", - "MultiRZ", - "PauliX", - "PauliY", - "PauliZ", - "PCPhase", - "PhaseShift", - "PSWAP", - "QubitUnitary", - "Rot", - "RX", - "RY", - "RZ", - "S", - "SWAP", - "T", - "Toffoli", - "GlobalPhase", -] - RUNTIME_OBSERVABLES = [ "Identity", "PauliX", @@ -109,11 +71,9 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] -# The runtime interface does not care about specific gate properties, so set them all to True. -RUNTIME_OPERATIONS = { - op: OperatorProperties(invertible=True, controllable=True, differentiable=True) - for op in RUNTIME_OPERATIONS -} +# A list of custom operations supported by the Catalyst compiler. +# This is useful especially for testing a device with custom operations. +CUSTOM_OPERATIONS = {} RUNTIME_OBSERVABLES = { obs: OperatorProperties(invertible=True, controllable=True, differentiable=True) @@ -199,6 +159,13 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo: return BackendInfo(dname, device_name, device_lpath, device_kwargs) +def union_operations( + a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties] +) -> Dict[str, OperatorProperties]: + """Union of two sets of operator properties""" + return {**a, **b} + + def intersect_operations( a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties] ) -> Dict[str, OperatorProperties]: @@ -223,8 +190,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev qjit_capabilities = deepcopy(target_capabilities) # Intersection of gates and observables supported by the device and by Catalyst runtime. - qjit_capabilities.operations = intersect_operations( - target_capabilities.operations, RUNTIME_OPERATIONS + qjit_capabilities.operations = union_operations( + target_capabilities.operations, CUSTOM_OPERATIONS ) qjit_capabilities.observables = intersect_operations( target_capabilities.observables, RUNTIME_OBSERVABLES diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index 59c6e516c7..177408addc 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -36,6 +36,7 @@ from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.measurements import CountsMP +from catalyst.device.op_support import is_lowering_compatible from catalyst.jax_extras import jaxpr_pad_consts from catalyst.jax_primitives import ( AbstractQbit, @@ -183,6 +184,14 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w if (fn := _special_op_bind_call.get(type(op))) is not None: bind_fn = partial(fn, hyperparameters=op.hyperparameters) else: + # TODO: Remove this after enabling the graph-based decomposition by default + # With graph enabled, all unsupported templates and operations will be decomposed + # away resulting the same behaviour with capture disabled in Catalyst. + if not is_lowering_compatible(op): + raise CompileError( + f"Operation {op.name} with hyperparameters {list(op.hyperparameters.keys())} " + "is not compatible with quantum instructions." + ) bind_fn = qinst_p.bind out_qubits = bind_fn( diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 8cd9ff764c..922b9dbd78 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -457,3 +457,33 @@ def circuit2(): test_two_qnodes_with_different_passes_in_one_workflow() + + +def test_capture_custom_op(): + """Test capture of a custom op""" + + dev = qml.device("lightning.qubit", wires=2) + + class MuCustomOp(qml.operation.Operator): + """A custom operator for testing.""" + + def __init__(self, theta, wires): + """Initialize the custom operator.""" + super().__init__(theta, wires=wires) + + qml.capture.enable() + + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(): + # CHECK: [[QUBIT_1:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[QUBIT_2:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: {{%.+}} = quantum.custom "MuCustomOp"({{%.+}}) [[QUBIT_1]], [[QUBIT_2]] : !quantum.bit, !quantum.bit + MuCustomOp(0.5, wires=[0, 1]) + return qml.state() + + print(circuit.mlir) + qml.capture.disable() + + +test_capture_custom_op() diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index c93ddb0b0d..e969ee9aa4 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -36,6 +36,8 @@ pytestmark = pytest.mark.usefixtures("disable_capture") +# pylint: disable=too-many-lines + def catalyst_execute_jaxpr(jaxpr): """Create a function capable of executing the provided catalyst-variant jaxpr.""" @@ -194,6 +196,29 @@ def c(): ): from_plxpr(jaxpr)() + def test_unsupported_op(self): + """Test that a CompileError is raised when an unsupported op is encountered.""" + + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def circuit(): + qml.QROM( + bitstrings=["010", "111", "110", "000"], + control_wires=[0, 1], + target_wires=[2, 3, 4], + work_wires=[5, 6, 7], + ) + return qml.state() + + jaxpr = jax.make_jaxpr(circuit)() + + with pytest.raises( + catalyst.utils.exceptions.CompileError, + match="Operation QROM with hyperparameters", + ): + from_plxpr(jaxpr)() + class TestCatalystCompareJaxpr: """Test comparing catalyst and pennylane jaxpr for a variety of situations.""" diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index a918ec09d4..916e3e7c97 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -118,14 +118,14 @@ def test_decompose_integration(self): @qml.qnode(dev) def circuit(theta: float): - qml.SingleExcitationPlus(theta, wires=[0, 1]) + qml.OrbitalRotation(theta, wires=[0, 1, 2, 3]) return qml.state() mlir = qjit(circuit, target="mlir").mlir + assert "SingleExcitation" in mlir assert "Hadamard" in mlir - assert "CNOT" in mlir - assert "RY" in mlir - assert "SingleExcitationPlus" not in mlir + assert "RX" in mlir + assert "OrbitalRotation" not in mlir def test_decompose_ops_to_unitary(self): """Test the decompose ops to unitary transform.""" diff --git a/frontend/test/pytest/test_verification.py b/frontend/test/pytest/test_verification.py index e50d5c10ce..b88cf26642 100644 --- a/frontend/test/pytest/test_verification.py +++ b/frontend/test/pytest/test_verification.py @@ -37,7 +37,7 @@ from catalyst.api_extensions import HybridAdjoint, HybridCtrl from catalyst.compiler import get_lib_path from catalyst.device import get_device_capabilities -from catalyst.device.qjit_device import RUNTIME_OPERATIONS, get_qjit_device_capabilities +from catalyst.device.qjit_device import CUSTOM_OPERATIONS, get_qjit_device_capabilities from catalyst.device.verification import validate_measurements # pylint: disable = unused-argument, unnecessary-lambda-assignment, unnecessary-lambda @@ -290,7 +290,7 @@ def test_non_controllable_gate_hybridctrl(self): # Note: The HybridCtrl operator is not currently supported with the QJIT device, but the # verification structure is in place, so we test the verification of its nested operators by # adding HybridCtrl to the list of native gates for the custom base device and by patching - # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test. + # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test. @qml.qnode( get_custom_device( @@ -302,12 +302,12 @@ def f(x: float): assert isinstance(op, HybridCtrl), f"op expected to be HybridCtrl but got {type(op)}" return qml.expval(qml.PauliX(0)) - runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) + runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS) runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) - with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl): + with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl): with pytest.raises(CompileError, match="PauliZ is not controllable"): qjit(f)(1.2) @@ -321,7 +321,7 @@ def test_hybridctrl_raises_error(self): """Test that a HybridCtrl operator is rejected by the verification.""" # TODO: If you are deleting this test because HybridCtrl support has been added, consider - # updating the tests that patch RUNTIME_OPERATIONS to inclue HybridCtrl accordingly + # updating the tests that patch CUSTOM_OPERATIONS to inclue HybridCtrl accordingly @qml.qnode(get_custom_device(non_controllable_gates={"PauliZ"}, wires=4)) def f(x: float): @@ -391,7 +391,7 @@ def test_hybrid_ctrl_containing_adjoint(self, adjoint_type, unsupported_gate_att # Note: The HybridCtrl operator is not currently supported with the QJIT device, but the # verification structure is in place, so we test the verification of its nested operators by # adding HybridCtrl to the list of native gates for the custom base device and by patching - # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test. + # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test. def _ops(x, wires): if adjoint_type == HybridAdjoint: @@ -410,12 +410,12 @@ def f(x: float): assert isinstance(base, adjoint_type), f"expected {adjoint_type} but got {type(op)}" return qml.expval(qml.PauliX(0)) - runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) + runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS) runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) - with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl): + with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl): with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"): qjit(f)(1.2) @@ -434,7 +434,7 @@ def test_hybrid_adjoint_containing_hybrid_ctrl(self, ctrl_type, unsupported_gate # Note: The HybridCtrl operator is not currently supported with the QJIT device, but the # verification structure is in place, so we test the verification of its nested operators by # adding HybridCtrl to the list of native gates for the custom base device and by patching - # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test. + # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test. def _ops(x, wires): if ctrl_type == HybridCtrl: @@ -453,12 +453,12 @@ def f(x: float): assert isinstance(base, ctrl_type), f"expected {ctrl_type} but got {type(op)}" return qml.expval(qml.PauliX(0)) - runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) + runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS) runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) - with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl): + with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl): with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"): qjit(f)(1.2)