diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index a129568639..d3787f46eb 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -108,6 +108,16 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] +# A list of operations that can be represented +# in the Catalyst compiler. This will be a superset of +# the operations supported by the runtime. +# FIXME: ops with OpName(params, wires) signatures can be +# represented in the Catalyst compiler. Unfortunately, +# the signature info is not sufficient as there are +# templates with the same signature that should be +# disambiguated. +COMPILER_OPERATIONS = RUNTIME_OPERATIONS + # The runtime interface does not care about specific gate properties, so set them all to True. RUNTIME_OPERATIONS = { op: OperatorProperties(invertible=True, controllable=True, differentiable=True) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py new file mode 100644 index 0000000000..ee7c1f9e07 --- /dev/null +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -0,0 +1,223 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A transform for the new MLIR-based Catalyst decomposition system. +""" + + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from copy import copy +from typing import get_type_hints + +import jax +import pennylane as qml + +# GraphSolutionInterpreter: +from pennylane.decomposition import DecompositionGraph +from pennylane.measurements import MidMeasureMP +from pennylane.wires import WiresLike +import functools + +from catalyst.jax_primitives import decomposition_rule + +COMPILER_OPERATIONS_NUM_WIRES = { + "CNOT": 2, + "ControlledPhaseShift": 2, + "CRot": 2, + "CRX": 2, + "CRY": 2, + "CRZ": 2, + "CSWAP": 3, + "CY": 2, + "CZ": 2, + "Hadamard": 1, + "Identity": 1, + "IsingXX": 2, + "IsingXY": 2, + "IsingYY": 2, + "IsingZZ": 2, + "SingleExcitation": 2, + "DoubleExcitation": 4, + "ISWAP": 2, + "PauliX": 1, + "PauliY": 1, + "PauliZ": 1, + "PhaseShift": 1, + "PSWAP": 2, + "Rot": 1, + "RX": 1, + "RY": 1, + "RZ": 1, + "S": 1, + "SWAP": 2, + "T": 1, + "Toffoli": 3, + "U1": 1, + "U2": 1, + "U3": 1, +} + + +def create_decomposition_rule(func: Callable, op_name: str, num_wires: int): + """Create a decomposition rule from a function.""" + + sig_func = inspect.signature(func) + type_hints = get_type_hints(func) + + args = {} + op = None + for name in sig_func.parameters.keys(): + typ = type_hints.get(name, None) + + # Skip tailing kwargs in the rules + if name == "__": + continue + + if typ is float or name in ("phi", "theta", "omega", "delta"): + args[name] = float + elif typ is int: + args[name] = int + elif typ is WiresLike or name == "wires": + args[name] = qml.math.array([0] * num_wires, like="jax") # How come wires here are preserved? + elif name == "base": + with qml.capture.pause(): + op = qml.adjoint(qml.RZ(float, [0])) # How do i get this from the name? + else: + raise ValueError( + f"Unsupported type annotation {typ} for parameter {name} in func {func}." + ) + + # Update the name of decomposition rule + rule_name = "_rule" if func.__name__[0] == "_" else "_rule_" + func.__name__ = op_name + rule_name + func.__name__ + "_wires_" + str(num_wires) + return decomposition_rule(func, op=op)(**args) + + +# pylint: disable=too-few-public-methods +class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): + + """Interpreter for getting the decomposition graph solution + from a jaxpr when program capture is enabled. + """ + + def __init__( + self, + *, + gate_set=None, + fixed_decomps=None, + alt_decomps=None, + ): # pylint: disable=too-many-arguments + + if not qml.decomposition.enabled_graph(): + raise TypeError( + "The GraphSolutionInterpreter can only be used when" + "graph-based decomposition is enabled." + ) + + self._gate_set = gate_set + self._fixed_decomps = fixed_decomps + self._alt_decomps = alt_decomps + + self._captured = False + self._operations = set() + self._decomp_graph_solution = {} + + def interpret_operation(self, op: "qml.operation.Operator"): + """Interpret a PennyLane operation instance. + + Args: + op (Operator): a pennylane operator instance + + Returns: + Any + + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. + + We cache the list of operations seen during the interpretation + to build the decomposition graph in the later stages. + + See also: :meth:`~.interpret_operation_eqn`. + + """ + + self._operations.add(op) + data, struct = jax.tree_util.tree_flatten(op) + return jax.tree_util.tree_unflatten(struct, data) + + def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess"): + """Interpret a measurement process instance. + + Args: + measurement (MeasurementProcess): a measurement instance. + + See also :meth:`~.interpret_measurement_eqn`. + + """ + + if not self._captured and not isinstance(measurement, MidMeasureMP): + self._captured = True + self._decomp_graph_solution = _solve_decomposition_graph( + self._operations, + self._gate_set, + fixed_decomps=self._fixed_decomps, + alt_decomps=self._alt_decomps, + ) + + captured_ops = copy(self._operations) + for op, rule in self._decomp_graph_solution.items(): + if (o := next((o for o in captured_ops if o.name == op.op.name), None)) is not None: + create_decomposition_rule(rule, op_name=op.op.name, num_wires=len(o.wires)) + elif op.op.name in COMPILER_OPERATIONS_NUM_WIRES: + num_wires = COMPILER_OPERATIONS_NUM_WIRES[op.op.name] + create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) + else: + raise ValueError(f"Could not capture {op} without the number of wires.") + + data, struct = jax.tree_util.tree_flatten(measurement) + return jax.tree_util.tree_unflatten(struct, data) + + +# pylint: disable=protected-access +def _solve_decomposition_graph(operations, gate_set, fixed_decomps, alt_decomps): + """Get the decomposition graph solution for the given operations and gate set.""" + + # decomp_graph_solution + decomp_graph_solution = {} + + decomp_graph = DecompositionGraph( + operations, + gate_set, + fixed_decomps=fixed_decomps, + alt_decomps=alt_decomps, + ) + + # Find the efficient pathways to the target gate set + solutions = decomp_graph.solve() + + def is_solved_for(op): + return ( + op in solutions._all_op_indices + and solutions._all_op_indices[op] in solutions._visitor.distances + ) + + for op_node, op_node_idx in solutions._all_op_indices.items(): + if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: + d_node_idx = solutions._visitor.predecessors[op_node_idx] + decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl + + return decomp_graph_solution diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index b33d53404e..d5771a7fd6 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -20,7 +20,6 @@ from typing import Callable import jax -import jax.core import jax.numpy as jnp import pennylane as qml from jax._src.sharding_impls import UNSPECIFIED @@ -45,6 +44,10 @@ from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot from catalyst.device import extract_backend_info +from catalyst.device.qjit_device import COMPILER_OPERATIONS +from catalyst.from_plxpr.decompose import ( + GraphSolutionInterpreter, +) from catalyst.from_plxpr.qubit_handler import QubitHandler from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -177,11 +180,16 @@ def f(x): class WorkflowInterpreter(PlxprInterpreter): - """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" + """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" def __init__(self): self._pass_pipeline = [] self.qubit_handler = None + + # Compiler options for the new decomposition system + self.requires_compiler_decompose = False + self.decompose_gatesets = [] # queue of gatesets + super().__init__() @@ -199,7 +207,16 @@ def handle_qnode( consts = args[shots_len : n_consts + shots_len] non_const_args = args[shots_len + n_consts :] - closed_jaxpr = ClosedJaxpr(qfunc_jaxpr, consts) + closed_jaxpr = ( + ClosedJaxpr(qfunc_jaxpr, consts) + if not self.requires_compiler_decompose + else apply_compiler_decompose_to_plxpr( + inner_jaxpr=qfunc_jaxpr, + consts=consts, + ncargs=non_const_args, + tgatesets=self.decompose_gatesets, + ) + ) def calling_convention(*args): device_init_p.bind( @@ -216,6 +233,10 @@ def calling_convention(*args): device_release_p.bind() return retvals + if self.requires_compiler_decompose: + # Add gate_set attribute to the quantum kernel primitive + setattr(qnode, "decompose_gatesets", self.decompose_gatesets) + return quantum_kernel_p.bind( wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), *non_const_args, @@ -240,6 +261,28 @@ def calling_convention(*args): } +def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgatesets, ncargs): + """Apply the compiler-specific decomposition for a given JAXPR.""" + + # disable the graph decomposition optimization + is_graph = qml.decomposition.enabled_graph() + if is_graph: + qml.decomposition.disable_graph() + + # First perform the pre-mlir decomposition to simplify the jaxpr + # by decomposing high-level gates and templates + gate_set = COMPILER_OPERATIONS + list(set().union(*tgatesets)) + + final_jaxpr = qml.transforms.decompose.plxpr_transform( + inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs + ) + + if is_graph: + qml.decomposition.enable_graph() + + return final_jaxpr + + # pylint: disable-next=redefined-outer-name def register_transform(pl_transform, pass_name, decomposition): """Register pennylane transforms and their conversion to Catalyst transforms""" @@ -264,6 +307,61 @@ def handle_transform( non_const_args = args[args_slice] targs = args[targs_slice] + # If the transform is a decomposition transform + # and the graph-based decomposition is enabled + if ( + hasattr(pl_plxpr_transform, "__name__") + and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" + and qml.decomposition.enabled_graph() + ): + if not self.requires_compiler_decompose: + self.requires_compiler_decompose = True + + # A helper function to get the name of a pennylane operator + def get_operator_name(op): + """Get the name of a pennylane operator, handling wrapped operators. + + Note: Controlled and Adjoint ops aren't supported in `gate_set` + by PennyLane's DecompositionGraph; unit tests were added in PennyLane. + """ + if isinstance(op, str): + return op + + # Return NoNameOp if the operator has no _primitive.name attribute. + # This is to avoid errors when we capture the program + # as we deal with such ops later in the decomposition graph. + return getattr(op._primitive, "name", "NoNameOp") + + # Update the decompose_gatesets to be used by the quantum kernel primitive + tgateset = tkwargs.get("gate_set", []) + + # We treat decompose_gatesets as a queue of gatesets to be used + # by the decompose-lowering pass at MLIR + self.decompose_gatesets.insert(0, [get_operator_name(op) for op in tgateset]) + + # Note. We don't perform the compiler-specific decomposition here + # to be able to support multiple decomposition transforms + # and collect all the required gatesets + # as well as being able to support other transforms in between. + + # The compiler specific transformation will be performed + # in the qnode handler. + + # Add the decompose-lowering pass to the start of the pipeline + self._pass_pipeline.insert(0, Pass("decompose-lowering")) + + # We still need to construct and solve the graph based on + # the current jaxpr based on the current gateset + # but we don't rewrite the jaxpr at this stage. + + gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(inner_jaxpr, consts, *args) + + final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) + if catalyst_pass_name is None: # Use PL's ExpandTransformsInterpreter to expand this and any embedded # transform according to PL rules. It works by overriding the primitive @@ -283,10 +381,10 @@ def wrapper(*args): ) return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) - else: - # Apply the corresponding Catalyst pass counterpart - self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) - return self.eval(inner_jaxpr, consts, *non_const_args) + + # Apply the corresponding Catalyst pass counterpart + self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) + return self.eval(inner_jaxpr, consts, *non_const_args) # This is our registration factory for PL transforms. The loop below iterates @@ -297,6 +395,7 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) +# pylint: disable=too-many-instance-attributes class PLxPRToQuantumJaxprInterpreter(PlxprInterpreter): """ Unlike the previous interpreters which modified the getattr and setattr @@ -516,7 +615,6 @@ def handle_decomposition_rule(self, *, pyfun, func_jaxpr, is_qreg, num_params): """ Transform a quantum decomposition rule from PLxPR into JAXPR with quantum primitives. """ - if is_qreg: self.qubit_handler.insert_all_dangling_qubits() diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 8f754913e7..2ee080209d 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -395,7 +395,7 @@ def wrapper(*args, **kwargs): return wrapper -def decomposition_rule(func=None, *, is_qreg=False, num_params=0): +def decomposition_rule(func=None, *, is_qreg=True, num_params=0, op: qml.operation.Operator = None): """ Denotes the creation of a quantum definition in the intermediate representation. """ @@ -409,7 +409,11 @@ def decomposition_rule(func=None, *, is_qreg=False, num_params=0): @functools.wraps(func) def wrapper(*args, **kwargs): - jaxpr = jax.make_jaxpr(func)(*args, **kwargs) + if op is not None: + new_func = functools.partial(func,wires=op.wires, **op.hyperparameters) + jaxpr = jax.make_jaxpr(new_func)(*args, **kwargs) + else: + jaxpr = jax.make_jaxpr(func)(*args, **kwargs) decomprule_p.bind(pyfun=func, func_jaxpr=jaxpr, is_qreg=is_qreg, num_params=num_params) return wrapper @@ -590,7 +594,10 @@ def _decomposition_rule_lowering(ctx, *, pyfun, func_jaxpr, **_): """Lower a quantum decomposition rule into MLIR in a single step process. The step is the compilation of the definition of the function fn. """ - lower_callable(ctx, pyfun, func_jaxpr) + + # Set the visibility of the decomposition rule to public + # to avoid the elimination by the compiler + lower_callable(ctx, pyfun, func_jaxpr, public=True) return () diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 73d00cb9ef..7faa679c7e 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -26,6 +26,8 @@ from mlir_quantum.dialects._transform_ops_gen import ApplyRegisteredPassOp, NamedSequenceOp, YieldOp from mlir_quantum.dialects.catalyst import LaunchKernelOp +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval + def get_call_jaxpr(jaxpr): """Extracts the `call_jaxpr` from a JAXPR if it exists.""" "" @@ -44,7 +46,16 @@ def get_call_equation(jaxpr): def lower_jaxpr(ctx, jaxpr, context=None): - """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p""" + """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p + + Args: + ctx: LoweringRuleContext + jaxpr: JAXPR to be lowered + context: additional context to distinguish different FuncOps + + Returns: + FuncOp + """ equation = get_call_equation(jaxpr) call_jaxpr = equation.params["call_jaxpr"] callable_ = equation.params.get("fn") @@ -54,7 +65,8 @@ def lower_jaxpr(ctx, jaxpr, context=None): return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context) -def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): +# pylint: disable=too-many-arguments, too-many-positional-arguments +def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False): """Lowers _callable to MLIR. If callable_ is a qnode, then we will first create a module, then @@ -66,6 +78,8 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): ctx: LoweringRuleContext callable_: python function call_jaxpr: jaxpr representing callable_ + public: whether the visibility should be marked public + Returns: FuncOp """ @@ -73,25 +87,49 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): pipeline = tuple() if not isinstance(callable_, qml.QNode): - return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) + return get_or_create_funcop( + ctx, callable_, call_jaxpr, pipeline, context=context, public=public + ) return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) -def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None): - """Get funcOp from cache, or create it from scratch""" +# pylint: disable=too-many-arguments, too-many-positional-arguments +def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False): + """Get funcOp from cache, or create it from scratch + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + context: additional context to distinguish different FuncOps + public: whether the visibility should be marked public + + Returns: + FuncOp + """ if context is None: context = tuple() key = (callable_, *context, *pipeline) if func_op := get_cached(ctx, key): return func_op - func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr) + func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public) cache(ctx, key, func_op) return func_op -def lower_callable_to_funcop(ctx, callable_, call_jaxpr): - """Lower callable to either a FuncOp""" +def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): + """Lower callable to either a FuncOp + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + public: whether the visibility should be marked public + + Returns: + FuncOp + """ if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) @@ -101,10 +139,16 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr): name = callable_.__name__ else: name = callable_.func.__name__ + ".partial" + kwargs["name"] = name kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] kwargs["name_stack"] = ctx.name_stack + + # Make the visibility of the function public=True + # to avoid elimination by the compiler + kwargs["public"] = public + func_op = mlir.lower_jaxpr_to_fun(**kwargs) if isinstance(callable_, qml.QNode): @@ -135,6 +179,10 @@ def only_single_expval(): func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) + gateset = getattr(callable_, "decompose_gatesets", []) + if gateset: + func_op.attributes["decompose_gatesets"] = get_mlir_attribute_from_pyval(gateset) + return func_op diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py index eeb810009f..d539d758de 100644 --- a/frontend/catalyst/passes/builtin_passes.py +++ b/frontend/catalyst/passes/builtin_passes.py @@ -394,6 +394,25 @@ def circuit(x: float): return PassPipelineWrapper(qnode, "merge-rotations") +def decompose_lowering(qnode): + """ + Specify that the ``-decompose-lowering`` MLIR compiler pass + for applying the compiled decomposition rules to the QNode + recursively. + + Args: + fn (QNode): the QNode to apply the cancel inverses compiler pass to + + Returns: + ~.QNode: + + **Example** + // TODO: add example here + + """ + return PassPipelineWrapper(qnode, "decompose-lowering") + + def ions_decomposition(qnode): # pragma: nocover """ Specify that the ``--ions-decomposition`` MLIR compiler pass should be diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py index 37872a1f0c..2c59d53e68 100644 --- a/frontend/catalyst/passes/pass_api.py +++ b/frontend/catalyst/passes/pass_api.py @@ -377,6 +377,7 @@ def _API_name_to_pass_name(): "disentangle_cnot": "disentangle-CNOT", "disentangle_swap": "disentangle-SWAP", "merge_rotations": "merge-rotations", + "decompose_lowering": "decompose-lowering", "ions_decomposition": "ions-decomposition", "to_ppr": "to-ppr", "commute_ppr": "commute-ppr", diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 0c61109e9a..06a4f526ad 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -3,8 +3,10 @@ import pathlib import platform from copy import deepcopy +from functools import partial import jax +import numpy as np import pennylane as qml from pennylane.devices.capabilities import OperatorProperties from pennylane.typing import TensorLike @@ -273,7 +275,7 @@ def decompose_to_matrix(): def test_decomposition_rule_wire_param(): """Test decomposition rule with passing a parameter that is a wire/integer""" - @decomposition_rule + @decomposition_rule(is_qreg=False) def Hadamard0(wire: WiresLike): qml.Hadamard(wire) @@ -288,7 +290,7 @@ def circuit(_: float): Hadamard0(int) return qml.probs() - # CHECK: func.func private @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit + # CHECK: func.func public @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit # CHECK-NEXT: [[QUBIT_OUT:%.+]] = quantum.custom "Hadamard"() [[QBIT]] : !quantum.bit # CHECK-NEXT: return [[QUBIT_OUT]] : !quantum.bit @@ -303,7 +305,7 @@ def circuit(_: float): def test_decomposition_rule_gate_param_param(): """Test decomposition rule with passing a regular parameter""" - @decomposition_rule(num_params=1) + @decomposition_rule(is_qreg=False, num_params=1) def RX_on_wire_0(param: TensorLike, w0: WiresLike): qml.RX(param, wires=w0) @@ -316,7 +318,7 @@ def circuit_2(_: float): RX_on_wire_0(float, int) return qml.probs() - # CHECK: func.func private @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit + # CHECK: func.func public @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit # CHECK-NEXT: [[PARAM:%.+]] = tensor.extract [[PARAM_TENSOR]][] : tensor # CHECK-NEXT: [[QUBIT_1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit # CHECK-NEXT: return [[QUBIT_1]] : !quantum.bit @@ -336,7 +338,7 @@ def test_multiple_decomposition_rules(): @decomposition_rule def identity(): ... - @decomposition_rule(num_params=1) + @decomposition_rule(is_qreg=True) def all_wires_rx(param: TensorLike, w0: WiresLike, w1: WiresLike, w2: WiresLike): qml.RX(param, wires=w0) qml.RX(param, wires=w1) @@ -355,8 +357,8 @@ def circuit_3(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @identity - # CHECK: func.func private @all_wires_rx + # CHECK: func.func public @identity + # CHECK: func.func public @all_wires_rx print(circuit_3.mlir) qml.capture.disable() @@ -384,7 +386,7 @@ def circuit_4(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg + # CHECK: func.func public @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg # CHECK-NEXT: [[IDX_0:%.+]] = stablehlo.slice [[QUBITS]] [0:1] : (tensor<3xi64>) -> tensor<1xi64> # CHECK-NEXT: [[RIDX_0:%.+]] = stablehlo.reshape [[IDX_0]] : (tensor<1xi64>) -> tensor # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract [[RIDX_0]][] : tensor @@ -409,7 +411,7 @@ def shaped_wires_rule(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[1]) qml.RX(param, wires=wires[2]) - @decomposition_rule(num_params=1, is_qreg=False) + @decomposition_rule(is_qreg=False, num_params=1) def expanded_wires_rule(param: TensorLike, w1, w2, w3): shaped_wires_rule(param, [w1, w2, w3]) @@ -421,7 +423,7 @@ def circuit_5(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) + # CHECK: func.func public @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) print(circuit_5.mlir) qml.capture.disable() @@ -452,7 +454,7 @@ def circuit_6(): cond_RX(float, jax.core.ShapedArray((1,), int)) return qml.probs() - # CHECK: func.func private @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg + # CHECK: func.func public @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg # CHECK-NEXT: [[ZERO:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor # CHECK-NEXT: [[COND_TENSOR:%.+]] = stablehlo.compare NE, [[PARAM_TENSOR]], [[ZERO]], FLOAT : (tensor, tensor) -> tensor # CHECK-NEXT: [[COND:%.+]] = tensor.extract [[COND_TENSOR]][] : tensor @@ -479,17 +481,17 @@ def test_decomposition_rule_caller(): qml.capture.enable() @decomposition_rule(is_qreg=True) - def Op1_decomp(_: TensorLike, wires: WiresLike): + def rule_op1_decomp(_: TensorLike, wires: WiresLike): qml.Hadamard(wires=wires[0]) qml.Hadamard(wires=[1]) @decomposition_rule(is_qreg=True) - def Op2_decomp(param: TensorLike, wires: WiresLike): + def rule_op2_decomp(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[0]) def decomps_caller(param: TensorLike, wires: WiresLike): - Op1_decomp(param, wires) - Op2_decomp(param, wires) + rule_op1_decomp(param, wires) + rule_op2_decomp(param, wires) @qml.qjit(autograph=False) @qml.qnode(qml.device("lightning.qubit", wires=1)) @@ -500,11 +502,361 @@ def circuit_7(): decomps_caller(float, jax.core.ShapedArray((2,), int)) return qml.probs() - # CHECK: func.func private @Op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK: func.func private @Op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - + # CHECK: func.func public @rule_op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK: func.func public @rule_op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg print(circuit_7.mlir) qml.capture.disable() test_decomposition_rule_caller() + + +def test_decompose_gateset_without_graph(): + """Test the decompose transform to a target gate set without the graph decomposition.""" + + qml.capture.enable() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_8() -> tensor attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} + def circuit_8(): + return qml.expval(qml.Z(0)) + + print(circuit_8.mlir) + + qml.capture.disable() + + +test_decompose_gateset_without_graph() + + +def test_decompose_gateset_with_graph(): + """Test the decompose transform to a target gate set with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @simple_circuit_9() -> tensor attributes {decompose_gatesets + def simple_circuit_9(): + return qml.expval(qml.Z(0)) + + print(simple_circuit_9.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_9() -> tensor attributes {decompose_gatesets + def circuit_9(): + return qml.expval(qml.Z(0)) + + print(circuit_9.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_with_graph() + + +def test_decompose_gateset_operator_with_graph(): + """Test the decompose transform to a target gate set with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={qml.RX}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @simple_circuit_10() -> tensor attributes {decompose_gatesets + def simple_circuit_10(): + return qml.expval(qml.Z(0)) + + print(simple_circuit_10.mlir) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, gate_set={qml.RX, qml.RZ, "PauliZ", qml.PauliX, qml.Hadamard} + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_10() -> tensor attributes {decompose_gatesets + def circuit_10(): + return qml.expval(qml.Z(0)) + + print(circuit_10.mlir) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, gate_set={qml.RX, qml.RZ, qml.PauliZ, qml.PauliX, qml.Hadamard} + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_11() -> tensor attributes {decompose_gatesets + def circuit_11(): + return qml.expval(qml.Z(0)) + + print(circuit_11.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_operator_with_graph() + + +def test_decompose_gateset_with_rotxzx(): + """Test the decompose transform with a custom operator with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RotXZX"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @simple_circuit_12() -> tensor attributes {decompose_gatesets + def simple_circuit_12(): + return qml.expval(qml.Z(0)) + + print(simple_circuit_12.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_12() -> tensor attributes {decompose_gatesets + def circuit_12(): + return qml.expval(qml.Z(0)) + + print(circuit_12.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_with_rotxzx() + + +def test_decomposition_rule_name(): + """Test the name of the decomposition rule is not updated with circuit instantiation.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @decomposition_rule + def _ry_to_rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @decomposition_rule + def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @decomposition_rule + def _u2_phaseshift_rot_decomposition(phi, delta, wires, **__): + """Decomposition of U2 gate using Rot and PhaseShift gates.""" + pi_half = qml.math.ones_like(delta) * (np.pi / 2) + qml.Rot(delta, pi_half, -delta, wires=wires) + qml.PhaseShift(delta, wires=wires) + qml.PhaseShift(phi, wires=wires) + + @decomposition_rule + def _xzx_decompose(phi, theta, omega, wires, **__): + """Decomposition of Rot gate using RX and RZ gates in XZX format.""" + qml.RX(phi, wires=wires) + qml.RZ(theta, wires=wires) + qml.RX(omega, wires=wires) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_13() -> tensor attributes {decompose_gatesets + def circuit_13(): + _ry_to_rz_rx(float, int) + _rot_to_rz_ry_rz(float, float, float, int) + _u2_phaseshift_rot_decomposition(float, float, int) + _xzx_decompose(float, float, float, int) + return qml.expval(qml.Z(0)) + + # CHECK: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg + # CHECK: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + # CHECK: func.func public @_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg + # CHECK: func.func public @_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + print(circuit_13.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name() + + +def test_decomposition_rule_name_update(): + """Test the name of the decomposition rule is updated in the MLIR output.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RZ: 2, qml.RX: 1}) + def rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @qml.register_resources({qml.RZ: 2, qml.RY: 1}) + def rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @qml.register_resources({qml.RY: 1, qml.PhaseShift: 1}) + def ry_gp(wires: WiresLike, **__): + """Decomposition of PauliY gate using RY and GlobalPhase gates.""" + qml.RY(np.pi, wires=wires) + qml.GlobalPhase(-np.pi / 2, wires=wires) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RZ", "PhaseShift"}, + fixed_decomps={ + qml.RY: rz_rx, + qml.Rot: rz_ry_rz, + qml.PauliY: ry_gp, + }, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_14() -> tensor attributes {decompose_gatesets + def circuit_14(): + qml.RY(0.5, wires=0) + qml.Rot(0.1, 0.2, 0.3, wires=1) + qml.PauliY(wires=2) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @Rot_rule_rz_ry_rz_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @RY_rule_rz_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @PauliY_rule_ry_gp_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + print(circuit_14.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update() + + +def test_decomposition_rule_name_update_multi_qubits(): + """Test the name of the decomposition rule with multi-qubit gates.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", "CNOT", "Hadamard", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_15() -> tensor attributes {decompose_gatesets + def circuit_15(): + qml.SingleExcitation(0.5, wires=[0, 1]) + qml.SingleExcitationPlus(0.5, wires=[0, 1]) + qml.SingleExcitationMinus(0.5, wires=[0, 1]) + qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @SingleExcitationPlus_rule_single_excitation_plus_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @CY_rule_cy_wires_2(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @CRY_rule_cry_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @S_rule_s_phaseshift_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @PhaseShift_rule_phaseshift_to_rz_gp_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @RZ_rule_rz_to_ry_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @Rot_rule_rot_to_rz_ry_rz_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @DoubleExcitation_rule_doublexcit_wires_4(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @SingleExcitationMinus_rule_single_excitation_minus_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @SingleExcitation_rule_single_excitation_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + print(circuit_15.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update_multi_qubits() + + +def test_decomposition_rule_name_adjoint(): + """Test decomposition rule with qml.adjoint.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", "CZ", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK: public @circuit_16() -> tensor attributes {decompose_gatesets + def circuit_16(): + # CHECK-DAG: %1 = quantum.adjoint(%0) : !quantum.reg + # CHECK-DAG: %2 = quantum.adjoint(%1) : !quantum.reg + # CHECK-DAG: %3 = quantum.adjoint(%2) : !quantum.reg + # CHECK-DAG: %4 = quantum.adjoint(%3) : !quantum.reg + qml.adjoint(qml.CNOT)(wires=[0, 1]) + qml.adjoint(qml.Hadamard)(wires=2) + qml.adjoint(qml.RZ)(0.5, wires=3) + qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @CNOT_rule_cnot_to_cz_h_wires_2(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @Hadamard_rule_hadamard_to_rz_ry_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @SingleExcitation_rule_SingleExcitation_rule_single_excitation_decomp_wires_2_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + print(circuit_16.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_adjoint() + + +def test_decomposition_rule_name_ctrl(): + """Test decomposition rule with qml.ctrl.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RZ"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=5)) + # CHECK: public @circuit_17() -> tensor attributes {decompose_gatesets + def circuit_17(): + # CHECK: %out_qubits:2 = quantum.custom "CRY"(%cst) %1, %2 : !quantum.bit, !quantum.bit + qml.ctrl(qml.RY, control=0)(0.5, 1) + qml.ctrl(qml.PauliX, control=0)(1) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @RY_rule_ry_to_rz_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + print(circuit_17.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_ctrl() diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 7fd8618cc6..a6c2bae85e 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -18,6 +18,8 @@ """Lit tests for the PLxPR to JAXPR with quantum primitives pipeline""" +from functools import partial + import pennylane as qml @@ -362,3 +364,78 @@ def circuit(): test_pass_application() + + +def test_pass_decomposition(): + """Application of pass decorator with decomposition.""" + + dev = qml.device("null.qubit", wires=1) + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(dev) + def circuit1(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] + + print(circuit1.mlir) + + @qml.qjit(target="mlir") + @qml.transforms.cancel_inverses + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit2(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] + + print(circuit2.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit3(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" + # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[second_pass]] + + print(circuit3.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX"}) + @qml.transforms.cancel_inverses + @partial(qml.transforms.decompose, gate_set={"RZ"}) + @qml.transforms.merge_rotations + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(dev) + def circuit4(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: [[merge_rot:%.+]] = transform.apply_registered_pass "merge-rotations" to [[first_pass]] + # CHECK-NEXT: [[decomp_to_rz:%.+]] = transform.apply_registered_pass "decompose-lowering" to [[merge_rot]] + # CHECK-NEXT: [[remove_chained:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[decomp_to_rz]] + # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[remove_chained]] + + print(circuit4.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_pass_decomposition() diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py new file mode 100644 index 0000000000..4e21eadff4 --- /dev/null +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py @@ -0,0 +1,314 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests the ``decompose`` transform with the new Catalyst graph-based decomposition system.""" +from functools import partial + +import numpy as np +import pennylane as qml +import pytest + +pytestmark = pytest.mark.usefixtures("disable_capture") + + +class TestDecomposeGraphEnabled: + """Tests the decompose transform with graph enabled.""" + + @pytest.mark.integration + def test_mixed_gate_set_specification(self): + """Tests that the gate_set can be specified as both a type and a string.""" + + qml.decomposition.enable_graph() + + tape = qml.tape.QuantumScript([qml.RX(0.5, wires=[0]), qml.CNOT(wires=[0, 1])]) + [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RX", qml.CNOT}) + assert new_tape.operations == tape.operations + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_gate_set_targeted_decompositions(self): + """Tests that a simple circuit is correctly decomposed into different gate sets.""" + + qml.decomposition.enable_graph() + + tape = qml.tape.QuantumScript( + [ + qml.H(0), # non-parametric op + qml.Rot(0.1, 0.2, 0.3, wires=[0]), # parametric single-qubit op + qml.MultiRZ(0.5, wires=[0, 1, 2]), # parametric multi-qubit op + ] + ) + + [new_tape], _ = qml.transforms.decompose(tape, gate_set={"Hadamard", "CNOT", "RZ", "RY"}) + assert new_tape.operations == [ + # H is in the target gate set + qml.H(0), + # Rot decomposes to ZYZ + qml.RZ(0.1, wires=[0]), + qml.RY(0.2, wires=[0]), + qml.RZ(0.3, wires=[0]), + # Decomposition of MultiRZ + qml.CNOT(wires=[2, 1]), + qml.CNOT(wires=[1, 0]), + qml.RZ(0.5, wires=[0]), + qml.CNOT(wires=[1, 0]), + qml.CNOT(wires=[2, 1]), + ] + + [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RY", "RZ", "CZ", "GlobalPhase"}) + assert new_tape.operations == [ + # The H decomposes to RZ and RY + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + # Rot decomposes to ZYZ + qml.RZ(0.1, wires=[0]), + qml.RY(0.2, wires=[0]), + qml.RZ(0.3, wires=[0]), + # CNOT decomposes to H and CZ, where H decomposes to RZ and RY + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[2, 1]), + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + # second CNOT + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[1, 0]), + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + # The middle RZ + qml.RZ(0.5, wires=[0]), + # The last two CNOTs + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[1, 0]), + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[2, 1]), + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_fixed_decomp(self): + """Tests that a fixed decomposition rule is used instead of the stock ones.""" + + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) + def my_cnot(wires, **__): + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + qml.CZ(wires=wires) + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + + tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) + [new_tape], _ = qml.transforms.decompose( + tape, + gate_set={"RY", "RZ", "CZ", "Hadamard", "GlobalPhase"}, + fixed_decomps={qml.CNOT: my_cnot}, + ) + assert new_tape.operations == [ + qml.RY(np.pi / 2, wires=[0]), + qml.RZ(np.pi, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[1, 0]), + qml.RY(np.pi / 2, wires=[0]), + qml.RZ(np.pi, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_alt_decomp_not_used(self): + """Tests that alt_decomp isn't necessarily used if it's not efficient.""" + + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) + def my_cnot(wires, **__): + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + qml.CZ(wires=wires) + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + + tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) + [new_tape], _ = qml.transforms.decompose( + tape, + gate_set={"RY", "RZ", "CZ", "Hadamard", "GlobalPhase"}, + alt_decomps={qml.CNOT: [my_cnot]}, + ) + assert new_tape.operations == [ + qml.H(0), + qml.CZ(wires=[1, 0]), + qml.H(0), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_alt_decomp(self): + """Tests that alternative decomposition rules are used when applicable.""" + + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) + def my_cnot(wires, **__): + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + qml.CZ(wires=wires) + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + + tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) + [new_tape], _ = qml.transforms.decompose( + tape, + gate_set={"RY", "RZ", "CZ", "PauliZ", "GlobalPhase"}, + alt_decomps={qml.CNOT: [my_cnot]}, + ) + assert new_tape.operations == [ + qml.RY(np.pi / 2, wires=[0]), + qml.Z(0), + qml.CZ(wires=[1, 0]), + qml.RY(np.pi / 2, wires=[0]), + qml.Z(0), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_fall_back(self): + """Tests that op.decompose() is used for ops unsolved in the graph.""" + + qml.decomposition.enable_graph() + + class CustomOp(qml.operation.Operation): # pylint: disable=too-few-public-methods + """Dummy custom op.""" + + resource_keys = set() + + @property + def resource_params(self): + """Dummy resource params.""" + + return {} + + def decomposition(self): + """Decomposition of CustomOp into H-CNOT-H.""" + + return [qml.H(self.wires[1]), qml.CNOT(self.wires), qml.H(self.wires[1])] + + @qml.register_resources({qml.CZ: 1}) + def my_decomp(wires, **__): + qml.CZ(wires=wires) + + tape = qml.tape.QuantumScript([CustomOp(wires=[0, 1])]) + [new_tape], _ = qml.transforms.decompose( + tape, gate_set={"CNOT", "Hadamard"}, fixed_decomps={CustomOp: my_decomp} + ) + assert new_tape.operations == [qml.H(1), qml.CNOT(wires=[0, 1]), qml.H(1)] + + qml.decomposition.disable_graph() + + # @pytest.mark.integration + # def test_controlled_decomp(self): + # """Tests decomposing a controlled operation.""" + + # # The C(MultiRZ) is decomposed by applying control on the base decomposition. + # # The decomposition of MultiRZ contains two CNOTs + # # So this also tests applying control on an PauliX based operation + # # The decomposition of MultiRZ also contains an RZ gate + # # So this also tests logic involving custom controlled operators. + # ops = [qml.ctrl(qml.MultiRZ(0.5, wires=[0, 1]), control=[2])] + # tape = qml.tape.QuantumScript(ops) + # [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RZ", "CNOT", "Toffoli"}) + # assert new_tape.operations == [ + # # Decomposition of C(CNOT) + # qml.Toffoli(wires=[2, 1, 0]), + # # Decomposition of C(RZ) -> CRZ + # qml.RZ(0.25, wires=[0]), + # qml.CNOT(wires=[2, 0]), + # qml.RZ(-0.25, wires=[0]), + # qml.CNOT(wires=[2, 0]), + # # Decomposition of C(CNOT) + # qml.Toffoli(wires=[2, 1, 0]), + # ] + + # @pytest.mark.integration + # def test_adjoint_decomp(self): + # """Tests decomposing an adjoint operation.""" + + # class CustomOp(qml.operation.Operator): # pylint: disable=too-few-public-methods + + # resource_keys = set() + + # @property + # def resource_params(self) -> dict: + # return {} + + # @qml.register_resources({qml.RX: 1, qml.RY: 1, qml.RZ: 1}) + # def custom_decomp(theta, phi, omega, wires): + # qml.RX(theta, wires[0]) + # qml.RY(phi, wires[0]) + # qml.RZ(omega, wires[0]) + + # tape = qml.tape.QuantumScript( + # [ + # qml.adjoint(qml.RX(0.5, wires=[0])), + # qml.adjoint(qml.adjoint(qml.MultiRZ(0.5, wires=[0, 1]))), + # qml.adjoint(CustomOp(0.1, 0.2, 0.3, wires=[0])), + # ] + # ) + # [new_tape], _ = qml.transforms.decompose( + # tape, gate_set={"CNOT", "RX", "RY", "RZ"}, fixed_decomps={CustomOp: custom_decomp} + # ) + # assert new_tape.operations == [ + # qml.RX(-0.5, wires=[0]), + # qml.CNOT(wires=[1, 0]), + # qml.RZ(0.5, wires=[0]), + # qml.CNOT(wires=[1, 0]), + # qml.RZ(-0.3, wires=[0]), + # qml.RY(-0.2, wires=[0]), + # qml.RX(-0.1, wires=[0]), + # ] + + +def test_decompose_qnode(): + """Tests that the decompose transform works with a QNode.""" + + @partial(qml.transforms.decompose, gate_set={"CZ", "Hadamard"}) + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(): + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0)) + + res = circuit() + assert qml.math.allclose(res, 1.0) diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 00f33d8fa4..33b25c0179 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -30,6 +30,7 @@ std::unique_ptr createRemoveChainedSelfInversePass(); std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createMergeRotationsPass(); +std::unique_ptr createDecomposeLoweringPass(); std::unique_ptr createDisentangleCNOTPass(); std::unique_ptr createDisentangleSWAPPass(); std::unique_ptr createIonsDecompositionPass(); diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index f0a344190e..918b2032dc 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -110,6 +110,12 @@ def MergeRotationsPass : Pass<"merge-rotations"> { let constructor = "catalyst::createMergeRotationsPass()"; } +def DecomposeLoweringPass : Pass<"decompose-lowering"> { + let summary = "Replace quantum operations with compiled decomposition rules."; + + let constructor = "catalyst::createDecomposeLoweringPass()"; +} + def DisentangleCNOTPass : Pass<"disentangle-CNOT"> { let summary = "Replace a CNOT gate with two single qubit gates whenever possible."; diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 0e7be8337b..73f8327216 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -68,6 +68,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createRegisterInactiveCallbackPass); mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); mlir::registerPass(catalyst::createMergeRotationsPass); + mlir::registerPass(catalyst::createDecomposeLoweringPass); mlir::registerPass(catalyst::createScatterLoweringPass); mlir::registerPass(catalyst::createStablehloLegalizeControlFlowPass); mlir::registerPass(catalyst::createStablehloLegalizeSortPass); diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index 3a244ac4d6..ddc54e3148 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ file(GLOB SRC SplitMultipleTapes.cpp merge_rotation.cpp MergeRotationsPatterns.cpp + decompose_lowering.cpp DisentangleSWAP.cpp DisentangleCNOT.cpp ions_decompositions.cpp diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp new file mode 100644 index 0000000000..8f0d6b638e --- /dev/null +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -0,0 +1,48 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "decompose-lowering" + +#include "Catalyst/IR/CatalystDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" + +using namespace llvm; +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { +namespace quantum { +#define GEN_PASS_DEF_DECOMPOSELOWERINGPASS +#define GEN_PASS_DECL_DECOMPOSELOWERINGPASS +#include "Quantum/Transforms/Passes.h.inc" + +struct DecomposeLoweringPass : public impl::DecomposeLoweringPassBase { + using impl::DecomposeLoweringPassBase::DecomposeLoweringPassBase; + + void runOnOperation() override { llvm::errs() << "Decompose Lowering Pass!\n"; } +}; + +} // namespace quantum + +std::unique_ptr createDecomposeLoweringPass() +{ + return std::make_unique(); +} + +} // namespace catalyst diff --git a/runtime/include/RuntimeCAPI.h b/runtime/include/RuntimeCAPI.h index 87414df96f..b248979a11 100644 --- a/runtime/include/RuntimeCAPI.h +++ b/runtime/include/RuntimeCAPI.h @@ -63,6 +63,7 @@ void __catalyst__qis__RX(double, QUBIT *, const Modifiers *); void __catalyst__qis__RY(double, QUBIT *, const Modifiers *); void __catalyst__qis__RZ(double, QUBIT *, const Modifiers *); void __catalyst__qis__Rot(double, double, double, QUBIT *, const Modifiers *); +void __catalyst__qis__RotXZX(double, double, double, QUBIT *, const Modifiers *); void __catalyst__qis__CNOT(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__CY(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__CZ(QUBIT *, QUBIT *, const Modifiers *); diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 94ed41e13a..893af5a7fd 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -627,6 +627,14 @@ void __catalyst__qis__Rot(double phi, double theta, double omega, QUBIT *qubit, MODIFIERS_ARGS(modifiers)); } +void __catalyst__qis__RotXZX(double phi, double theta, double omega, QUBIT *qubit, + const Modifiers *modifiers) +{ + getQuantumDevicePtr()->NamedOperation("RotXZX", {phi, theta, omega}, + {reinterpret_cast(qubit)}, + MODIFIERS_ARGS(modifiers)); +} + void __catalyst__qis__CNOT(QUBIT *control, QUBIT *target, const Modifiers *modifiers) { RT_FAIL_IF(control == target,