Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@

<h3>Improvements 🛠</h3>

* The frontend no longer maintains a hardcoded list of runtime operations,
allowing arbitrary PennyLane gates with Quantum dialect-compatible
hyperparameters to be captured and represented in MLIR.
Users of the legacy compilation pipeline are unaffected,
as Catalyst continues to decompose unsupported gates
based on device capabilities before lowering to MLIR.
Gates that cannot be represented as MLIR operators will temporarily
raise a `CompileError` during program capture.
The long-term solution will integrate the new decomposition framework
with capture-enabled compilation.
[(#2215)](https://github.com/PennyLaneAI/catalyst/pull/2215)

* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now
be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`.
[(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149)
Expand Down
3 changes: 2 additions & 1 deletion frontend/catalyst/device/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_controllable,
is_differentiable,
is_invertible,
is_lowering_compatible,
is_supported,
)
from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes
Expand Down Expand Up @@ -227,7 +228,7 @@ def catalyst_acceptance(
if match and is_controllable(op.base, capabilities):
return match

elif is_supported(op, capabilities):
elif is_supported(op, capabilities) and is_lowering_compatible(op):
return op.name

return None
Expand Down
22 changes: 22 additions & 0 deletions frontend/catalyst/device/op_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ def is_supported(op: Operator, capabilities: DeviceCapabilities) -> bool:
return op.name in capabilities.operations


def is_lowering_compatible(op: Operator) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this function is sufficient to guarantee compatibility with the quantum dialect. Looking at the structure of an operator:

Operators are uniquely defined by their name, the wires they act on, their (trainable) parameters,
and their (non-trainable) hyperparameters. The trainable parameters can be tensors of any
supported auto-differentiation framework.

Name and wires are always supported by quantum.custom.

Hyperparameters are being checked below, although I don't know if the check is not too generic (i.e. I don't know if all hyperparameters are supported by all operation types).

The thing that's missing entirely is verifying parameters. Pennylane only requires that they be tensor-like, so this could be an arbitrary sequence of arbitrary tensors. This cannot be mapped to the quantum dialect at the moment. Some special ops support certain tensors (like 2D complex for quantum.unitary), but for the generic quantum.custom we would only support a sequence of scalar floats (or perhaps a single tensor that is flattened into a sequence of floats, but this could get highly inefficient if the tensor is large).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider operations like quantum.measure at all, since its signature even has a return value 🤔

"""Check if an operation can be lowered to MLIR using JAX primitives."""
# Exceptions for operations that are not quantum instructions but are allowed
# via custom lowering rules.
# TODO: Revisit this as more explicit ops will be added to Catalyst Compiler.
if isinstance(op, (qml.Snapshot, qml.PCPhase, qml.MultiRZ)):
return True

# Accepted hyperparameters for quantum instructions bind calls
_accepted_hyperparams = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albi3ro Here's the minimum set of hyperparams that we always need to check for compatibility with the quantum.custom MLIR Op.

"base",
"n_wires",
"num_wires",
"control_wires",
"control_values",
"work_wires",
"work_wire_type",
Comment on lines +59 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle work wires when mapping to mlir?

}

return set(op.hyperparameters).issubset(_accepted_hyperparams)


def _is_grad_recipe_same_as_catalyst(op):
"""Checks that the grad_recipe for the op matches the hard coded one in Catalyst."""

Expand Down
57 changes: 12 additions & 45 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,6 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

RUNTIME_OPERATIONS = [
"CNOT",
"ControlledPhaseShift",
"CRot",
"CRX",
"CRY",
"CRZ",
"CSWAP",
"CY",
"CZ",
"Hadamard",
"Identity",
"IsingXX",
"IsingXY",
"IsingYY",
"IsingZZ",
"SingleExcitation",
"DoubleExcitation",
"ISWAP",
"MultiRZ",
"PauliX",
"PauliY",
"PauliZ",
"PCPhase",
"PhaseShift",
"PSWAP",
"QubitUnitary",
"Rot",
"RX",
"RY",
"RZ",
"S",
"SWAP",
"T",
"Toffoli",
"GlobalPhase",
]

RUNTIME_OBSERVABLES = [
"Identity",
"PauliX",
Expand All @@ -109,11 +71,9 @@

RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"]

# The runtime interface does not care about specific gate properties, so set them all to True.
RUNTIME_OPERATIONS = {
op: OperatorProperties(invertible=True, controllable=True, differentiable=True)
for op in RUNTIME_OPERATIONS
}
# A list of custom operations supported by the Catalyst compiler.
# This is useful especially for testing a device with custom operations.
CUSTOM_OPERATIONS = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a point to keep this empty set, considering its entire purpose is to be unioned with something else?

Or is it just for tests? In that case possibly related: #2114


RUNTIME_OBSERVABLES = {
obs: OperatorProperties(invertible=True, controllable=True, differentiable=True)
Expand Down Expand Up @@ -199,6 +159,13 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo:
return BackendInfo(dname, device_name, device_lpath, device_kwargs)


def union_operations(
a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
) -> Dict[str, OperatorProperties]:
"""Union of two sets of operator properties"""
return {**a, **b}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why not just

>>> x
{1: 2}
>>> y
{3: 4}
>>> x.update(y)
>>> x
{1: 2, 3: 4}

?



def intersect_operations(
a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
) -> Dict[str, OperatorProperties]:
Expand All @@ -223,8 +190,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev
qjit_capabilities = deepcopy(target_capabilities)

# Intersection of gates and observables supported by the device and by Catalyst runtime.
qjit_capabilities.operations = intersect_operations(
target_capabilities.operations, RUNTIME_OPERATIONS
qjit_capabilities.operations = union_operations(
target_capabilities.operations, CUSTOM_OPERATIONS
)
qjit_capabilities.observables = intersect_operations(
target_capabilities.observables, RUNTIME_OBSERVABLES
Expand Down
9 changes: 9 additions & 0 deletions frontend/catalyst/from_plxpr/qfunc_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP

from catalyst.device.op_support import is_lowering_compatible
from catalyst.jax_extras import jaxpr_pad_consts
from catalyst.jax_primitives import (
AbstractQbit,
Expand Down Expand Up @@ -183,6 +184,14 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w
if (fn := _special_op_bind_call.get(type(op))) is not None:
bind_fn = partial(fn, hyperparameters=op.hyperparameters)
else:
# TODO: Remove this after enabling the graph-based decomposition by default
# With graph enabled, all unsupported templates and operations will be decomposed
# away resulting the same behaviour with capture disabled in Catalyst.
if not is_lowering_compatible(op):
raise CompileError(
f"Operation {op.name} with hyperparameters {list(op.hyperparameters.keys())} "
"is not compatible with quantum instructions."
)
bind_fn = qinst_p.bind

out_qubits = bind_fn(
Expand Down
30 changes: 30 additions & 0 deletions frontend/test/lit/test_from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,33 @@ def circuit2():


test_two_qnodes_with_different_passes_in_one_workflow()


def test_capture_custom_op():
"""Test capture of a custom op"""

dev = qml.device("lightning.qubit", wires=2)

class MuCustomOp(qml.operation.Operator):
"""A custom operator for testing."""

def __init__(self, theta, wires):
"""Initialize the custom operator."""
super().__init__(theta, wires=wires)

qml.capture.enable()

@qml.qjit(target="mlir")
@qml.qnode(dev)
def circuit():
# CHECK: [[QUBIT_1:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
# CHECK-NEXT: [[QUBIT_2:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
# CHECK-NEXT: {{%.+}} = quantum.custom "MuCustomOp"({{%.+}}) [[QUBIT_1]], [[QUBIT_2]] : !quantum.bit, !quantum.bit
MuCustomOp(0.5, wires=[0, 1])
return qml.state()

print(circuit.mlir)
qml.capture.disable()


test_capture_custom_op()
25 changes: 25 additions & 0 deletions frontend/test/pytest/from_plxpr/test_from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

pytestmark = pytest.mark.usefixtures("disable_capture")

# pylint: disable=too-many-lines


def catalyst_execute_jaxpr(jaxpr):
"""Create a function capable of executing the provided catalyst-variant jaxpr."""
Expand Down Expand Up @@ -194,6 +196,29 @@ def c():
):
from_plxpr(jaxpr)()

def test_unsupported_op(self):
"""Test that a CompileError is raised when an unsupported op is encountered."""

dev = qml.device("lightning.qubit", wires=5)

@qml.qnode(dev)
def circuit():
qml.QROM(
bitstrings=["010", "111", "110", "000"],
control_wires=[0, 1],
target_wires=[2, 3, 4],
work_wires=[5, 6, 7],
)
return qml.state()

jaxpr = jax.make_jaxpr(circuit)()

with pytest.raises(
catalyst.utils.exceptions.CompileError,
match="Operation QROM with hyperparameters",
):
from_plxpr(jaxpr)()


class TestCatalystCompareJaxpr:
"""Test comparing catalyst and pennylane jaxpr for a variety of situations."""
Expand Down
8 changes: 4 additions & 4 deletions frontend/test/pytest/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ def test_decompose_integration(self):

@qml.qnode(dev)
def circuit(theta: float):
qml.SingleExcitationPlus(theta, wires=[0, 1])
qml.OrbitalRotation(theta, wires=[0, 1, 2, 3])
return qml.state()

mlir = qjit(circuit, target="mlir").mlir
assert "SingleExcitation" in mlir
assert "Hadamard" in mlir
assert "CNOT" in mlir
assert "RY" in mlir
assert "SingleExcitationPlus" not in mlir
assert "RX" in mlir
assert "OrbitalRotation" not in mlir

def test_decompose_ops_to_unitary(self):
"""Test the decompose ops to unitary transform."""
Expand Down
22 changes: 11 additions & 11 deletions frontend/test/pytest/test_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from catalyst.api_extensions import HybridAdjoint, HybridCtrl
from catalyst.compiler import get_lib_path
from catalyst.device import get_device_capabilities
from catalyst.device.qjit_device import RUNTIME_OPERATIONS, get_qjit_device_capabilities
from catalyst.device.qjit_device import CUSTOM_OPERATIONS, get_qjit_device_capabilities
from catalyst.device.verification import validate_measurements

# pylint: disable = unused-argument, unnecessary-lambda-assignment, unnecessary-lambda
Expand Down Expand Up @@ -290,7 +290,7 @@ def test_non_controllable_gate_hybridctrl(self):
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
# the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
# the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.

@qml.qnode(
get_custom_device(
Expand All @@ -302,12 +302,12 @@ def f(x: float):
assert isinstance(op, HybridCtrl), f"op expected to be HybridCtrl but got {type(op)}"
return qml.expval(qml.PauliX(0))

runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)

with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match="PauliZ is not controllable"):
qjit(f)(1.2)

Expand All @@ -321,7 +321,7 @@ def test_hybridctrl_raises_error(self):
"""Test that a HybridCtrl operator is rejected by the verification."""

# TODO: If you are deleting this test because HybridCtrl support has been added, consider
# updating the tests that patch RUNTIME_OPERATIONS to inclue HybridCtrl accordingly
# updating the tests that patch CUSTOM_OPERATIONS to inclue HybridCtrl accordingly

@qml.qnode(get_custom_device(non_controllable_gates={"PauliZ"}, wires=4))
def f(x: float):
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_hybrid_ctrl_containing_adjoint(self, adjoint_type, unsupported_gate_att
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
# the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
# the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.

def _ops(x, wires):
if adjoint_type == HybridAdjoint:
Expand All @@ -410,12 +410,12 @@ def f(x: float):
assert isinstance(base, adjoint_type), f"expected {adjoint_type} but got {type(op)}"
return qml.expval(qml.PauliX(0))

runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)

with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"):
qjit(f)(1.2)

Expand All @@ -434,7 +434,7 @@ def test_hybrid_adjoint_containing_hybrid_ctrl(self, ctrl_type, unsupported_gate
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
# the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
# the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.

def _ops(x, wires):
if ctrl_type == HybridCtrl:
Expand All @@ -453,12 +453,12 @@ def f(x: float):
assert isinstance(base, ctrl_type), f"expected {ctrl_type} but got {type(op)}"
return qml.expval(qml.PauliX(0))

runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)

with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"):
qjit(f)(1.2)

Expand Down