diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 318a1afb07..fa1bdf853d 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -2,6 +2,13 @@
New features since last release
+* A new experimental decomposition system is introduced in Catalyst enabling the
+ PennyLane's graph-based decomposition and MLIR-based lowering of decomposition rules.
+ This feature is integrated with PennyLane program capture and graph-based decomposition
+ including support for custom decomposition rules and operators.
+ [(#2001)](https://github.com/PennyLaneAI/catalyst/pull/2001)
+ [(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029)
+
* Catalyst now supports dynamic wire allocation with ``qml.allocate()`` and
``qml.deallocate()`` when program capture is enabled.
[(#2002)](https://github.com/PennyLaneAI/catalyst/pull/2002)
diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py
index df1fc2675d..dec5a8176b 100644
--- a/frontend/catalyst/device/qjit_device.py
+++ b/frontend/catalyst/device/qjit_device.py
@@ -108,6 +108,16 @@
RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"]
+# A list of operations that can be represented
+# in the Catalyst compiler. This will be a superset of
+# the operations supported by the runtime.
+# FIXME: ops with OpName(params, wires) signatures can be
+# represented in the Catalyst compiler. Unfortunately,
+# the signature info is not sufficient as there are
+# templates with the same signature that should be
+# disambiguated.
+COMPILER_OPERATIONS = RUNTIME_OPERATIONS
+
# The runtime interface does not care about specific gate properties, so set them all to True.
RUNTIME_OPERATIONS = {
op: OperatorProperties(invertible=True, controllable=True, differentiable=True)
diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py
new file mode 100644
index 0000000000..232d2ab886
--- /dev/null
+++ b/frontend/catalyst/from_plxpr/decompose.py
@@ -0,0 +1,395 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A transform for the new MLIR-based Catalyst decomposition system.
+"""
+
+
+from __future__ import annotations
+
+import functools
+import inspect
+import types
+from collections.abc import Callable
+from typing import get_type_hints
+
+import jax
+import pennylane as qml
+
+# DecompRuleInterpreter:
+from pennylane.decomposition import DecompositionGraph
+from pennylane.typing import TensorLike
+from pennylane.wires import WiresLike
+
+from catalyst.jax_primitives import decomposition_rule
+
+
+# pylint: disable=too-few-public-methods
+class DecompRuleInterpreter(qml.capture.PlxprInterpreter):
+ """Interpreter for getting the decomposition graph solution
+ from a jaxpr when program capture is enabled.
+
+ This interpreter captures all operations seen during the interpretation
+ and builds a decomposition graph to find efficient decomposition pathways
+ to a target gate set.
+
+ This interpreter should be used with `qml.decomposition.enable_graph()`
+ to enable graph-based decomposition.
+
+ Note that this doesn't actually decompose the operations during interpretation.
+ It only captures the operations and builds the decomposition graph.
+ The actual decomposition is done later in the MLIR decomposition pass.
+
+ See also: :class:`~.DecompositionGraph`.
+
+ Args:
+ gate_set (set[Operator] or None): The target gate set to decompose to
+ fixed_decomps (dict or None): A dictionary of fixed decomposition rules
+ to use in the decomposition graph.
+ alt_decomps (dict or None): A dictionary of alternative decomposition rules
+ to use in the decomposition graph.
+
+ Raises:
+ TypeError: if graph-based decomposition is not enabled.
+ """
+
+ # A mapping from operation names to the number of wires they act on
+ # and the number of parameters they have.
+ # This is used when the operation is not in the captured operations
+ # but we still need to create a decomposition rule for it.
+ #
+ # Note that some operations have a variable number of wires,
+ # e.g., MultiRZ, GlobalPhase. For these, we set the number
+ # of wires to -1 to indicate a variable number.
+ #
+ # This will require a copy of the function to be made
+ # when creating the decomposition rule to avoid mutating
+ # the original function with attributes like num_wires.
+ compiler_ops_num_wires: dict[str, tuple[int, int]] = {
+ "CNOT": (2, 0),
+ "ControlledPhaseShift": (2, 1),
+ "CRot": (2, 3),
+ "CRX": (2, 1),
+ "CRY": (2, 1),
+ "CRZ": (2, 1),
+ "CSWAP": (3, 0),
+ "CY": (2, 0),
+ "CZ": (2, 0),
+ "Hadamard": (1, 0),
+ "Identity": (1, 0),
+ "IsingXX": (2, 1),
+ "IsingXY": (2, 1),
+ "IsingYY": (2, 1),
+ "IsingZZ": (2, 1),
+ "SingleExcitation": (2, 1),
+ "DoubleExcitation": (4, 1),
+ "ISWAP": (2, 0),
+ "PauliX": (1, 0),
+ "PauliY": (1, 0),
+ "PauliZ": (1, 0),
+ "PhaseShift": (1, 1),
+ "PSWAP": (2, 1),
+ "Rot": (1, 3),
+ "RX": (1, 1),
+ "RY": (1, 1),
+ "RZ": (1, 1),
+ "S": (1, 0),
+ "SWAP": (2, 0),
+ "T": (1, 0),
+ "Toffoli": (3, 0),
+ "U1": (1, 1),
+ "U2": (1, 2),
+ "U3": (1, 3),
+ "MultiRZ": (-1, 1),
+ "GlobalPhase": (-1, 1),
+ }
+
+ def __init__(
+ self,
+ *,
+ gate_set=None,
+ fixed_decomps=None,
+ alt_decomps=None,
+ ): # pylint: disable=too-many-arguments
+
+ if not qml.decomposition.enabled_graph(): # pragma: no cover
+ raise TypeError(
+ "The DecompRuleInterpreter can only be used when"
+ "graph-based decomposition is enabled."
+ )
+
+ self._gate_set = gate_set
+ self._fixed_decomps = fixed_decomps
+ self._alt_decomps = alt_decomps
+
+ self._captured = False
+ self._operations = set()
+ self._decomp_graph_solution = {}
+
+ def interpret_operation(self, op: "qml.operation.Operator"):
+ """Interpret a PennyLane operation instance.
+
+ Args:
+ op (Operator): a pennylane operator instance
+
+ Returns:
+ Any
+
+ This method is only called when the operator's output is a dropped variable,
+ so the output will not affect later equations in the circuit.
+
+ We cache the list of operations seen during the interpretation
+ to build the decomposition graph in the later stages.
+
+ See also: :meth:`~.interpret_operation_eqn`.
+
+ """
+
+ self._operations.add(op)
+ data, struct = jax.tree_util.tree_flatten(op)
+ return jax.tree_util.tree_unflatten(struct, data)
+
+ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess"):
+ """Interpret a measurement process instance.
+
+ Args:
+ measurement (MeasurementProcess): a measurement instance.
+
+ See also :meth:`~.interpret_measurement_eqn`.
+
+ """
+
+ # If we haven't captured and compiled the decomposition rules yet,
+ if not self._captured:
+ # Capture the current operations and mark as captured
+ self._captured = True
+
+ # Solve the decomposition graph to get the decomposition rules
+ # for all the captured operations
+ # I know it's a bit hacky to do this here, but it's the only
+ # place where we can be sure that we have seen all operations
+ # in the circuit before the measurement.
+ # TODO: Find a better way to do this.
+ self._decomp_graph_solution = _solve_decomposition_graph(
+ self._operations,
+ self._gate_set,
+ fixed_decomps=self._fixed_decomps,
+ alt_decomps=self._alt_decomps,
+ )
+
+ # Create decomposition rules for each operation in the solution
+ # and compile them to Catalyst JAXPR decomposition rules
+ for op, rule in self._decomp_graph_solution.items():
+ # Get number of wires if exists
+ op_num_wires = op.op.params.get("num_wires", None)
+ if (
+ o := next(
+ (
+ o
+ for o in self._operations
+ if o.name == op.op.name and len(o.wires) == op_num_wires
+ ),
+ None,
+ )
+ ) is not None:
+ num_wires, num_params = self.compiler_ops_num_wires[op.op.name]
+ _create_decomposition_rule(
+ rule,
+ op_name=op.op.name,
+ num_wires=len(o.wires),
+ num_params=num_params,
+ requires_copy=num_wires == -1,
+ )
+ elif op.op.name in self.compiler_ops_num_wires:
+ # In this part, we need to handle the case where an operation in
+ # the decomposition graph solution is not in the captured operations.
+ # This can happen if the operation is not directly called
+ # in the circuit, but is used inside a decomposition rule.
+ # In this case, we fall back to using the compiler_ops_num_wires
+ # dictionary to get the number of wires.
+ num_wires, num_params = self.compiler_ops_num_wires[op.op.name]
+ _create_decomposition_rule(
+ rule,
+ op_name=op.op.name,
+ num_wires=num_wires,
+ num_params=num_params,
+ requires_copy=num_wires == -1,
+ )
+ else: # pragma: no cover
+ raise ValueError(f"Could not capture {op} without the number of wires.")
+
+ data, struct = jax.tree_util.tree_flatten(measurement)
+ return jax.tree_util.tree_unflatten(struct, data)
+
+
+def _create_decomposition_rule(
+ func: Callable, op_name: str, num_wires: int, num_params: int, requires_copy: bool = False
+):
+ """Create a decomposition rule from a callable.
+
+ See also: :func:`~.decomposition_rule`.
+
+ Args:
+ func (Callable): The decomposition function.
+ op_name (str): The name of the operation to decompose.
+ num_wires (int): The number of wires the operation acts on.
+ num_params (int): The number of parameters the operation takes.
+ requires_copy (bool): Whether to create a copy of the function
+ to avoid mutating the original. This is required for operations
+ with a variable number of wires (e.g., MultiRZ, GlobalPhase).
+ """
+
+ sig_func = inspect.signature(func)
+ type_hints = get_type_hints(func)
+
+ args = {}
+ for name in sig_func.parameters.keys():
+ typ = type_hints.get(name, None)
+
+ # Skip tailing args or kwargs in the rules
+ if name in ("__", "_"):
+ continue
+
+ # TODO: This is a temporary solution until all rules have proper type annotations.
+ # Why? Because we need to pass the correct types to the decomposition_rule
+ # function to capture the rule correctly with JAX.
+ possible_names_for_single_param = {
+ "param",
+ "angle",
+ "phi",
+ "omega",
+ "theta",
+ "weight",
+ }
+ possible_names_for_multi_params = {
+ "params",
+ "angles",
+ "weights",
+ }
+
+ # TODO: Support work-wires when it's supported in Catalyst.
+ possible_names_for_wires = {"wires", "wire", "control_wires", "target_wires"}
+
+ if typ is TensorLike or name in possible_names_for_multi_params:
+ args[name] = qml.math.array([0.0] * num_params, like="jax", dtype=float)
+ elif typ is float or name in possible_names_for_single_param:
+ # TensorLike is a Union of float, int, array-like, so we use float here
+ # to cover the most common case as the JAX tracer doesn't like Union types
+ # and we don't have the actual values at this point.
+ args[name] = float
+ elif typ is WiresLike or name in possible_names_for_wires:
+ # Pass a dummy array of zeros with the correct number of wires
+ # This is required for the decomposition_rule to work correctly
+ # as it expects an array-like input for wires
+ args[name] = qml.math.array([0] * num_wires, like="jax")
+ elif typ is int: # pragma: no cover
+ # This is only for cases where the rule has an int parameter
+ # e.g., dimension in some gates. Not that common though!
+ # We cover this when adding end-to-end tests for rules
+ # in the MLIR PR.
+ args[name] = int
+ else: # pragma: no cover
+ raise ValueError(
+ f"Unsupported type annotation {typ} for parameter {name} in func {func}."
+ )
+
+ func_cp = make_def_copy(func) if requires_copy else func
+
+ # Set custom attributes for the decomposition rule
+ # These attributes are used in the MLIR decomposition pass
+ # to identify the target gate and the number of wires
+ setattr(func_cp, "target_gate", op_name)
+ setattr(func_cp, "num_wires", num_wires)
+
+ if requires_copy:
+ # Include number of wires in the function name to avoid name clashes
+ # when the same rule is compiled multiple times with different number of wires
+ # (e.g., MultiRZ, GlobalPhase)
+ func_cp.__name__ += f"_wires_{num_wires}" # pylint: disable=protected-access
+
+ return decomposition_rule(func_cp)(**args)
+
+
+# pylint: disable=protected-access
+def _solve_decomposition_graph(operations, gate_set, fixed_decomps, alt_decomps):
+ """Get the decomposition graph solution for the given operations and gate set.
+
+ TODO: Extend `DecompGraphSolution` API and avoid accessing protected members
+ directly in this function.
+
+ Args:
+ operations (set[Operator]): The set of operations to decompose.
+ gate_set (set[Operator]): The target gate set to decompose to.
+ fixed_decomps (dict or None): A dictionary of fixed decomposition rules
+ to use in the decomposition graph.
+ alt_decomps (dict or None): A dictionary of alternative decomposition rules
+ to use in the decomposition graph.
+
+ Returns:
+ dict: A dictionary mapping operations to their decomposition rules.
+ """
+
+ # decomp_graph_solution
+ decomp_graph_solution = {}
+
+ decomp_graph = DecompositionGraph(
+ operations,
+ gate_set,
+ fixed_decomps=fixed_decomps,
+ alt_decomps=alt_decomps,
+ )
+
+ # Find the efficient pathways to the target gate set
+ solutions = decomp_graph.solve()
+
+ def is_solved_for(op):
+ return (
+ op in solutions._all_op_indices
+ and solutions._all_op_indices[op] in solutions._visitor.distances
+ )
+
+ for op_node, op_node_idx in solutions._all_op_indices.items():
+ if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors:
+ d_node_idx = solutions._visitor.predecessors[op_node_idx]
+ decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl
+
+ return decomp_graph_solution
+
+
+# pylint: disable=protected-access
+def make_def_copy(func):
+ """Create a copy of a Python definition to avoid mutating the original.
+
+ This is especially useful when compiling decomposition rules with
+ parametric number of wires (e.g., MultiRZ, GlobalPhase) multiple times,
+ as the compilation process may add attributes to the function that
+ can interfere with subsequent compilations.
+
+ Args:
+ func (Callable): The function to copy.
+
+ Returns:
+ Callable: A copy of the original function with the same attributes.
+ """
+ # Create a new function object with the same code, globals, name, defaults, and closure
+ func_copy = types.FunctionType(
+ func.__code__,
+ func.__globals__,
+ name=func.__name__,
+ argdefs=func.__defaults__,
+ closure=func.__closure__,
+ )
+
+ # Now, we create and update the wrapper to copy over attributes like docstring, module, etc.
+ return functools.update_wrapper(func_copy, func)
diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py
index aba015a42e..a9fcbf0784 100644
--- a/frontend/catalyst/from_plxpr/from_plxpr.py
+++ b/frontend/catalyst/from_plxpr/from_plxpr.py
@@ -15,12 +15,13 @@
This submodule defines a utility for converting plxpr into Catalyst jaxpr.
"""
# pylint: disable=protected-access
+# pylint: disable=too-many-lines
+
from copy import copy
from functools import partial
from typing import Callable
import jax
-import jax.core
import jax.numpy as jnp
import pennylane as qml
from jax._src.sharding_impls import UNSPECIFIED
@@ -45,6 +46,8 @@
from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot
from catalyst.device import extract_backend_info
+from catalyst.device.qjit_device import COMPILER_OPERATIONS
+from catalyst.from_plxpr.decompose import DecompRuleInterpreter
from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder, get_in_qubit_values
from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config
from catalyst.jax_primitives import (
@@ -177,11 +180,16 @@ def f(x):
class WorkflowInterpreter(PlxprInterpreter):
- """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant."""
+ """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant."""
def __init__(self):
self._pass_pipeline = []
self.init_qreg = None
+
+ # Compiler options for the new decomposition system
+ self.requires_decompose_lowering = False
+ self.decompose_tkwargs = {} # target gateset
+
super().__init__()
@@ -201,7 +209,24 @@ def handle_qnode(
consts = args[shots_len : n_consts + shots_len]
non_const_args = args[shots_len + n_consts :]
- closed_jaxpr = ClosedJaxpr(qfunc_jaxpr, consts)
+ closed_jaxpr = (
+ ClosedJaxpr(qfunc_jaxpr, consts)
+ if not self.requires_decompose_lowering
+ else _apply_compiler_decompose_to_plxpr(
+ inner_jaxpr=qfunc_jaxpr,
+ consts=consts,
+ ncargs=non_const_args,
+ tgateset=list(self.decompose_tkwargs.get("gate_set", [])),
+ )
+ )
+
+ if self.requires_decompose_lowering:
+ closed_jaxpr = _collect_and_compile_graph_solutions(
+ inner_jaxpr=closed_jaxpr.jaxpr,
+ consts=closed_jaxpr.consts,
+ tkwargs=self.decompose_tkwargs,
+ ncargs=non_const_args,
+ )
def calling_convention(*args):
device_init_p.bind(
@@ -220,6 +245,16 @@ def calling_convention(*args):
device_release_p.bind()
return retvals
+ if self.requires_decompose_lowering:
+ # Add gate_set attribute to the quantum kernel primitive
+ # decompose_gatesets is treated as a queue of gatesets to be used
+ # but we only support a single gateset for now in from_plxpr
+ # as supporting multiple gatesets requires an MLIR/C++ graph-decomposition
+ # implementation. The current Python implementation cannot be mixed
+ # with other transforms in between.
+ gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])]
+ setattr(qnode, "decompose_gatesets", [gateset])
+
return quantum_kernel_p.bind(
wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info),
*non_const_args,
@@ -268,6 +303,51 @@ def handle_transform(
non_const_args = args[args_slice]
targs = args[targs_slice]
+ # If the transform is a decomposition transform
+ # and the graph-based decomposition is enabled
+ if (
+ hasattr(pl_plxpr_transform, "__name__")
+ and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr"
+ and qml.decomposition.enabled_graph()
+ ):
+ if not self.requires_decompose_lowering:
+ self.requires_decompose_lowering = True
+ else:
+ raise NotImplementedError(
+ "Multiple decomposition transforms are not yet supported."
+ )
+
+ # Update the decompose_gateset to be used by the quantum kernel primitive
+ # TODO: we originally wanted to treat decompose_gateset as a queue of
+ # gatesets to be used by the decompose-lowering pass at MLIR
+ # but this requires a C++ implementation of the graph-based decomposition
+ # which doesn't exist yet.
+ self.decompose_tkwargs = tkwargs
+
+ # Note. We don't perform the compiler-specific decomposition here
+ # to be able to support multiple decomposition transforms
+ # and collect all the required gatesets
+ # as well as being able to support other transforms in between.
+
+ # The compiler specific transformation will be performed
+ # in the qnode handler.
+
+ # Add the decompose-lowering pass to the start of the pipeline
+ self._pass_pipeline.insert(0, Pass("decompose-lowering"))
+
+ # We still need to construct and solve the graph based on
+ # the current jaxpr based on the current gateset
+ # but we don't rewrite the jaxpr at this stage.
+
+ # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)
+
+ # def gds_wrapper(*args):
+ # return gds_interpreter.eval(inner_jaxpr, consts, *args)
+
+ # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
+ # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
+ return self.eval(inner_jaxpr, consts, *non_const_args)
+
if catalyst_pass_name is None:
# Use PL's ExpandTransformsInterpreter to expand this and any embedded
# transform according to PL rules. It works by overriding the primitive
@@ -287,10 +367,10 @@ def wrapper(*args):
)
return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args)
- else:
- # Apply the corresponding Catalyst pass counterpart
- self._pass_pipeline.insert(0, Pass(catalyst_pass_name))
- return self.eval(inner_jaxpr, consts, *non_const_args)
+
+ # Apply the corresponding Catalyst pass counterpart
+ self._pass_pipeline.insert(0, Pass(catalyst_pass_name))
+ return self.eval(inner_jaxpr, consts, *non_const_args)
# This is our registration factory for PL transforms. The loop below iterates
@@ -301,6 +381,7 @@ def wrapper(*args):
register_transform(pl_transform, pass_name, decomposition)
+# pylint: disable=too-many-instance-attributes
class PLxPRToQuantumJaxprInterpreter(PlxprInterpreter):
"""
Unlike the previous interpreters which modified the getattr and setattr
@@ -571,7 +652,6 @@ def handle_decomposition_rule(self, *, pyfun, func_jaxpr, is_qreg, num_params):
"""
Transform a quantum decomposition rule from PLxPR into JAXPR with quantum primitives.
"""
-
if is_qreg:
self.init_qreg.insert_all_dangling_qubits()
@@ -848,3 +928,86 @@ def trace_from_pennylane(
jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs)
return jaxpr, out_type, out_treedef, sig
+
+
+def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs):
+ """Apply the compiler-specific decomposition for a given JAXPR.
+
+ Args:
+ inner_jaxpr (Jaxpr): The input JAXPR to be decomposed.
+ consts (list): The constants used in the JAXPR.
+ tgateset (list): A list of target gateset for decomposition.
+ ncargs (list): Non-constant arguments for the JAXPR.
+ qargs (list): All arguments including constants and non-constants.
+
+ Returns:
+ ClosedJaxpr: The decomposed JAXPR.
+ """
+
+ # Disable the graph decomposition optimization
+
+ # Why? Because for the compiler-specific decomposition we want to
+ # only decompose higher-level gates and templates that only have
+ # a single decomposition, and not do any further optimization
+ # based on the graph solution.
+ # Besides, the graph-based decomposition is not supported
+ # yet in from_plxpr for most gates and templates.
+
+ # TODO: Enable the graph-based decomposition
+ qml.decomposition.disable_graph()
+
+ # First perform the pre-mlir decomposition to simplify the jaxpr
+ # by decomposing high-level gates and templates
+ gate_set = set(COMPILER_OPERATIONS + tgateset)
+
+ final_jaxpr = qml.transforms.decompose.plxpr_transform(
+ inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs
+ )
+
+ qml.decomposition.enable_graph()
+
+ return final_jaxpr
+
+
+def _collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs):
+ """Collect and compile graph solutions for a given JAXPR.
+
+ This function uses the DecompRuleInterpreter to evaluate
+ the input JAXPR and obtain a new JAXPR that incorporates
+ the graph-based decomposition solutions.
+
+ This function doesn't modify the underlying quantum function
+ but rather constructs a new JAXPR with decomposition rules.
+
+ Args:
+ inner_jaxpr (Jaxpr): The input JAXPR to be decomposed.
+ consts (list): The constants used in the JAXPR.
+ tkwargs (list): The keyword arguments of the decompose transform.
+ ncargs (list): Non-constant arguments for the JAXPR.
+
+ Returns:
+ ClosedJaxpr: The decomposed JAXPR.
+ """
+ gds_interpreter = DecompRuleInterpreter(**tkwargs)
+
+ def gds_wrapper(*args):
+ return gds_interpreter.eval(inner_jaxpr, consts, *args)
+
+ final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs)
+
+ return final_jaxpr
+
+
+def _get_operator_name(op):
+ """Get the name of a pennylane operator, handling wrapped operators.
+
+ Note: Controlled and Adjoint ops aren't supported in `gate_set`
+ by PennyLane's DecompositionGraph; unit tests were added in PennyLane.
+ """
+ if isinstance(op, str):
+ return op
+
+ # Return NoNameOp if the operator has no _primitive.name attribute.
+ # This is to avoid errors when we capture the program
+ # as we deal with such ops later in the decomposition graph.
+ return getattr(op._primitive, "name", "NoNameOp")
diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index 8f754913e7..fec8e41684 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -395,7 +395,7 @@ def wrapper(*args, **kwargs):
return wrapper
-def decomposition_rule(func=None, *, is_qreg=False, num_params=0):
+def decomposition_rule(func=None, *, is_qreg=True, num_params=0):
"""
Denotes the creation of a quantum definition in the intermediate representation.
"""
@@ -590,7 +590,10 @@ def _decomposition_rule_lowering(ctx, *, pyfun, func_jaxpr, **_):
"""Lower a quantum decomposition rule into MLIR in a single step process.
The step is the compilation of the definition of the function fn.
"""
- lower_callable(ctx, pyfun, func_jaxpr)
+
+ # Set the visibility of the decomposition rule to public
+ # to avoid the elimination by the compiler
+ lower_callable(ctx, pyfun, func_jaxpr, public=True)
return ()
diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py
index 411aba7cc6..1227b7f1b5 100644
--- a/frontend/catalyst/jax_primitives_utils.py
+++ b/frontend/catalyst/jax_primitives_utils.py
@@ -26,6 +26,8 @@
from mlir_quantum.dialects._transform_ops_gen import ApplyRegisteredPassOp, NamedSequenceOp, YieldOp
from mlir_quantum.dialects.catalyst import LaunchKernelOp
+from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
+
def get_call_jaxpr(jaxpr):
"""Extracts the `call_jaxpr` from a JAXPR if it exists.""" ""
@@ -44,7 +46,16 @@ def get_call_equation(jaxpr):
def lower_jaxpr(ctx, jaxpr, context=None):
- """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p"""
+ """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p
+
+ Args:
+ ctx: LoweringRuleContext
+ jaxpr: JAXPR to be lowered
+ context: additional context to distinguish different FuncOps
+
+ Returns:
+ FuncOp
+ """
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
callable_ = equation.params.get("fn")
@@ -54,7 +65,8 @@ def lower_jaxpr(ctx, jaxpr, context=None):
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context)
-def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None):
+# pylint: disable=too-many-arguments, too-many-positional-arguments
+def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False):
"""Lowers _callable to MLIR.
If callable_ is a qnode, then we will first create a module, then
@@ -66,6 +78,8 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None):
ctx: LoweringRuleContext
callable_: python function
call_jaxpr: jaxpr representing callable_
+ public: whether the visibility should be marked public
+
Returns:
FuncOp
"""
@@ -73,25 +87,49 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None):
pipeline = tuple()
if not isinstance(callable_, qml.QNode):
- return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
+ return get_or_create_funcop(
+ ctx, callable_, call_jaxpr, pipeline, context=context, public=public
+ )
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
-def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None):
- """Get funcOp from cache, or create it from scratch"""
+# pylint: disable=too-many-arguments, too-many-positional-arguments
+def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False):
+ """Get funcOp from cache, or create it from scratch
+
+ Args:
+ ctx: LoweringRuleContext
+ callable_: python function
+ call_jaxpr: jaxpr representing callable_
+ context: additional context to distinguish different FuncOps
+ public: whether the visibility should be marked public
+
+ Returns:
+ FuncOp
+ """
if context is None:
context = tuple()
key = (callable_, *context, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
- func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr)
+ func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public)
cache(ctx, key, func_op)
return func_op
-def lower_callable_to_funcop(ctx, callable_, call_jaxpr):
- """Lower callable to either a FuncOp"""
+def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
+ """Lower callable to either a FuncOp
+
+ Args:
+ ctx: LoweringRuleContext
+ callable_: python function
+ call_jaxpr: jaxpr representing callable_
+ public: whether the visibility should be marked public
+
+ Returns:
+ FuncOp
+ """
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
@@ -101,10 +139,16 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr):
name = callable_.__name__
else:
name = callable_.func.__name__ + ".partial"
+
kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
kwargs["effects"] = []
kwargs["name_stack"] = ctx.name_stack
+
+ # Make the visibility of the function public=True
+ # to avoid elimination by the compiler
+ kwargs["public"] = public
+
func_op = mlir.lower_jaxpr_to_fun(**kwargs)
if isinstance(callable_, qml.QNode):
@@ -135,6 +179,19 @@ def only_single_expval():
func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)
+ # Register the decomposition gatesets to the QNode FuncOp
+ # This will set a queue of gatesets that enables support for multiple
+ # levels of decomposition in the MLIR decomposition pass
+ if gateset := getattr(callable_, "decompose_gatesets", []):
+ func_op.attributes["decompose_gatesets"] = get_mlir_attribute_from_pyval(gateset)
+
+ # Extract the target gate and number of wires from decomposition rules
+ # and set them as attributes on the FuncOp for use in the MLIR decomposition pass
+ if target_gate := getattr(callable_, "target_gate", None):
+ func_op.attributes["target_gate"] = get_mlir_attribute_from_pyval(target_gate)
+ if num_wires := getattr(callable_, "num_wires", None):
+ func_op.attributes["num_wires"] = get_mlir_attribute_from_pyval(num_wires)
+
return func_op
diff --git a/frontend/catalyst/passes/__init__.py b/frontend/catalyst/passes/__init__.py
index a972a42a10..4940b49d95 100644
--- a/frontend/catalyst/passes/__init__.py
+++ b/frontend/catalyst/passes/__init__.py
@@ -46,12 +46,14 @@
ppr_to_ppm,
t_layer_reduction,
to_ppr,
+ decompose_lowering,
)
from catalyst.passes.pass_api import Pass, PassPlugin, apply_pass, apply_pass_plugin
__all__ = (
"to_ppr",
"commute_ppr",
+ "decompose_lowering",
"cancel_inverses",
"disentangle_cnot",
"disentangle_swap",
diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py
index 564e75b401..21f4a795a9 100644
--- a/frontend/catalyst/passes/builtin_passes.py
+++ b/frontend/catalyst/passes/builtin_passes.py
@@ -394,6 +394,34 @@ def circuit(x: float):
return PassPipelineWrapper(qnode, "merge-rotations")
+def decompose_lowering(qnode=None, *, rules_path=None):
+ """
+ Specify that the ``-decompose-lowering`` MLIR compiler pass
+ for applying the compiled decomposition rules to the QNode
+ recursively.
+
+ Args:
+ fn (QNode): the QNode to apply the cancel inverses compiler pass to
+ rules_path (str): the path to the decomposition rules MLIR file; if not provided,
+ the decomposition rules will be applied from the main module
+
+ Returns:
+ ~.QNode:
+
+ **Example**
+ // TODO: add example here
+
+ """
+ if qnode is None:
+ return functools.partial(decompose_lowering, rules_path=rules_path)
+
+ if rules_path is None:
+ decompose_lowering_pass = "decompose-lowering"
+ else:
+ decompose_lowering_pass = {"decompose-lowering": {"rules-path": rules_path}}
+ return PassPipelineWrapper(qnode, decompose_lowering_pass) # pragma: no cover
+
+
def ions_decomposition(qnode): # pragma: nocover
"""
Specify that the ``--ions-decomposition`` MLIR compiler pass should be
diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py
index 37872a1f0c..33cefda4a7 100644
--- a/frontend/catalyst/passes/pass_api.py
+++ b/frontend/catalyst/passes/pass_api.py
@@ -374,6 +374,7 @@ def dictionary_to_list_of_passes(pass_pipeline: PipelineDict | str, *flags, **va
def _API_name_to_pass_name():
return {
"cancel_inverses": "remove-chained-self-inverse",
+ "decompose_lowering": "decompose-lowering",
"disentangle_cnot": "disentangle-CNOT",
"disentangle_swap": "disentangle-SWAP",
"merge_rotations": "merge-rotations",
diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py
index 0c61109e9a..99b44d5dd6 100644
--- a/frontend/test/lit/test_decomposition.py
+++ b/frontend/test/lit/test_decomposition.py
@@ -3,8 +3,10 @@
import pathlib
import platform
from copy import deepcopy
+from functools import partial
import jax
+import numpy as np
import pennylane as qml
from pennylane.devices.capabilities import OperatorProperties
from pennylane.typing import TensorLike
@@ -29,6 +31,7 @@
# RUN: %PYTHON %s | FileCheck %s
# pylint: disable=line-too-long
+# pylint: disable=too-many-lines
TEST_PATH = os.path.dirname(__file__)
@@ -273,7 +276,7 @@ def decompose_to_matrix():
def test_decomposition_rule_wire_param():
"""Test decomposition rule with passing a parameter that is a wire/integer"""
- @decomposition_rule
+ @decomposition_rule(is_qreg=False)
def Hadamard0(wire: WiresLike):
qml.Hadamard(wire)
@@ -288,7 +291,7 @@ def circuit(_: float):
Hadamard0(int)
return qml.probs()
- # CHECK: func.func private @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit
+ # CHECK: func.func public @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit
# CHECK-NEXT: [[QUBIT_OUT:%.+]] = quantum.custom "Hadamard"() [[QBIT]] : !quantum.bit
# CHECK-NEXT: return [[QUBIT_OUT]] : !quantum.bit
@@ -303,7 +306,7 @@ def circuit(_: float):
def test_decomposition_rule_gate_param_param():
"""Test decomposition rule with passing a regular parameter"""
- @decomposition_rule(num_params=1)
+ @decomposition_rule(is_qreg=False, num_params=1)
def RX_on_wire_0(param: TensorLike, w0: WiresLike):
qml.RX(param, wires=w0)
@@ -316,7 +319,7 @@ def circuit_2(_: float):
RX_on_wire_0(float, int)
return qml.probs()
- # CHECK: func.func private @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit
+ # CHECK: func.func public @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit
# CHECK-NEXT: [[PARAM:%.+]] = tensor.extract [[PARAM_TENSOR]][] : tensor
# CHECK-NEXT: [[QUBIT_1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit
# CHECK-NEXT: return [[QUBIT_1]] : !quantum.bit
@@ -336,7 +339,7 @@ def test_multiple_decomposition_rules():
@decomposition_rule
def identity(): ...
- @decomposition_rule(num_params=1)
+ @decomposition_rule(is_qreg=True)
def all_wires_rx(param: TensorLike, w0: WiresLike, w1: WiresLike, w2: WiresLike):
qml.RX(param, wires=w0)
qml.RX(param, wires=w1)
@@ -355,8 +358,8 @@ def circuit_3(_: float):
qml.Hadamard(0)
return qml.probs()
- # CHECK: func.func private @identity
- # CHECK: func.func private @all_wires_rx
+ # CHECK: func.func public @identity
+ # CHECK: func.func public @all_wires_rx
print(circuit_3.mlir)
qml.capture.disable()
@@ -384,7 +387,7 @@ def circuit_4(_: float):
qml.Hadamard(0)
return qml.probs()
- # CHECK: func.func private @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg
+ # CHECK: func.func public @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg
# CHECK-NEXT: [[IDX_0:%.+]] = stablehlo.slice [[QUBITS]] [0:1] : (tensor<3xi64>) -> tensor<1xi64>
# CHECK-NEXT: [[RIDX_0:%.+]] = stablehlo.reshape [[IDX_0]] : (tensor<1xi64>) -> tensor
# CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract [[RIDX_0]][] : tensor
@@ -409,7 +412,7 @@ def shaped_wires_rule(param: TensorLike, wires: WiresLike):
qml.RX(param, wires=wires[1])
qml.RX(param, wires=wires[2])
- @decomposition_rule(num_params=1, is_qreg=False)
+ @decomposition_rule(is_qreg=False, num_params=1)
def expanded_wires_rule(param: TensorLike, w1, w2, w3):
shaped_wires_rule(param, [w1, w2, w3])
@@ -421,7 +424,7 @@ def circuit_5(_: float):
qml.Hadamard(0)
return qml.probs()
- # CHECK: func.func private @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit)
+ # CHECK: func.func public @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit)
print(circuit_5.mlir)
qml.capture.disable()
@@ -452,7 +455,7 @@ def circuit_6():
cond_RX(float, jax.core.ShapedArray((1,), int))
return qml.probs()
- # CHECK: func.func private @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg
+ # CHECK: func.func public @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg
# CHECK-NEXT: [[ZERO:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor
# CHECK-NEXT: [[COND_TENSOR:%.+]] = stablehlo.compare NE, [[PARAM_TENSOR]], [[ZERO]], FLOAT : (tensor, tensor) -> tensor
# CHECK-NEXT: [[COND:%.+]] = tensor.extract [[COND_TENSOR]][] : tensor
@@ -479,17 +482,17 @@ def test_decomposition_rule_caller():
qml.capture.enable()
@decomposition_rule(is_qreg=True)
- def Op1_decomp(_: TensorLike, wires: WiresLike):
+ def rule_op1_decomp(_: TensorLike, wires: WiresLike):
qml.Hadamard(wires=wires[0])
qml.Hadamard(wires=[1])
@decomposition_rule(is_qreg=True)
- def Op2_decomp(param: TensorLike, wires: WiresLike):
+ def rule_op2_decomp(param: TensorLike, wires: WiresLike):
qml.RX(param, wires=wires[0])
def decomps_caller(param: TensorLike, wires: WiresLike):
- Op1_decomp(param, wires)
- Op2_decomp(param, wires)
+ rule_op1_decomp(param, wires)
+ rule_op2_decomp(param, wires)
@qml.qjit(autograph=False)
@qml.qnode(qml.device("lightning.qubit", wires=1))
@@ -500,11 +503,643 @@ def circuit_7():
decomps_caller(float, jax.core.ShapedArray((2,), int))
return qml.probs()
- # CHECK: func.func private @Op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg
- # CHECK: func.func private @Op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg
-
+ # CHECK: func.func public @rule_op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg
+ # CHECK: func.func public @rule_op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg
print(circuit_7.mlir)
qml.capture.disable()
test_decomposition_rule_caller()
+
+
+def test_decompose_gateset_without_graph():
+ """Test the decompose transform to a target gate set without the graph decomposition."""
+
+ qml.capture.enable()
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={"RX", "RZ"})
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: public @circuit_8() -> tensor attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode}
+ def circuit_8():
+ return qml.expval(qml.Z(0))
+
+ print(circuit_8.mlir)
+
+ qml.capture.disable()
+
+
+test_decompose_gateset_without_graph()
+
+
+def test_decompose_gateset_with_graph():
+ """Test the decompose transform to a target gate set with the graph decomposition."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={"RX"})
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: public @simple_circuit_9() -> tensor attributes {decompose_gatesets
+ def simple_circuit_9():
+ return qml.expval(qml.Z(0))
+
+ print(simple_circuit_9.mlir)
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={"RX", "RZ"})
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_9() -> tensor attributes {decompose_gatesets
+ def circuit_9():
+ return qml.expval(qml.Z(0))
+
+ print(circuit_9.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_gateset_with_graph()
+
+
+def test_decompose_gateset_operator_with_graph():
+ """Test the decompose transform to a target gate set with the graph decomposition."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={qml.RX})
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: public @simple_circuit_10() -> tensor attributes {decompose_gatesets
+ def simple_circuit_10():
+ return qml.expval(qml.Z(0))
+
+ print(simple_circuit_10.mlir)
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose, gate_set={qml.RX, qml.RZ, "PauliZ", qml.PauliX, qml.Hadamard}
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: public @circuit_10() -> tensor attributes {decompose_gatesets
+ def circuit_10():
+ return qml.expval(qml.Z(0))
+
+ print(circuit_10.mlir)
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose, gate_set={qml.RX, qml.RZ, qml.PauliZ, qml.PauliX, qml.Hadamard}
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_11() -> tensor attributes {decompose_gatesets
+ def circuit_11():
+ return qml.expval(qml.Z(0))
+
+ print(circuit_11.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_gateset_operator_with_graph()
+
+
+def test_decompose_gateset_with_rotxzx():
+ """Test the decompose transform with a custom operator with the graph decomposition."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={"RotXZX"})
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: public @simple_circuit_12() -> tensor attributes {decompose_gatesets
+ def simple_circuit_12():
+ return qml.expval(qml.Z(0))
+
+ print(simple_circuit_12.mlir)
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX})
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_12() -> tensor attributes {decompose_gatesets
+ def circuit_12():
+ return qml.expval(qml.Z(0))
+
+ print(circuit_12.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_gateset_with_rotxzx()
+
+
+def test_decomposition_rule_name():
+ """Test the name of the decomposition rule is not updated with circuit instantiation."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @decomposition_rule
+ def _ry_to_rz_rx(phi, wires: WiresLike, **__):
+ """Decomposition of RY gate using RZ and RX gates."""
+ qml.RZ(-np.pi / 2, wires=wires)
+ qml.RX(phi, wires=wires)
+ qml.RZ(np.pi / 2, wires=wires)
+
+ @decomposition_rule
+ def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__):
+ """Decomposition of Rot gate using RZ and RY gates."""
+ qml.RZ(phi, wires=wires)
+ qml.RY(theta, wires=wires)
+ qml.RZ(omega, wires=wires)
+
+ @decomposition_rule
+ def _u2_phaseshift_rot_decomposition(phi, delta, wires, **__):
+ """Decomposition of U2 gate using Rot and PhaseShift gates."""
+ pi_half = qml.math.ones_like(delta) * (np.pi / 2)
+ qml.Rot(delta, pi_half, -delta, wires=wires)
+ qml.PhaseShift(delta, wires=wires)
+ qml.PhaseShift(phi, wires=wires)
+
+ @decomposition_rule
+ def _xzx_decompose(phi, theta, omega, wires, **__):
+ """Decomposition of Rot gate using RX and RZ gates in XZX format."""
+ qml.RX(phi, wires=wires)
+ qml.RZ(theta, wires=wires)
+ qml.RX(omega, wires=wires)
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"})
+ @qml.qnode(qml.device("lightning.qubit", wires=3))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_13() -> tensor attributes {decompose_gatesets
+ def circuit_13():
+ _ry_to_rz_rx(float, int)
+ _rot_to_rz_ry_rz(float, float, float, int)
+ _u2_phaseshift_rot_decomposition(float, float, int)
+ _xzx_decompose(float, float, float, int)
+ return qml.expval(qml.Z(0))
+
+ # CHECK: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg
+ # CHECK: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg
+ # CHECK: func.func public @_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg
+ # CHECK: func.func public @_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg
+ print(circuit_13.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decomposition_rule_name()
+
+
+def test_decomposition_rule_name_update():
+ """Test the name of the decomposition rule is updated in the MLIR output."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.register_resources({qml.RZ: 2, qml.RX: 1})
+ def rz_rx(phi, wires: WiresLike, **__):
+ """Decomposition of RY gate using RZ and RX gates."""
+ qml.RZ(-np.pi / 2, wires=wires)
+ qml.RX(phi, wires=wires)
+ qml.RZ(np.pi / 2, wires=wires)
+
+ @qml.register_resources({qml.RZ: 2, qml.RY: 1})
+ def rz_ry_rz(phi, theta, omega, wires: WiresLike, **__):
+ """Decomposition of Rot gate using RZ and RY gates."""
+ qml.RZ(phi, wires=wires)
+ qml.RY(theta, wires=wires)
+ qml.RZ(omega, wires=wires)
+
+ @qml.register_resources({qml.RY: 1, qml.PhaseShift: 1})
+ def ry_gp(wires: WiresLike, **__):
+ """Decomposition of PauliY gate using RY and GlobalPhase gates."""
+ qml.RY(np.pi, wires=wires)
+ qml.GlobalPhase(-np.pi / 2, wires=wires)
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RX", "RZ", "PhaseShift"},
+ fixed_decomps={
+ qml.RY: rz_rx,
+ qml.Rot: rz_ry_rz,
+ qml.PauliY: ry_gp,
+ },
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=3))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_14() -> tensor attributes {decompose_gatesets
+ def circuit_14():
+ qml.RY(0.5, wires=0)
+ qml.Rot(0.1, 0.2, 0.3, wires=1)
+ qml.PauliY(wires=2)
+ return qml.expval(qml.Z(0))
+
+ # CHECK-DAG: func.func public @rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg
+ # CHECK-DAG: func.func public @rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg
+ # CHECK-DAG: func.func public @ry_gp(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg
+ print(circuit_14.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decomposition_rule_name_update()
+
+
+def test_decomposition_rule_name_update_multi_qubits():
+ """Test the name of the decomposition rule with multi-qubit gates."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RY", "RX", "CNOT", "Hadamard", "GlobalPhase"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=4))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_15() -> tensor attributes {decompose_gatesets
+ def circuit_15():
+ qml.SingleExcitation(0.5, wires=[0, 1])
+ qml.SingleExcitationPlus(0.5, wires=[0, 1])
+ qml.SingleExcitationMinus(0.5, wires=[0, 1])
+ qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3])
+ return qml.expval(qml.Z(0))
+
+ # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"}
+ # CHECK-DAG: func.func public @_s_phaseshift(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "S"}
+ # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"}
+ # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"}
+ # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ # CHECK-DAG: func.func public @_doublexcit(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 4 : i64, target_gate = "DoubleExcitation"}
+ # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"}
+ print(circuit_15.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decomposition_rule_name_update_multi_qubits()
+
+
+def test_decomposition_rule_name_adjoint():
+ """Test decomposition rule with qml.adjoint."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RY", "RX", "CZ", "GlobalPhase"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=4))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: public @circuit_16() -> tensor attributes {decompose_gatesets
+ def circuit_16():
+ # CHECK-DAG: %1 = quantum.adjoint(%0) : !quantum.reg
+ # CHECK-DAG: %2 = quantum.adjoint(%1) : !quantum.reg
+ # CHECK-DAG: %3 = quantum.adjoint(%2) : !quantum.reg
+ # CHECK-DAG: %4 = quantum.adjoint(%3) : !quantum.reg
+ qml.adjoint(qml.CNOT)(wires=[0, 1])
+ qml.adjoint(qml.Hadamard)(wires=2)
+ qml.adjoint(qml.RZ)(0.5, wires=3)
+ qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1])
+ return qml.expval(qml.Z(0))
+
+ # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"}
+ # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"}
+ # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"}
+ # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"}
+ print(circuit_16.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decomposition_rule_name_adjoint()
+
+
+def test_decomposition_rule_name_ctrl():
+ """Test decomposition rule with qml.ctrl."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RX", "RZ", "H", "CZ"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK{LITERAL}: func.func public @circuit_17() -> tensor attributes {decompose_gatesets
+ def circuit_17():
+ # CHECK: %out_qubits:2 = quantum.custom "CRY"(%cst) %1, %2 : !quantum.bit, !quantum.bit
+ # CHECK-NEXT: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit
+ qml.ctrl(qml.RY, control=0)(0.5, 1)
+ qml.ctrl(qml.PauliX, control=0)(1)
+ return qml.expval(qml.Z(0))
+
+ # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"}
+ # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"}
+ # CHECK-DAG: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RY"}
+ # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ print(circuit_17.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decomposition_rule_name_ctrl()
+
+
+def test_qft_decomposition():
+ """Test the decomposition of the QFT"""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RX", "RY", "CNOT", "GlobalPhase"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=4))
+ # CHECK: %0 = transform.apply_registered_pass "decompose-lowering"
+ # CHECK: func.func public @circuit_18(%arg0: tensor<3xf64>) -> tensor attributes {decompose_gatesets
+ def circuit_18():
+ # %6 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %0) -> (!quantum.reg) {
+ # %23 = scf.for %arg3 = %c0 to %22 step %c1 iter_args(%arg4 = %21) -> (!quantum.reg) {
+ # %7 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %6) -> (!quantum.reg) {
+ qml.QFT(wires=[0, 1, 2, 3])
+ return qml.expval(qml.Z(0))
+
+ # CHECK-DAG: func.func public @_cphase_to_rz_cnot(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "ControlledPhaseShift"}
+ # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"}
+ # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ # CHECK-DAG: func.func public @_swap_to_cnot(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SWAP"}
+ # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"}
+ print(circuit_18.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_qft_decomposition()
+
+
+def test_decompose_lowering_with_other_passes():
+ """Test the decompose lowering pass with other passes in a pass pipeline."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @qml.transforms.merge_rotations
+ @qml.transforms.cancel_inverses
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RZ", "RY", "CNOT", "GlobalPhase"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: module attributes {transform.with_named_sequence} {
+ # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ # CHECK-NEXT: [[ONE:%.+]] = transform.apply_registered_pass "decompose-lowering" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ # CHECK-NEXT: [[TWO:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[ONE]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to [[TWO]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ # CHECK-NEXT: transform.yield
+ # CHECK-NEXT: }
+ def circuit_19():
+
+ # CHECK: [[OUT_0:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit
+ # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "PauliX"() [[OUT_0]] : !quantum.bit
+ # CHECK-NEXT: [[OUT_2:%.+]] = quantum.custom "RX"(%cst_0) [[OUT_1]] : !quantum.bit
+ # CHECK-NEXT: {{%.+}} = quantum.custom "RX"(%cst) [[OUT_2]] : !quantum.bit
+ qml.PauliX(0)
+ qml.PauliX(0)
+ qml.RX(0.1, wires=0)
+ qml.RX(-0.1, wires=0)
+ return qml.expval(qml.PauliX(0))
+
+ # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"}
+ # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"}
+ print(circuit_19.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_lowering_with_other_passes()
+
+
+def test_decompose_lowering_multirz():
+ """Test the decompose lowering pass with MultiRZ in the gate set."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"CNOT", "RZ"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=3))
+ # CHECK: %0 = transform.apply_registered_pass "decompose-lowering"
+ def circuit_20(x: float):
+ # CHECK: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor
+ # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.multirz([[EXTRACTED]]) %1 : !quantum.bit
+ # CHECK-NEXT: [[BIT_1:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
+ # CHECK-NEXT: [[EXTRACTED_0:%.+]] = tensor.extract %arg0[] : tensor
+ # CHECK-NEXT: [[OUT_QUBITS_1:%.+]] = quantum.multirz([[EXTRACTED_0]]) [[OUT_QUBITS]], [[BIT_1]] : !quantum.bit, !quantum.bit
+ # CHECK-NEXT: [[BIT_2:%.+]] = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit
+ # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract %arg0[] : tensor
+ # CHECK-NEXT: {{%.+}} = quantum.multirz([[EXTRACTED_2]]) {{%.+}}, {{%.+}}, [[BIT_2]] : !quantum.bit, !quantum.bit, !quantum.bit
+ qml.MultiRZ(x, wires=[0])
+ qml.MultiRZ(x, wires=[0, 1])
+ qml.MultiRZ(x, wires=[1, 0, 2])
+ return qml.expval(qml.PauliX(0))
+
+ # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"}
+ # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"}
+ # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"}
+ # CHECK-DAG: %0 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %arg0) -> (!quantum.reg)
+ # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg)
+ print(circuit_20.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_lowering_multirz()
+
+
+def test_decompose_lowering_with_ordered_passes():
+ """Test the decompose lowering pass with other passes in a specific order in a pass pipeline."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RZ", "RY", "CNOT", "GlobalPhase"},
+ )
+ @qml.transforms.merge_rotations
+ @qml.transforms.cancel_inverses
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: module attributes {transform.with_named_sequence} {
+ # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ # CHECK-NEXT: [[FIRST:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ # CHECK-NEXT: [[SECOND:%.+]] = transform.apply_registered_pass "merge-rotations" to [[FIRST]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "decompose-lowering" to [[SECOND]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ # CHECK-NEXT: transform.yield
+ # CHECK-NEXT: }
+ def circuit_21(x: float):
+ # CHECK: [[OUT:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit
+ # CHECK-NEXT: [[OUT_0:%.+]] = quantum.custom "PauliX"() [[OUT]] : !quantum.bit
+ # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor
+ # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "RX"([[EXTRACTED]]) [[OUT_0]] : !quantum.bit
+ # CHECK-NEXT: [[NEGATED:%.+]] = stablehlo.negate %arg0 : tensor
+ # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract [[NEGATED]][] : tensor
+ # CHECK-NEXT: {{%.+}} = quantum.custom "RX"([[EXTRACTED_2]]) [[OUT_1]] : !quantum.bit
+ qml.PauliX(0)
+ qml.PauliX(0)
+ qml.RX(x, wires=0)
+ qml.RX(-x, wires=0)
+ return qml.expval(qml.PauliX(0))
+
+ # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"}
+ # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"}
+ # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ print(circuit_21.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_lowering_with_ordered_passes()
+
+
+def test_decompose_lowering_with_gphase():
+ """Test the decompose lowering pass with GlobalPhase."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RX", "RY", "GlobalPhase"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=3))
+ # CHECK: %0 = transform.apply_registered_pass "decompose-lowering"
+ def circuit_22():
+ # CHECK: quantum.gphase(%cst_0) :
+ # CHECK-NEXT: [[EXTRACTED:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.custom "PhaseShift"(%cst) [[EXTRACTED]] : !quantum.bit
+ # CHECK-NEXT: {{%.+}} = quantum.custom "PhaseShift"(%cst) [[OUT_QUBITS]] : !quantum.bit
+ qml.GlobalPhase(0.5)
+ qml.ctrl(qml.GlobalPhase, control=0)(0.3)
+ qml.ctrl(qml.GlobalPhase, control=0)(phi=0.3, wires=[1, 2])
+ return qml.expval(qml.PauliX(0))
+
+ # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"}
+ # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"}
+ print(circuit_22.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_lowering_with_gphase()
+
+
+def test_decompose_lowering_alt_decomps():
+ """Test the decompose lowering pass with alternative decompositions."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.register_resources({qml.RY: 1})
+ def custom_rot_cheap(params, wires: WiresLike):
+ qml.RY(params[1], wires=wires)
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RY", "RZ"},
+ alt_decomps={qml.Rot: [custom_rot_cheap]},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000)
+ def circ(x: float, y: float):
+ qml.Rot(x, y, x + y, wires=1)
+ return qml.expval(qml.PauliZ(0))
+
+ # CHECK-DAG: func.func public @custom_rot_cheap(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ print(circ.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_lowering_alt_decomps()
+
+
+def test_decompose_lowering_with_tensorlike():
+ """Test the decompose lowering pass with fixed decompositions
+ using TensorLike parameters."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.register_resources({qml.RZ: 2, qml.RY: 1})
+ def custom_rot(params: TensorLike, wires: WiresLike):
+ qml.RZ(params[0], wires=wires)
+ qml.RY(params[1], wires=wires)
+ qml.RZ(params[2], wires=wires)
+
+ @qml.register_resources({qml.RZ: 1, qml.CNOT: 4})
+ def custom_multirz(params: TensorLike, wires: WiresLike):
+ qml.CNOT(wires=(wires[2], wires[1]))
+ qml.CNOT(wires=(wires[1], wires[0]))
+ qml.RZ(params[0], wires=wires[0])
+ qml.CNOT(wires=(wires[1], wires[0]))
+ qml.CNOT(wires=(wires[2], wires[1]))
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RY", "RX", qml.CNOT},
+ fixed_decomps={qml.Rot: custom_rot, qml.MultiRZ: custom_multirz},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000)
+ def circ(x: float, y: float):
+ qml.Rot(x, y, x + y, wires=1)
+ qml.MultiRZ(x + y, wires=[0, 1, 2])
+ return qml.expval(qml.PauliZ(0))
+
+ # CHECK-DAG: func.func public @custom_multirz_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"}
+ # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"}
+ # CHECK-DAG: func.func public @custom_rot(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"}
+ print(circ.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_decompose_lowering_with_tensorlike()
diff --git a/frontend/test/lit/test_dynamic_qubit_allocation.py b/frontend/test/lit/test_dynamic_qubit_allocation.py
index 85b3ede88a..8b4d7f114c 100644
--- a/frontend/test/lit/test_dynamic_qubit_allocation.py
+++ b/frontend/test/lit/test_dynamic_qubit_allocation.py
@@ -82,7 +82,8 @@ def test_basic_dynalloc():
# CHECK: [[CNOTout:%.+]]:2 = quantum.custom "CNOT"() [[dyn_bit2]], [[dev_bit1]]
# CHECK: [[insert0:%.+]] = quantum.insert [[dyn_qreg]][ 1], [[Xout]]
# CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 2], [[CNOTout]]#0
- # CHECK: quantum.dealloc [[insert1]]
+ # CHECK: [[insert2:%.+]] = quantum.insert [[insert1]][ 3]
+ # CHECK: quantum.dealloc [[insert2]]
with qml.allocate(4) as qs1:
qml.X(qs1[1])
diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py
index 7fd8618cc6..14cea4a8e6 100644
--- a/frontend/test/lit/test_from_plxpr.py
+++ b/frontend/test/lit/test_from_plxpr.py
@@ -18,6 +18,8 @@
"""Lit tests for the PLxPR to JAXPR with quantum primitives pipeline"""
+from functools import partial
+
import pennylane as qml
@@ -45,7 +47,7 @@ def main():
print(main.mlir)
- qml.capture.enable()
+ qml.capture.disable()
test_conditional_capture()
@@ -362,3 +364,60 @@ def circuit():
test_pass_application()
+
+
+def test_pass_decomposition():
+ """Application of pass decorator with decomposition."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @qml.transforms.cancel_inverses
+ @qml.transforms.merge_rotations
+ @partial(qml.transforms.decompose, gate_set={"RX", "RZ"})
+ @qml.qnode(dev)
+ def circuit1():
+ return qml.probs()
+
+ # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering"
+ # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "merge-rotations"
+ # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]]
+
+ print(circuit1.mlir)
+
+ @qml.qjit(target="mlir")
+ @qml.transforms.cancel_inverses
+ @partial(qml.transforms.decompose, gate_set={"RX", "RZ"})
+ @qml.transforms.merge_rotations
+ @qml.qnode(dev)
+ def circuit2():
+ return qml.probs()
+
+ # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations"
+ # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "decompose-lowering"
+ # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]]
+
+ print(circuit2.mlir)
+
+ @qml.qjit(target="mlir")
+ @partial(qml.transforms.decompose, gate_set={"RX", "RZ"})
+ @qml.transforms.cancel_inverses
+ @qml.transforms.merge_rotations
+ @qml.qnode(dev)
+ def circuit3():
+ return qml.probs()
+
+ # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations"
+ # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse"
+ # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[second_pass]]
+
+ print(circuit3.mlir)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+
+test_pass_decomposition()
diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py
index 47b0b65be8..9d2ce69a2c 100644
--- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py
+++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py
@@ -15,6 +15,8 @@
This module tests the from_plxpr conversion function.
"""
+from functools import partial
+
import jax
import numpy as np
import pennylane as qml
@@ -965,5 +967,39 @@ def workflow(x, y):
assert qml.math.allclose(results, expected)
+class TestGraphDecomposition:
+ """Test the new graph-based decomposition integration with from_plxpr."""
+
+ def test_with_multiple_decomps_transforms(self):
+ """Test that a circuit with multiple decompositions and transforms can be converted."""
+
+ qml.capture.enable()
+ qml.decomposition.enable_graph()
+
+ @qml.qjit(target="mlir")
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"RX", "RY"},
+ )
+ @partial(
+ qml.transforms.decompose,
+ gate_set={"NOT", "GlobalPhase"},
+ )
+ @qml.qnode(qml.device("lightning.qubit", wires=0))
+ def circuit(x):
+ qml.GlobalPhase(x)
+ return qml.expval(qml.PauliX(0))
+
+ with pytest.raises(
+ NotImplementedError, match="Multiple decomposition transforms are not yet supported."
+ ):
+ circuit(0.2)
+
+ qml.decomposition.disable_graph()
+ qml.capture.disable()
+
+ assert qml.decomposition.enabled_graph() is False
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td
index 89498b185a..0b3a76296d 100644
--- a/mlir/include/Quantum/IR/QuantumOps.td
+++ b/mlir/include/Quantum/IR/QuantumOps.td
@@ -241,6 +241,7 @@ def ExtractOp : Memory_Op<"extract", [NoMemoryEffect]> {
$qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results)
}];
+ let hasCanonicalizeMethod = 1;
let hasVerifier = 1;
let hasFolder = 1;
}
diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h
index 00f33d8fa4..33b25c0179 100644
--- a/mlir/include/Quantum/Transforms/Passes.h
+++ b/mlir/include/Quantum/Transforms/Passes.h
@@ -30,6 +30,7 @@ std::unique_ptr createRemoveChainedSelfInversePass();
std::unique_ptr createAnnotateFunctionPass();
std::unique_ptr createSplitMultipleTapesPass();
std::unique_ptr createMergeRotationsPass();
+std::unique_ptr createDecomposeLoweringPass();
std::unique_ptr createDisentangleCNOTPass();
std::unique_ptr createDisentangleSWAPPass();
std::unique_ptr createIonsDecompositionPass();
diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td
index f0a344190e..b4aeed0dd9 100644
--- a/mlir/include/Quantum/Transforms/Passes.td
+++ b/mlir/include/Quantum/Transforms/Passes.td
@@ -110,6 +110,17 @@ def MergeRotationsPass : Pass<"merge-rotations"> {
let constructor = "catalyst::createMergeRotationsPass()";
}
+def DecomposeLoweringPass : Pass<"decompose-lowering"> {
+ let summary = "Replace quantum operations with compiled decomposition rules.";
+
+ let options = [
+ Option<"rulesPath", "rules-path", "std::string", /*default=*/"",
+ "Path to the MLIR module containing gate decomposition rules.">
+ ];
+
+ let constructor = "catalyst::createDecomposeLoweringPass()";
+}
+
def DisentangleCNOTPass : Pass<"disentangle-CNOT"> {
let summary = "Replace a CNOT gate with two single qubit gates whenever possible.";
diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h
index a16569c01b..8b8ade74c1 100644
--- a/mlir/include/Quantum/Transforms/Patterns.h
+++ b/mlir/include/Quantum/Transforms/Patterns.h
@@ -15,8 +15,12 @@
#pragma once
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/AllocatorBase.h"
namespace catalyst {
namespace quantum {
@@ -26,6 +30,9 @@ void populateAdjointPatterns(mlir::RewritePatternSet &);
void populateSelfInversePatterns(mlir::RewritePatternSet &);
void populateMergeRotationsPatterns(mlir::RewritePatternSet &);
void populateIonsDecompositionPatterns(mlir::RewritePatternSet &);
+void populateDecomposeLoweringPatterns(mlir::RewritePatternSet &,
+ const llvm::StringMap &,
+ const llvm::StringSet &);
void populateLoopBoundaryPatterns(mlir::RewritePatternSet &, unsigned int mode);
} // namespace quantum
diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
index 0e7be8337b..88d97a8674 100644
--- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
+++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
@@ -38,6 +38,7 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createCliffordTToPPRPass);
mlir::registerPass(catalyst::createMergePPRIntoPPMPass);
mlir::registerPass(catalyst::createPPMCompilationPass);
+ mlir::registerPass(catalyst::createDecomposeLoweringPass);
mlir::registerPass(catalyst::createDecomposeNonCliffordPPRPass);
mlir::registerPass(catalyst::createDecomposeCliffordPPRPass);
mlir::registerPass(catalyst::createCountPPMSpecsPass);
diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp
index 7049f58e63..14f5e6e811 100644
--- a/mlir/lib/Quantum/IR/QuantumDialect.cpp
+++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser
+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
-#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser
+#include "mlir/Transforms/InliningUtils.h"
#include "Quantum/IR/QuantumDialect.h"
#include "Quantum/IR/QuantumOps.h"
@@ -22,6 +25,65 @@
using namespace mlir;
using namespace catalyst::quantum;
+//===----------------------------------------------------------------------===//
+// Quantum Dialect Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct QuantumInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ static constexpr StringRef decompAttr = "target_gate";
+
+ /// Returns true if the given operation 'callable' can be inlined into the
+ /// position given by the 'call'. Currently, we always inline quantum
+ /// decomposition functions.
+ bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final
+ {
+ if (auto funcOp = dyn_cast(callable)) {
+ return funcOp->hasAttr(decompAttr);
+ }
+ return false;
+ }
+
+ /// Returns true if the given region 'src' can be inlined into the region
+ /// 'dest'. Only allow for decomposition functions.
+ bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
+ IRMapping &valueMapping) const final
+ {
+ if (auto funcOp = src->getParentOfType()) {
+ return funcOp->hasAttr(decompAttr);
+ }
+ return false;
+ }
+
+ // Allow to inline operations from decomposition functions.
+ bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
+ IRMapping &valueMapping) const final
+ {
+ if (auto funcOp = op->getParentOfType()) {
+ return funcOp->hasAttr(decompAttr);
+ }
+ return false;
+ }
+
+ /// Handle the given inlined terminator by replacing it with a new operation
+ /// as necessary. Required when the region has only one block.
+ void handleTerminator(Operation *op, ValueRange valuesToRepl) const final
+ {
+ auto yieldOp = dyn_cast(op);
+ if (!yieldOp) {
+ return;
+ }
+
+ for (auto retValue : llvm::zip(valuesToRepl, yieldOp.getOperands())) {
+ std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
+ }
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// Quantum dialect definitions.
//===----------------------------------------------------------------------===//
@@ -45,6 +107,8 @@ void QuantumDialect::initialize()
#include "Quantum/IR/QuantumOps.cpp.inc"
>();
+ addInterfaces();
+
declarePromisedInterfaces();
diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp
index 101ed77a57..b0eedf09d3 100644
--- a/mlir/lib/Quantum/IR/QuantumOps.cpp
+++ b/mlir/lib/Quantum/IR/QuantumOps.cpp
@@ -146,6 +146,28 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor)
return nullptr;
}
+LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter &rewriter)
+{
+ // Handle the pattern: %reg2 = insert %reg1[idx], %qubit -> %q = extract %reg2[idx]
+ // Convert to: %q = %qubit, and replace other uses of %reg2 with %reg1
+ if (auto insert = dyn_cast_if_present(extract.getQreg().getDefiningOp())) {
+ bool bothStatic = extract.getIdxAttr().has_value() && insert.getIdxAttr().has_value();
+ bool bothDynamic = !extract.getIdxAttr().has_value() && !insert.getIdxAttr().has_value();
+ bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr();
+ bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx();
+ // if other users of insert are also `insert`, we are good to go
+ bool valid = llvm::all_of(insert.getResult().getUsers(), [&](Operation *op) {
+ return isa(op) || op == extract.getOperation();
+ });
+ if ((staticallyEqual || dynamicallyEqual) && valid) {
+ rewriter.replaceOp(extract, insert.getQubit());
+ rewriter.replaceOp(insert, insert.getInQreg());
+ return success();
+ }
+ }
+ return failure();
+}
+
LogicalResult InsertOp::canonicalize(InsertOp insert, mlir::PatternRewriter &rewriter)
{
if (auto extract = dyn_cast_if_present(insert.getQubit().getDefiningOp())) {
@@ -153,9 +175,10 @@ LogicalResult InsertOp::canonicalize(InsertOp insert, mlir::PatternRewriter &rew
bool bothDynamic = !extract.getIdxAttr().has_value() && !insert.getIdxAttr().has_value();
bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr();
bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx();
+ bool sameQreg = extract.getQreg() == insert.getInQreg();
bool oneUse = extract.getResult().hasOneUse();
- if ((staticallyEqual || dynamicallyEqual) && oneUse) {
+ if ((staticallyEqual || dynamicallyEqual) && oneUse && sameQreg) {
rewriter.replaceOp(insert, insert.getInQreg());
rewriter.eraseOp(extract);
return success();
diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt
index 3a244ac4d6..f1ae85d1ff 100644
--- a/mlir/lib/Quantum/Transforms/CMakeLists.txt
+++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt
@@ -14,6 +14,8 @@ file(GLOB SRC
SplitMultipleTapes.cpp
merge_rotation.cpp
MergeRotationsPatterns.cpp
+ decompose_lowering.cpp
+ DecomposeLoweringPatterns.cpp
DisentangleSWAP.cpp
DisentangleCNOT.cpp
ions_decompositions.cpp
@@ -28,6 +30,7 @@ set(LIBS
${dialect_libs}
${conversion_libs}
MLIRQuantum
+ ExternalStablehloLib
)
set(DEPENDS
diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
new file mode 100644
index 0000000000..2f60e052a1
--- /dev/null
+++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
@@ -0,0 +1,474 @@
+// Copyright 2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#define DEBUG_TYPE "decompose-lowering"
+
+#include
+
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+
+#include "Quantum/IR/QuantumOps.h"
+#include "Quantum/Transforms/Patterns.h"
+
+using namespace mlir;
+using namespace catalyst::quantum;
+
+namespace catalyst {
+namespace quantum {
+
+/// A struct to represent qubit indices in quantum operations.
+///
+/// This struct provides a way to handle qubit indices that can be either:
+/// - A runtime Value (for dynamic indices computed at runtime)
+/// - An IntegerAttr (for compile-time constant indices)
+/// - Invalid/uninitialized (represented by std::monostate)
+///
+/// The struct uses std::variant to ensure only one type is active at a time,
+/// preventing invalid states.
+///
+/// Example usage:
+/// QubitIndex dynamicIdx(operandValue); // Runtime qubit index
+/// QubitIndex staticIdx(IntegerAttr::get(...)); // Compile-time constant
+/// QubitIndex invalidIdx; // Uninitialized state
+///
+/// if (dynamicIdx) { // Check if valid
+/// if (dynamicIdx.isValue()) { // Check if runtime value
+/// Value idx = dynamicIdx.getValue(); // Get the Value
+/// }
+/// }
+struct QubitIndex {
+ // use monostate to represent the invalid index
+ std::variant index;
+
+ QubitIndex() : index(std::monostate()) {}
+ QubitIndex(Value val) : index(val) {}
+ QubitIndex(IntegerAttr attr) : index(attr) {}
+
+ bool isValue() const { return std::holds_alternative(index); }
+ bool isAttr() const { return std::holds_alternative(index); }
+ operator bool() const { return isValue() || isAttr(); }
+ Value getValue() const { return isValue() ? std::get(index) : nullptr; }
+ IntegerAttr getAttr() const { return isAttr() ? std::get(index) : nullptr; }
+};
+
+// The goal of this class is to analyze the signature of a custom operation to get the enough
+// information to prepare the call operands and results for replacing the op to calling the
+// decomposition function.
+class OpSignatureAnalyzer {
+ public:
+ OpSignatureAnalyzer() = delete;
+ OpSignatureAnalyzer(CustomOp op, bool enableQregMode)
+ : signature(OpSignature{
+ .params = op.getParams(),
+ .inQubits = op.getInQubits(),
+ .inCtrlQubits = op.getInCtrlQubits(),
+ .inCtrlValues = op.getInCtrlValues(),
+ .outQubits = op.getOutQubits(),
+ .outCtrlQubits = op.getOutCtrlQubits(),
+ })
+ {
+ if (!enableQregMode)
+ return;
+
+ signature.sourceQreg = getSourceQreg(signature.inQubits.front());
+ if (!signature.sourceQreg) {
+ op.emitError("Cannot get source qreg");
+ isValid = false;
+ return;
+ }
+
+ // input wire indices
+ for (Value qubit : signature.inQubits) {
+ const QubitIndex index = getExtractIndex(qubit);
+ if (!index) {
+ op.emitError("Cannot get index for input qubit");
+ isValid = false;
+ return;
+ }
+ signature.inWireIndices.emplace_back(index);
+ }
+
+ // input ctrl wire indices
+ for (Value ctrlQubit : signature.inCtrlQubits) {
+ const QubitIndex index = getExtractIndex(ctrlQubit);
+ if (!index) {
+ op.emitError("Cannot get index for ctrl qubit");
+ isValid = false;
+ return;
+ }
+ signature.inCtrlWireIndices.emplace_back(index);
+ }
+
+ // Output qubit indices are the same as input qubit indices
+ signature.outQubitIndices = signature.inWireIndices;
+ signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
+ }
+
+ operator bool() const { return isValid; }
+
+ // Prepare the operands for calling the decomposition function
+ // There are two cases:
+ // 1. The first input is a qreg, which means the decomposition function is a qreg mode function
+ // 2. Otherwise, the decomposition function is a qubit mode function
+ //
+ // Type signatures:
+ // 1. qreg mode:
+ // - func(qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg
+ // 2. qubit mode:
+ // - func(param*, inQubits*, inCtrlQubits*?, inCtrlValues*?) -> outQubits*
+ llvm::SmallVector prepareCallOperands(func::FuncOp decompFunc, PatternRewriter &rewriter,
+ Location loc)
+ {
+ auto funcType = decompFunc.getFunctionType();
+ auto funcInputs = funcType.getInputs();
+
+ SmallVector operands(funcInputs.size());
+
+ int operandIdx = 0;
+ if (isa(funcInputs[0])) {
+ Value updatedQreg = signature.sourceQreg;
+ for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) {
+ const QubitIndex &index = signature.inWireIndices[i];
+ updatedQreg =
+ rewriter.create(loc, updatedQreg.getType(), updatedQreg,
+ index.getValue(), index.getAttr(), qubit);
+ }
+
+ operands[operandIdx++] = updatedQreg;
+ if (!signature.params.empty()) {
+ auto [startIdx, endIdx] =
+ findParamTypeRange(funcInputs, signature.params.size(), operandIdx);
+ ArrayRef paramsTypes = funcInputs.slice(startIdx, endIdx - startIdx);
+ auto updatedParams = generateParams(signature.params, paramsTypes, rewriter, loc);
+ for (Value param : updatedParams) {
+ operands[operandIdx++] = param;
+ }
+ }
+
+ if (!signature.inWireIndices.empty()) {
+ operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices,
+ funcInputs[operandIdx], rewriter, loc);
+ operandIdx++;
+ }
+
+ if (!signature.inCtrlWireIndices.empty()) {
+ operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices,
+ funcInputs[operandIdx], rewriter, loc);
+ operandIdx++;
+ }
+ }
+ else {
+ if (!signature.params.empty()) {
+ auto [startIdx, endIdx] =
+ findParamTypeRange(funcInputs, signature.params.size(), operandIdx);
+ ArrayRef paramsTypes = funcInputs.slice(startIdx, endIdx - startIdx);
+ auto updatedParams = generateParams(signature.params, paramsTypes, rewriter, loc);
+ for (Value param : updatedParams) {
+ operands[operandIdx++] = param;
+ }
+ }
+
+ for (auto inQubit : signature.inQubits) {
+ operands[operandIdx] =
+ fromTensorOrAsIs(inQubit, funcInputs[operandIdx], rewriter, loc);
+ operandIdx++;
+ }
+
+ for (auto inCtrlQubit : signature.inCtrlQubits) {
+ operands[operandIdx] =
+ fromTensorOrAsIs(inCtrlQubit, funcInputs[operandIdx], rewriter, loc);
+ operandIdx++;
+ }
+ }
+
+ if (!signature.inCtrlValues.empty()) {
+ operands[operandIdx] =
+ fromTensorOrAsIs(signature.inCtrlValues, funcInputs[operandIdx], rewriter, loc);
+ operandIdx++;
+ }
+
+ return operands;
+ }
+
+ // Prepare the results for the call operation
+ SmallVector prepareCallResultForQreg(func::CallOp callOp, PatternRewriter &rewriter)
+ {
+ assert(callOp.getNumResults() == 1 && "only one qreg result for qreg mode is allowed");
+
+ auto qreg = callOp.getResult(0);
+ assert(isa(qreg.getType()) && "only allow to have qreg result");
+
+ SmallVector newResults;
+ rewriter.setInsertionPointAfter(callOp);
+ for (const QubitIndex &index : signature.outQubitIndices) {
+ auto extractOp = rewriter.create(
+ callOp.getLoc(), rewriter.getType(), qreg, index.getValue(),
+ index.getAttr());
+ newResults.emplace_back(extractOp.getResult());
+ }
+ for (const QubitIndex &index : signature.outCtrlQubitIndices) {
+ auto extractOp = rewriter.create(
+ callOp.getLoc(), rewriter.getType(), qreg, index.getValue(),
+ index.getAttr());
+ newResults.emplace_back(extractOp.getResult());
+ }
+ return newResults;
+ }
+
+ private:
+ bool isValid = true;
+
+ struct OpSignature {
+ ValueRange params;
+ ValueRange inQubits;
+ ValueRange inCtrlQubits;
+ ValueRange inCtrlValues;
+ ValueRange outQubits;
+ ValueRange outCtrlQubits;
+
+ // Qreg mode specific information
+ Value sourceQreg = nullptr;
+ SmallVector inWireIndices;
+ SmallVector inCtrlWireIndices;
+ SmallVector outQubitIndices;
+ SmallVector outCtrlQubitIndices;
+ } signature;
+
+ Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc)
+ {
+ if (isa(type)) {
+ return rewriter.create(loc, type, values);
+ }
+ return values.front();
+ }
+
+ static size_t getElementsCount(Type type)
+ {
+ if (isa(type)) {
+ auto tensorType = cast(type);
+ return tensorType.getNumElements() > 0 ? tensorType.getNumElements() : 1;
+ }
+ return 1;
+ }
+
+ // Helper function to find the range of function input types that correspond to params
+ static std::pair findParamTypeRange(ArrayRef funcInputs,
+ size_t sigParamCount, size_t startIdx = 0)
+ {
+ size_t paramTypeCount = 0;
+ size_t paramTypeEnd = startIdx;
+
+ while (paramTypeCount < sigParamCount) {
+ assert(paramTypeEnd < funcInputs.size() &&
+ "param type end should be less than function input size");
+ paramTypeCount += getElementsCount(funcInputs[paramTypeEnd]);
+ paramTypeEnd++;
+ }
+
+ assert(paramTypeCount == sigParamCount &&
+ "param type count should be equal to signature param count");
+
+ return {startIdx, paramTypeEnd};
+ }
+
+ // generate params for calling the decomposition function based on function type requirements
+ SmallVector generateParams(ValueRange signatureParams, ArrayRef funcParamTypes,
+ PatternRewriter &rewriter, Location loc)
+ {
+ SmallVector operands;
+ size_t sigParamIdx = 0;
+
+ for (Type funcParamType : funcParamTypes) {
+ const size_t numElements = getElementsCount(funcParamType);
+
+ // collect numElements of signature params
+ SmallVector tensorElements;
+ for (size_t i = 0; i < numElements && sigParamIdx < signatureParams.size(); i++) {
+ tensorElements.push_back(signatureParams[sigParamIdx++]);
+ }
+ operands.push_back(fromTensorOrAsIs(tensorElements, funcParamType, rewriter, loc));
+ }
+
+ return operands;
+ }
+
+ Value fromTensorOrAsIs(ArrayRef indices, Type type, PatternRewriter &rewriter,
+ Location loc)
+ {
+ SmallVector values;
+ for (const QubitIndex &index : indices) {
+ if (index.isValue()) {
+ values.emplace_back(index.getValue());
+ }
+ else if (index.isAttr()) {
+ auto attr = index.getAttr();
+ auto constantValue = rewriter.create(loc, attr.getType(), attr);
+ values.emplace_back(constantValue);
+ }
+ }
+
+ if (isa(type)) {
+ return rewriter.create(loc, type, values);
+ }
+
+ assert(values.size() == 1 && "number of values should be 1 for non-tensor type");
+ return values.front();
+ }
+
+ Value getSourceQreg(Value qubit)
+ {
+ while (qubit) {
+ if (auto extractOp = qubit.getDefiningOp()) {
+ return extractOp.getQreg();
+ }
+
+ if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) {
+ if (customOp.getQubitOperands().empty()) {
+ break;
+ }
+ qubit = customOp.getQubitOperands()[0];
+ }
+ }
+
+ return nullptr;
+ }
+
+ QubitIndex getExtractIndex(Value qubit)
+ {
+ while (qubit) {
+ if (auto extractOp = qubit.getDefiningOp()) {
+ if (Value idx = extractOp.getIdx()) {
+ return QubitIndex(idx);
+ }
+ if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) {
+ return QubitIndex(idxAttr);
+ }
+ }
+
+ if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) {
+ auto qubitOperands = customOp.getQubitOperands();
+ auto qubitResults = customOp.getQubitResults();
+ auto it =
+ llvm::find_if(qubitResults, [&](Value result) { return result == qubit; });
+
+ if (it != qubitResults.end()) {
+ size_t resultIndex = std::distance(qubitResults.begin(), it);
+ if (resultIndex < qubitOperands.size()) {
+ qubit = qubitOperands[resultIndex];
+ continue;
+ }
+ }
+ }
+
+ break;
+ }
+
+ return QubitIndex();
+ }
+};
+
+struct DecomposeLoweringRewritePattern : public OpRewritePattern {
+ private:
+ const llvm::StringMap &decompositionRegistry;
+ const llvm::StringSet &targetGateSet;
+
+ public:
+ DecomposeLoweringRewritePattern(MLIRContext *context,
+ const llvm::StringMap ®istry,
+ const llvm::StringSet &gateSet)
+ : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet)
+ {
+ }
+
+ LogicalResult matchAndRewrite(CustomOp op, PatternRewriter &rewriter) const override
+ {
+ StringRef gateName = op.getGateName();
+
+ // Only decompose the op if it is not in the target gate set
+ if (targetGateSet.contains(gateName)) {
+ return failure();
+ }
+
+ // Find the corresponding decomposition function for the op
+ auto it = decompositionRegistry.find(gateName);
+ if (it == decompositionRegistry.end()) {
+ return failure();
+ }
+ func::FuncOp decompFunc = it->second;
+
+
+ ModuleOp parentModule = op->getParentOfType();
+ assert(parentModule && "expected parent module for custom op");
+
+ // If the decomposition function is already in the parent module, use the local clone
+ // otherwise, clone the function and insert it into the parent module
+ if (func::FuncOp localClone =
+ parentModule.lookupSymbol(decompFunc.getSymNameAttr().getValue())) {
+ decompFunc = localClone;
+ }
+ else {
+ PatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToEnd(parentModule.getBody());
+ func::FuncOp clonedFunc = decompFunc.clone();
+ rewriter.insert(clonedFunc);
+ decompFunc = clonedFunc;
+ }
+
+ // Here is the assumption that the decomposition function must have at least one input and
+ // one result
+ assert(decompFunc.getFunctionType().getNumInputs() > 0 &&
+ "Decomposition function must have at least one input");
+ assert(decompFunc.getFunctionType().getNumResults() >= 1 &&
+ "Decomposition function must have at least one result");
+
+ auto enableQreg = isa(decompFunc.getFunctionType().getInput(0));
+ auto analyzer = OpSignatureAnalyzer(op, enableQreg);
+ assert(analyzer && "Analyzer should be valid");
+
+ rewriter.setInsertionPointAfter(op);
+ auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
+ auto callOp =
+ rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(),
+ decompFunc.getSymName(), callOperands);
+
+ // Replace the op with the call op and adjust the insert ops for the qreg mode
+ if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) {
+ auto results = analyzer.prepareCallResultForQreg(callOp, rewriter);
+ rewriter.replaceOp(op, results);
+ }
+ else {
+ rewriter.replaceOp(op, callOp->getResults());
+ }
+
+ return success();
+ }
+};
+
+void populateDecomposeLoweringPatterns(RewritePatternSet &patterns,
+ const llvm::StringMap &decompositionRegistry,
+ const llvm::StringSet &targetGateSet)
+{
+ patterns.add(patterns.getContext(), decompositionRegistry,
+ targetGateSet);
+}
+
+} // namespace quantum
+} // namespace catalyst
diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp
new file mode 100644
index 0000000000..80bee07a2d
--- /dev/null
+++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp
@@ -0,0 +1,239 @@
+// Copyright 2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#define DEBUG_TYPE "decompose-lowering"
+
+// When we read the decomposition rules module from file,
+// StablehloDialect may not be registered from start.
+#include "stablehlo/dialect/StablehloOps.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/AllocatorBase.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "Quantum/IR/QuantumOps.h"
+#include "Quantum/Transforms/Patterns.h"
+
+using namespace mlir;
+using namespace catalyst::quantum;
+
+namespace catalyst {
+namespace quantum {
+#define GEN_PASS_DEF_DECOMPOSELOWERINGPASS
+#define GEN_PASS_DECL_DECOMPOSELOWERINGPASS
+#include "Quantum/Transforms/Passes.h.inc"
+
+namespace DecompUtils {
+
+static constexpr StringRef target_gate_attr_name = "target_gate";
+static constexpr StringRef decomp_gateset_attr_name = "decomp_gateset";
+
+// Helper function to load MLIR module from file path
+OwningOpRef loadMLIRModule(StringRef filePath, MLIRContext *context)
+{
+ // Parse MLIR file directly
+ OwningOpRef module = parseSourceFile(filePath, context);
+ if (!module) {
+ llvm::errs() << "Failed to parse MLIR file: " << filePath << "\n";
+ return nullptr;
+ }
+ return module;
+}
+
+// Check if a function is a decomposition function
+// It's expected that the decomposition function would have this attribute:
+// `catalyst.decomposition.target_op` And this attribute is set by the `markDecompositionAttributes`
+// functionq The decomposition attribute are used to determine if a function is a decomposition
+// function, and target_op is that the decomposition function want to replace
+bool isDecompositionFunction(func::FuncOp func) { return func->hasAttr(target_gate_attr_name); }
+
+StringRef getTargetGateName(func::FuncOp func)
+{
+ if (auto target_op_attr = func->getAttrOfType(target_gate_attr_name)) {
+ return target_op_attr.getValue();
+ }
+ return StringRef{};
+}
+
+} // namespace DecompUtils
+
+/// A module pass that work through a module, register all decomposition functions, and apply the
+/// decomposition patterns
+struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase {
+ using DecomposeLoweringPassBase::DecomposeLoweringPassBase;
+
+ void getDependentDialects(DialectRegistry ®istry) const override
+ {
+ registry.insert();
+ registry.insert();
+ registry.insert();
+ registry.insert();
+ registry.insert();
+ registry.insert();
+ }
+
+ private:
+ llvm::StringMap decompositionRegistry;
+ llvm::StringSet targetGateSet;
+
+ // Function to discover and register decomposition functions from a module
+ // It's bookkeeping the targetOp and the decomposition function that can decompose the targetOp
+ void discoverAndRegisterDecompositions(ModuleOp module,
+ llvm::StringMap &decompositionRegistry)
+ {
+ module.walk([&](func::FuncOp func) {
+ if (StringRef targetOp = DecompUtils::getTargetGateName(func); !targetOp.empty()) {
+ decompositionRegistry[targetOp] = func;
+ }
+ // No need to walk into the function body
+ return WalkResult::skip();
+ });
+ }
+
+ // Find the target gate set from the module.It's expected that the decomposition function would
+ // have this attribute: `decomp_gateset` And this attribute is set by the frontend, it contains
+ // the target gate set that the circuit function want to finally decompose into. Since each
+ // module only contains one circuit function, we can just find the target gate set from the
+ // function with the `decomp_gateset` attribute
+ void findTargetGateSet(ModuleOp module, llvm::StringSet &targetGateSet)
+ {
+ module.walk([&](func::FuncOp func) {
+ if (auto gate_set_attr =
+ func->getAttrOfType(DecompUtils::decomp_gateset_attr_name)) {
+ for (auto gate : gate_set_attr.getValue()) {
+ StringRef gate_name = cast(gate).getValue();
+ targetGateSet.insert(gate_name);
+ }
+ return WalkResult::interrupt();
+ }
+ // No need to walk into the function body
+ return WalkResult::skip();
+ });
+ }
+
+ // Remove unused decomposition functions:
+ // Since the decomposition functions are marked as public from the frontend,
+ // there is no way to remove them with any DCE pass automatically.
+ // So we need to manually remove them from the module
+ void removeDecompositionFunctions(ModuleOp module,
+ llvm::StringMap &decompositionRegistry)
+ {
+ llvm::DenseSet usedDecompositionFunctions;
+
+ module.walk([&](func::CallOp callOp) {
+ if (auto targetFunc = module.lookupSymbol(callOp.getCallee())) {
+ if (DecompUtils::isDecompositionFunction(targetFunc)) {
+ usedDecompositionFunctions.insert(targetFunc);
+ }
+ }
+ });
+
+ // remove unused decomposition functions
+ module.walk([&](func::FuncOp func) {
+ if (DecompUtils::isDecompositionFunction(func) &&
+ !usedDecompositionFunctions.contains(func)) {
+ func.erase();
+ }
+ return WalkResult::skip();
+ });
+ }
+
+ public:
+ void runOnOperation() final
+ {
+ ModuleOp module = cast(getOperation());
+
+ ModuleOp decompRuleModule = module;
+ OwningOpRef parsedModule;
+ if (!rulesPath.empty()) {
+ parsedModule = DecompUtils::loadMLIRModule(rulesPath, &getContext());
+ if (!parsedModule) {
+ return signalPassFailure();
+ }
+ decompRuleModule = parsedModule.get();
+ }
+
+ // Step 1: Discover and register all decomposition functions from the rules module if
+ // the rules path is provided; otherwise, we use the main module as the rules module
+ discoverAndRegisterDecompositions(decompRuleModule, decompositionRegistry);
+ if (decompositionRegistry.empty()) {
+ return;
+ }
+
+ // Step 1.1: Find the target gate set
+ findTargetGateSet(module, targetGateSet);
+
+ // Step 2: Canonicalize the module
+ RewritePatternSet patternsCanonicalization(&getContext());
+ catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization,
+ &getContext());
+ if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) {
+ return signalPassFailure();
+ }
+
+ // Step 3: Apply the decomposition patterns
+ RewritePatternSet decompositionPatterns(&getContext());
+ populateDecomposeLoweringPatterns(decompositionPatterns, decompositionRegistry,
+ targetGateSet);
+ if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) {
+ return signalPassFailure();
+ }
+
+ // Step 4: Inline and canonicalize/CSE the module again
+ PassManager pm(&getContext());
+ pm.addPass(createInlinerPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ if (failed(pm.run(module))) {
+ return signalPassFailure();
+ }
+
+ // Step 5. Remove redundant decomposition functions
+ removeDecompositionFunctions(module, decompositionRegistry);
+
+ // Step 6. Canonicalize the extract/insert pair
+ RewritePatternSet patternsInsertExtract(&getContext());
+ catalyst::quantum::InsertOp::getCanonicalizationPatterns(patternsInsertExtract,
+ &getContext());
+ catalyst::quantum::ExtractOp::getCanonicalizationPatterns(patternsInsertExtract,
+ &getContext());
+ if (failed(applyPatternsGreedily(module, std::move(patternsInsertExtract)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace quantum
+
+std::unique_ptr createDecomposeLoweringPass()
+{
+ return std::make_unique();
+}
+
+} // namespace catalyst
diff --git a/mlir/test/Quantum/CanonicalizationTest.mlir b/mlir/test/Quantum/CanonicalizationTest.mlir
index 4c698ed4e7..4b9620575d 100644
--- a/mlir/test/Quantum/CanonicalizationTest.mlir
+++ b/mlir/test/Quantum/CanonicalizationTest.mlir
@@ -83,8 +83,7 @@ func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2:
%q2 = quantum.extract %r2[0] : !quantum.reg -> !quantum.bit
%r3 = quantum.insert %r2[%i1], %q2 : !quantum.reg, !quantum.bit
- // CHECK: quantum.extract
- // CHECK: quantum.insert
+
%q3 = quantum.extract %r3[%i1] : !quantum.reg -> !quantum.bit
%r4 = quantum.insert %r3[%i2], %q3 : !quantum.reg, !quantum.bit
@@ -167,14 +166,14 @@ func.func @test_interleaved_extract_insert() -> tensor<4xf64> {
// CHECK: [[QBIT:%.+]] = quantum.extract [[QREG:%.+]][
// CHECK: [[QBIT_1:%.+]] = quantum.custom "Hadamard"() [[QBIT]]
// CHECK: [[QREG_1:%.+]] = quantum.insert [[QREG]]
- // CHECK-NOT: quantum.insert
- // COM: check that insert op canonicalization correctly removes unnecessary extract/inserts
+ // CHECK-NOT: quantum.insert
+ // COM: check that insert op canonicalization correctly removes unnecessary extract/inserts
// CHECK: quantum.compbasis qreg [[QREG_1]]
%1 = quantum.extract %0[%c0_i64] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit
%2 = quantum.extract %0[%c1_i64] : !quantum.reg -> !quantum.bit
- %3 = quantum.insert %0[%c0_i64], %out_qubits : !quantum.reg, !quantum.bit
- %4 = quantum.insert %3[%c1_i64], %2 : !quantum.reg, !quantum.bit
+ %3 = quantum.insert %0[%c1_i64], %2 : !quantum.reg, !quantum.bit
+ %4 = quantum.insert %3[%c0_i64], %out_qubits : !quantum.reg, !quantum.bit
%5 = quantum.compbasis qreg %4 : !quantum.obs
%6 = quantum.probs %5 : tensor<4xf64>
quantum.dealloc %4 : !quantum.reg
diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir
new file mode 100644
index 0000000000..91bfbe7778
--- /dev/null
+++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir
@@ -0,0 +1,510 @@
+// Copyright 2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | FileCheck %s
+
+module @two_hadamards {
+ func.func public @test_two_hadamards() -> tensor<4xf64> {
+ %0 = quantum.alloc( 2) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
+ // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
+ // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg
+ // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit
+
+ // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit
+ // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit
+
+ // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit
+ // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit
+
+ // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.probs %3 : tensor<4xf64>
+ quantum.dealloc %2 : !quantum.reg
+ return %4 : tensor<4xf64>
+ }
+
+ // Decomposition function should be applied and removed from the module
+ // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
+ %cst = arith.constant 3.1415926535897931 : f64
+ %cst_0 = arith.constant 1.5707963267948966 : f64
+ %out_qubits = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit
+ %out_qubits_1 = quantum.custom "RY"(%cst_0) %out_qubits : !quantum.bit
+ return %out_qubits_1 : !quantum.bit
+ }
+}
+
+// -----
+
+// Test single Hadamard decomposition
+module @single_hadamard {
+ func.func @test_single_hadamard() -> !quantum.bit {
+ // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
+ // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
+ // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg
+ // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit
+ %0 = quantum.alloc( 1) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+
+ // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit
+ // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %2 = quantum.custom "Hadamard"() %1 : !quantum.bit
+
+ // CHECK: return [[QUBIT2]]
+ return %2 : !quantum.bit
+ }
+
+ // Decomposition function should be applied and removed from the module
+ // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
+ %cst = arith.constant 3.1415926535897931 : f64
+ %cst_0 = arith.constant 1.5707963267948966 : f64
+ %out_qubits = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit
+ %out_qubits_1 = quantum.custom "RY"(%cst_0) %out_qubits : !quantum.bit
+ return %out_qubits_1 : !quantum.bit
+ }
+}
+
+// -----
+module @recursive {
+ func.func public @test_recursive() -> tensor<4xf64> {
+ %0 = quantum.alloc( 2) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
+ // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
+ // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg
+ // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit
+
+ // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit
+ // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit
+
+ // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit
+ // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit
+
+ // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.probs %3 : tensor<4xf64>
+ quantum.dealloc %2 : !quantum.reg
+ return %4 : tensor<4xf64>
+ }
+
+ // Decomposition function should be applied and removed from the module
+ // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
+ %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit
+ return %out_qubits_0 : !quantum.bit
+ }
+
+ // Decomposition function should be applied and removed from the module
+ // CHECK-NOT: func.func private @RZRY_decomp
+ func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} {
+ %cst = arith.constant 3.1415926535897931 : f64
+ %cst_0 = arith.constant 1.5707963267948966 : f64
+ %out_qubits_1 = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit
+ %out_qubits_2 = quantum.custom "RY"(%cst_0) %out_qubits_1 : !quantum.bit
+ return %out_qubits_2 : !quantum.bit
+ }
+}
+
+// -----
+module @recursive {
+ func.func public @test_recursive() -> tensor<4xf64> {
+ %0 = quantum.alloc( 2) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64
+ // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64
+ // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg
+ // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit
+
+ // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit
+ // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit
+
+ // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit
+ // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit
+
+ // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.probs %3 : tensor<4xf64>
+ quantum.dealloc %2 : !quantum.reg
+ return %4 : tensor<4xf64>
+ }
+
+ // Decomposition function should be applied and removed from the module
+ // CHECK-NOT: func.func private @Hadamard_to_RY_decomp
+ func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} {
+ %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit
+ return %out_qubits_0 : !quantum.bit
+ }
+
+ // Decomposition function should be applied and removed from the module
+ // CHECK-NOT: func.func private @RZRY_decomp
+ func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} {
+ %cst = arith.constant 3.1415926535897931 : f64
+ %cst_0 = arith.constant 1.5707963267948966 : f64
+ %out_qubits_1 = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit
+ %out_qubits_2 = quantum.custom "RY"(%cst_0) %out_qubits_1 : !quantum.bit
+ return %out_qubits_2 : !quantum.bit
+ }
+}
+
+// -----
+
+// Test parametric gates and wires
+module @param_rxry {
+ func.func public @test_param_rxry(%arg0: tensor, %arg1: tensor) -> tensor<2xf64> {
+ %c0_i64 = arith.constant 0 : i64
+
+ // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg
+ %0 = quantum.alloc( 1) : !quantum.reg
+
+ // CHECK: [[WIRE:%.+]] = tensor.extract %arg1[] : tensor
+ %extracted = tensor.extract %arg1[] : tensor
+
+ // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit
+ %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit
+
+ // CHECK: [[PARAM:%.+]] = tensor.extract %arg0[] : tensor
+ %param_0 = tensor.extract %arg0[] : tensor
+
+ // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit
+ // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM]]) [[QUBIT1]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "ParametrizedRXRY"
+ %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0) %1 : !quantum.bit
+
+ // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.probs %3 : tensor<2xf64>
+ quantum.dealloc %2 : !quantum.reg
+ return %4 : tensor<2xf64>
+ }
+
+ // Decomposition function expects tensor while operation provides f64
+ // CHECK-NOT: func.func private @ParametrizedRX_decomp
+ func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: !quantum.bit) -> !quantum.bit
+ attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} {
+ %extracted = tensor.extract %arg0[] : tensor
+ %out_qubits = quantum.custom "RX"(%extracted) %arg1 : !quantum.bit
+ %extracted_0 = tensor.extract %arg0[] : tensor
+ %out_qubits_1 = quantum.custom "RY"(%extracted_0) %out_qubits : !quantum.bit
+ return %out_qubits_1 : !quantum.bit
+ }
+}
+// -----
+
+// Test parametric gates and wires
+module @param_rxry_2 {
+ func.func public @test_param_rxry_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xf64> {
+ %c0_i64 = arith.constant 0 : i64
+
+ // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg
+ %0 = quantum.alloc( 1) : !quantum.reg
+
+ // CHECK: [[WIRE:%.+]] = tensor.extract %arg2[] : tensor
+ %extracted = tensor.extract %arg2[] : tensor
+
+ // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit
+ %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit
+
+ // CHECK: [[PARAM_0:%.+]] = tensor.extract %arg0[] : tensor
+ %param_0 = tensor.extract %arg0[] : tensor
+
+ // CHECK: [[PARAM_1:%.+]] = tensor.extract %arg1[] : tensor
+ %param_1 = tensor.extract %arg1[] : tensor
+
+ // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM_0]]) [[QUBIT]] : !quantum.bit
+ // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM_1]]) [[QUBIT1]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "ParametrizedRXRY"
+ %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0, %param_1) %1 : !quantum.bit
+
+ // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.probs %3 : tensor<2xf64>
+ quantum.dealloc %2 : !quantum.reg
+ return %4 : tensor<2xf64>
+ }
+
+ // Decomposition function expects tensor while operation provides f64
+ // CHECK-NOT: func.func private @ParametrizedRX_decomp
+ func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: tensor, %arg2: !quantum.bit) -> !quantum.bit
+ attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} {
+ %extracted_param_0 = tensor.extract %arg0[] : tensor
+ %out_qubits = quantum.custom "RX"(%extracted_param_0) %arg2 : !quantum.bit
+ %extracted_param_1 = tensor.extract %arg1[] : tensor
+ %out_qubits_1 = quantum.custom "RY"(%extracted_param_1) %out_qubits : !quantum.bit
+ return %out_qubits_1 : !quantum.bit
+ }
+}
+// -----
+
+// Test recursive and qreg-based gate decomposition
+module @qreg_base_circuit {
+ func.func public @test_qreg_base_circuit() -> tensor<2xf64> {
+ // CHECK: [[CST:%.+]] = arith.constant 1.000000e+00 : f64
+ %cst = arith.constant 1.000000e+00 : f64
+
+ // CHECK: [[CST_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor
+ // CHECK: [[CST_1:%.+]] = arith.constant dense<0> : tensor<1xi64>
+ // CHECK: [[CST_2:%.+]] = arith.constant dense<1.000000e+00> : tensor
+ // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg
+ %0 = quantum.alloc( 1) : !quantum.reg
+
+ // CHECK: [[EXTRACT_QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit
+ // CHECK: [[MRES:%.+]], [[OUT_QUBIT:%.+]] = quantum.measure [[EXTRACT_QUBIT]] : i1, !quantum.bit
+ // CHECK: [[REG1:%.+]] = quantum.insert [[REG]][ 0], [[OUT_QUBIT]] : !quantum.reg, !quantum.bit
+ // CHECK: [[COMPARE:%.+]] = stablehlo.compare NE, [[CST_2]], [[CST_0]], FLOAT : (tensor