diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 999e5ad137..d622074d73 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -473,6 +473,9 @@ jobs: # macOS requirements.txt python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client + # Install graphviz for testing the mlir-op-graph integration + sudo apt-get install -y graphviz + python3 -m pip install graphviz make frontend - name: Get Cached LLVM Build @@ -556,6 +559,9 @@ jobs: sudo apt-get install -y libasan6 make python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} python3 -m pip install -r requirements.txt + # Install graphviz for testing the mlir-op-graph integration + sudo apt-get install -y graphviz + python3 -m pip install graphviz make frontend - name: Get Cached LLVM Build diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8e09946363..b135975266 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -7,6 +7,48 @@

Improvements πŸ› 

+* Add an experimental `outline_state_evolution_pass` xDSL pass to `catalyst.python_interface.transforms`, + which moves all quantum gate operations to a private callable. + [(#8367)](https://github.com/PennyLaneAI/pennylane/pull/8367) + +* A new experimental `split_non_commuting_pass` compiler pass has been added to + `catalyst.python_interface.transforms`. This pass splits quantum functions that + measure observables on the same wires into multiple function executions, where + each execution measures observables on different wires (using the "wires" grouping + strategy). The original function is replaced with calls to these generated functions, + and the results are combined appropriately. + [(#8531)](https://github.com/PennyLaneAI/pennylane/pull/8531) + +* Add the `PCPhaseOp` operation to the xDSL Quantum dialect. + [(#8621)](https://github.com/PennyLaneAI/pennylane/pull/8621) + +* Users can now apply xDSL passes without the need to pass the `pass_plugins` argument to + the `qjit` decorator. + [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572) + [(#8573)](https://github.com/PennyLaneAI/pennylane/pull/8573) + [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169) + [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183) + +* The :meth:`catalyst.python_interface.transforms.convert_to_mbqc_formalism_pass` now + supports :class:`~xdsl.dialects.scf.IndexSwitchOp` in IR and ignores regions that have no body. + [(#8632)](https://github.com/PennyLaneAI/pennylane/pull/8632) + +* The `convert_to_mbqc_formalism` compilation pass now outlines the operations to represent a gate + in the MBQC formalism into subroutines in order to reduce the IR size for large programs. + [(#8619)](https://github.com/PennyLaneAI/pennylane/pull/8619) + +* The :meth:`catalyst.python_interface.Compiler.run` method now accepts a string as input, + which is parsed and transformed with xDSL. + [(#8587)](https://github.com/PennyLaneAI/pennylane/pull/8587) + +* An `is_xdsl_pass` function has been added to the `catalyst.python_interface.pass_api` module. + This function checks if a pass name corresponds to an xDSL implemented pass. + [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572) + +* A new `catalyst.python_interface.utils` submodule has been added, containing general-purpose utilities for + working with xDSL. This includes a function that extracts the concrete value of scalar, constant SSA values. + [(#8514)](https://github.com/PennyLaneAI/pennylane/pull/8514) + * Pass instrumentation can be applied to each pass within the `NamedSequenceOp` transform sequence for a qnode. [(#1978)](https://github.com/PennyLaneAI/catalyst/pull/1978) @@ -35,11 +77,6 @@ * `qml.grad` and `qml.jacobian` can now be used with `qjit` when program capture is enabled. [(#2078)](https://github.com/PennyLaneAI/catalyst/pull/2078) -* xDSL passes are now automatically detected when using the `qjit` decorator. - This removes the need to pass the `pass_plugins` argument to the `qjit` decorator. - [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169) - [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183) - * The ``mlir_opt`` property now correctly handles xDSL passes by automatically detecting when the Python compiler is being used and routing through it appropriately. [(#2190)](https://github.com/PennyLaneAI/catalyst/pull/2190) @@ -72,6 +109,20 @@

Bug fixes πŸ›

+* The experimental xDSL :func:`~catalyst.python_interface.transforms.measurements_from_samples_pass` + pass has been updated to support `shots` defined by an `arith.constant` operation. + [(#8460)](https://github.com/PennyLaneAI/pennylane/pull/8460) + +* The experimental xDSL :func:`~catalyst.python_interface.transforms.diagonalize_measurements` + pass has been updated to fix a bug that included the wrong SSA value for final qubit insertion + and deallocation at the end of the circuit. A clear error is now also raised when there are + observables with overlapping wires. + [(#8383)](https://github.com/PennyLaneAI/pennylane/pull/8383) + +* Fixes a bug in the constructor of the xDSL Quantum dialect's `QubitUnitaryOp` that + prevented an instance from being constructed. + [(#8456)](https://github.com/PennyLaneAI/pennylane/pull/8456) + * Fixes an issue where a heap-to-stack allocation conversion pass was causing SIGSEGV issues during program execution at runtime. [(#2172)](https://github.com/PennyLaneAI/catalyst/pull/2172) @@ -122,6 +173,10 @@

Internal changes βš™οΈ

+* Migrated the `pennylane.compiler.python_compiler` submodule from PennyLane to Catalyst. + It is now accessible as `catalyst.python_interface`. + [(#2199)](https://github.com/PennyLaneAI/catalyst/pull/2199) + * Replaces the deprecated `shape_dtype_to_ir_type` function with the `RankedTensorType.get` method. [(#2159)](https://github.com/PennyLaneAI/catalyst/pull/2159) @@ -173,6 +228,11 @@

Documentation πŸ“

+* Added a "Unified Compiler Cookbook" RST file, along with tutorials, to `catalyst.python_interface.doc`, + which provides a quickstart guide for getting started with xDSL and its integration with PennyLane and + Catalyst. + [(#8571)](https://github.com/PennyLaneAI/pennylane/pull/8571) + * A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected. [(#2136)](https://github.com/PennyLaneAI/catalyst/pull/2136) diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index a61f4ccfe3..71537e8e53 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -376,14 +376,14 @@ def to_llvmir(*args, stdin=None, options: Optional[CompileOptions] = None): def to_mlir_opt( - *args, stdin=None, options: Optional[CompileOptions] = None, using_python_compiler=False + *args, stdin=None, options: Optional[CompileOptions] = None, using_unified_compiler=False ): """echo ${input} | catalyst --tool=opt *args *opts -""" # Check if we need to use Python compiler for xDSL passes - if using_python_compiler: + if using_unified_compiler: # Use Python compiler path for xDSL passes # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler import Compiler as PythonCompiler + from catalyst.python_interface import Compiler as PythonCompiler compiler = PythonCompiler() stdin = compiler.run(stdin, callback=None) @@ -542,7 +542,7 @@ def check_nested_operations(op): return False @debug_logger - def is_using_python_compiler(self, mlir_module=None): + def is_using_unified_compiler(self, mlir_module=None): """Returns true if we need the Python compiler path. This happens when: @@ -598,11 +598,11 @@ def run(self, mlir_module, *args, **kwargs): (str): filename of shared object """ - if self.is_using_python_compiler(mlir_module): + if self.is_using_unified_compiler(mlir_module): # We keep this module here to keep xDSL requirement optional # Only move this is it has been decided that xDSL is no longer optional. # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler import Compiler as PythonCompiler + from catalyst.python_interface import Compiler as PythonCompiler compiler = PythonCompiler() mlir_module = compiler.run(mlir_module) diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 8d7e27acfe..749a8b3125 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -378,9 +378,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin try: # pylint: disable=import-outside-toplevel - from pennylane.compiler.python_compiler.pass_api import ( - is_xdsl_pass, - ) + from catalyst.python_interface.pass_api import is_xdsl_pass if is_xdsl_pass(_pass.name): uses_xdsl_passes = True diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 81f873917c..62b0acbde2 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -582,13 +582,13 @@ def mlir_opt(self): """Obtain the MLIR representation after optimization""" if not self.mlir_module: return None - using_python_compiler = self.compiler.is_using_python_compiler(self.mlir_module) + using_unified_compiler = self.compiler.is_using_unified_compiler(self.mlir_module) stdin = self.mlir_module.operation.get_asm( - print_generic_op_form=using_python_compiler, + print_generic_op_form=using_unified_compiler, enable_debug_info=self.compile_options.use_nameloc, ) return to_mlir_opt( - stdin=stdin, options=self.compile_options, using_python_compiler=using_python_compiler + stdin=stdin, options=self.compile_options, using_unified_compiler=using_unified_compiler ) @debug_logger diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py new file mode 100644 index 0000000000..85fb8ceef5 --- /dev/null +++ b/frontend/catalyst/python_interface/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python Compiler API for integration of Catalyst with xDSL.""" + +from .compiler import Compiler +from .parser import QuantumParser +from .pass_api import compiler_transform +from .visualization import QMLCollector + +__all__ = [ + "Compiler", + "compiler_transform", + "QuantumParser", + "QMLCollector", +] diff --git a/frontend/catalyst/python_interface/compiler.py b/frontend/catalyst/python_interface/compiler.py new file mode 100644 index 0000000000..f98ff2b8d8 --- /dev/null +++ b/frontend/catalyst/python_interface/compiler.py @@ -0,0 +1,83 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the PennyLane-xDSL integration API.""" + + +import io + +from jax._src.interpreters import mlir +from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.ir import Context as jaxContext +from jaxlib.mlir.ir import Module as jaxModule +from pennylane.typing import Callable +from xdsl.context import Context as xContext +from xdsl.dialects.builtin import ModuleOp +from xdsl.passes import ModulePass, PassPipeline +from xdsl.printer import Printer + +from catalyst.python_interface.parser import QuantumParser +from catalyst.python_interface.pass_api import ApplyTransformSequence + + +# pylint: disable=too-few-public-methods +class Compiler: + """Compiler namespace""" + + @staticmethod + def run( + module: jaxModule | str, + callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = None, + ) -> jaxModule | str: + """Runs the apply-transform-sequence pass. + + The apply-transform-sequence pass is a "meta-pass". In other words, + it is a pass that runs other passes. + + Args: + module: Either a Jax MLIR module or MLIR IR as a string + callback: Optional callback function called between passes + + Returns: + jaxModule | str: jaxModule if the input was a jaxModule, else a string. + """ + # Convert to generic text format + is_jax_module = isinstance(module, jaxModule) + if is_jax_module: + gentxtmod = module.operation.get_asm( + binary=False, print_generic_op_form=True, assume_verified=True + ) + else: + gentxtmod = module + + # Parse and transform with xDSL + ctx = xContext(allow_unregistered=True) + parser = QuantumParser(ctx, gentxtmod) + # xmod is modified in place + xmod = parser.parse_module() + pipeline = PassPipeline((ApplyTransformSequence(callback=callback),)) + pipeline.apply(ctx, xmod) + + # Convert back to string + buffer = io.StringIO() + Printer(stream=buffer, print_generic_format=True).print_op(xmod) + + # Convert back to jaxModule if input was jaxModule + if is_jax_module: + with jaxContext() as jctx: + jctx.allow_unregistered_dialects = True + jctx.append_dialect_registry(mlir.upstream_dialects) + stablehlo.register_dialect(jctx) # pylint: disable=no-member + newmod: jaxModule = jaxModule.parse(buffer.getvalue()) + return newmod + return buffer.getvalue() diff --git a/frontend/catalyst/python_interface/conversion.py b/frontend/catalyst/python_interface/conversion.py new file mode 100644 index 0000000000..975a46b3cb --- /dev/null +++ b/frontend/catalyst/python_interface/conversion.py @@ -0,0 +1,171 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for converting to xDSL module.""" + +from collections.abc import Callable, Sequence +from functools import wraps +from typing import TypeAlias + +from jax._src.lib import _jax +from jaxlib.mlir.dialects import stablehlo as jstablehlo +from jaxlib.mlir.ir import Context as jContext +from jaxlib.mlir.ir import Module as jModule +from xdsl.context import Context as xContext +from xdsl.dialects import builtin as xbuiltin +from xdsl.dialects import func as xfunc +from xdsl.ir import Dialect as xDialect +from xdsl.traits import SymbolTable as xSymbolTable + +from catalyst import QJIT +from catalyst.python_interface.parser import QuantumParser + +JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member + + +def _mlir_module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule: + """Get the MLIR module from a jax.jitted function""" + return func.lower(*args, **kwargs).compiler_ir() + + +def mlir_module(func: JaxJittedFunction) -> Callable[..., jModule]: + """Returns a wrapper that creates an MLIR module from a jax.jitted function.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> jModule: + return _mlir_module_inline(func, *args, **kwargs) + + return wrapper + + +def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover + """Create the generic textual representation for a jax.jitted function""" + lowered = func.lower(*args, **kwargs) + mod = lowered.compiler_ir() + return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True) + + +def generic_str(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover + """Returns a wrapper that creates the generic textual representation for a + jax.jitted function.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> str: + return _generic_str_inline(func, *args, **kwargs) + + return wrapper + + +def parse_generic_to_xdsl_module( + program: str, extra_dialects: Sequence[xDialect] | None = None +) -> xbuiltin.ModuleOp: # pragma: no cover + """Parses a generic MLIR program string to an xDSL module.""" + ctx = xContext(allow_unregistered=True) + parser = QuantumParser(ctx, program, extra_dialects=extra_dialects) + moduleOp: xbuiltin.ModuleOp = parser.parse_module() + return moduleOp + + +def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover + """Parses a generic MLIR program string to an MLIR module.""" + with jContext() as ctx: + ctx.allow_unregistered_dialects = True + jstablehlo.register_dialect(ctx) # pylint: disable=no-member + return jModule.parse(program) + + +def mlir_from_docstring(func: Callable) -> jModule: # pragma: no cover + """Returns a wrapper that parses an MLIR program string located in the docstring + into an MLIR module.""" + + @wraps(func) + def wrapper(*_, **__): + return parse_generic_to_mlir_module(func.__doc__) + + return wrapper + + +def _xdsl_module_inline( + func: JaxJittedFunction, *args, **kwargs +) -> xbuiltin.ModuleOp: # pragma: no cover + """Get the xDSL module from a jax.jitted function""" + generic_repr = _generic_str_inline(func, *args, **kwargs) + return parse_generic_to_xdsl_module(generic_repr) + + +def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover + """Returns a wrapper that parses an MLIR program string located in the docstring + into an xDSL module.""" + + @wraps(func) + def wrapper(*_, **__): + return parse_generic_to_xdsl_module(func.__doc__) + + return wrapper + + +def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover + """Returns a wrapper that creates an xDSL module from a jax.jitted function.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp: + return _xdsl_module_inline(func, *args, **kwargs) + + return wrapper + + +def inline_module( + from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None +) -> None: + """Inline the contents of one xDSL module into another xDSL module. The inlined body is appended + to the end of ``to_mod``. + + If ``from_mod`` has a ``main`` function, its name is changed to ``change_main_to`` if specified. + """ + if change_main_to: + main = xSymbolTable.lookup_symbol(from_mod, "main") + if main is not None: + assert isinstance(main, xfunc.FuncOp) + main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to) + + for op in from_mod.body.ops: + xSymbolTable.insert_or_update(to_mod, op.clone()) + + +def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp) -> Callable[..., None]: + """Inline a ``jax.jit``-ed Python function to an xDSL module. The inlined body is appended + to the end of ``mod`` in-place. The name of the entry point function of ``func`` is the same + as the name of ``func``.""" + + @wraps(func) + def wrapper(*args, **kwargs): + func_mod = _xdsl_module_inline(func, *args, **kwargs) + inline_module(func_mod, mod, change_main_to=func.__name__) + + return wrapper + + +def xdsl_from_qjit(func: QJIT) -> Callable[..., xbuiltin.ModuleOp]: + """Decorator to convert QJIT-ed functions into xDSL modules.""" + + @wraps(func) + def wrapper(*args, **kwargs): + func.jaxpr, *_ = func.capture(args, **kwargs) + _mlir_module = func.generate_ir() + _generic_str = _mlir_module.operation.get_asm( + binary=False, print_generic_op_form=True, assume_verified=True + ) + return parse_generic_to_xdsl_module(_generic_str) + + return wrapper diff --git a/frontend/catalyst/python_interface/dialects/__init__.py b/frontend/catalyst/python_interface/dialects/__init__.py new file mode 100644 index 0000000000..1b0ab7d1a9 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This submodule contains xDSL dialects for the unified compiler.""" + +from .catalyst import Catalyst +from .mbqc import MBQC +from .qec import QEC +from .quantum import Quantum +from .stablehlo import StableHLO +from .transform import Transform + +__all__ = ["Catalyst", "MBQC", "Quantum", "QEC", "StableHLO", "Transform"] diff --git a/frontend/catalyst/python_interface/dialects/catalyst.py b/frontend/catalyst/python_interface/dialects/catalyst.py new file mode 100644 index 0000000000..b42b256c37 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/catalyst.py @@ -0,0 +1,268 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the Catalyst dialect for the Python compiler. + +This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool) +and modified manually to support the unified compiler. + +The catalyst dialect serves as a standard library for the Catalyst compiler. +It contains data structures that support core compiler functionality. +""" + +from typing import ClassVar + +from xdsl.dialects.builtin import ( + I64, + ArrayAttr, + DenseArrayBase, + DictionaryAttr, + FlatSymbolRefAttrConstr, + FunctionType, + IntegerAttr, + IntegerType, + MemRefType, + StringAttr, + SymbolRefAttr, + UnitAttr, + i32, +) +from xdsl.ir import AttributeCovT, Dialect, Generic, ParametrizedAttribute, TypeAttribute +from xdsl.irdl import ( + AnyAttr, + IRDLOperation, + ParsePropInAttrDict, + VarConstraint, + base, + irdl_attr_definition, + irdl_op_definition, + operand_def, + opt_operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + var_operand_def, + var_result_def, +) + + +@irdl_attr_definition +class ArrayListType(Generic[AttributeCovT], ParametrizedAttribute, TypeAttribute): + """A dynamically resizable array""" + + name = "catalyst.arraylist" + + element_type: AttributeCovT + + +@irdl_op_definition +class AssertionOp(IRDLOperation): + """Asserts condition at runtime.""" + + name = "catalyst.assert" + + assertion = operand_def(IntegerType(1)) + + error = prop_def(StringAttr) + + +@irdl_op_definition +class CallbackCallOp(IRDLOperation): + """CallbackCallOp operation.""" + + name = "catalyst.callback_call" + + assembly_format = """ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + """ + + callee = prop_def(FlatSymbolRefAttrConstr) + + inputs = var_operand_def() + + arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + callback_results = var_result_def() + + irdl_options = [ParsePropInAttrDict()] + + +@irdl_op_definition +class CallbackOp(IRDLOperation): + """Operation denoting a symbol to refer to user callbacks.""" + + name = "catalyst.callback" + + sym_name = prop_def(StringAttr) + + function_type = prop_def(FunctionType) + + id = prop_def(IntegerAttr[I64]) + + argc = prop_def(IntegerAttr[I64]) + + resc = prop_def(IntegerAttr[I64]) + + arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + body = region_def() + + +@irdl_op_definition +class CustomCallOp(IRDLOperation): + """CustomCall operation""" + + name = "catalyst.custom_call" + + assembly_format = """ + `fn` `(`$call_target_name`)` `(` $inputs `)` + attr-dict `:` functional-type(operands, results) + """ + + inputs = var_operand_def() + + call_target_name = prop_def(StringAttr) + + number_original_arg = opt_prop_def(DenseArrayBase.constr(i32)) + + custom_results = var_result_def() + + irdl_options = [ParsePropInAttrDict()] + + +@irdl_op_definition +class LaunchKernelOp(IRDLOperation): + """LaunchKernelOp operation.""" + + name = "catalyst.launch_kernel" + + assembly_format = """ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + """ + + callee = prop_def(SymbolRefAttr) + + inputs = var_operand_def() + + arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + kernel_results = var_result_def() + + irdl_options = [ParsePropInAttrDict()] + + +@irdl_op_definition +class ListDeallocOp(IRDLOperation): + """Deallocate the underlying memory of an arraylist.""" + + name = "catalyst.list_dealloc" + + assembly_format = """ $list attr-dict `:` type($list) """ + + list = operand_def(ArrayListType) + + +@irdl_op_definition +class ListInitOp(IRDLOperation): + """Initialize a dynamically resizable arraylist.""" + + name = "catalyst.list_init" + + assembly_format = """ attr-dict `:` type($list) """ + + list = result_def(ArrayListType) + + +@irdl_op_definition +class ListLoadDataOp(IRDLOperation): + """Get the underlying memref storing the data of an array list.""" + + name = "catalyst.list_load_data" + + assembly_format = """ $list attr-dict `:` type($list) `->` type($data) """ + + list = operand_def(ArrayListType) + + data = result_def(MemRefType) + + +@irdl_op_definition +class ListPopOp(IRDLOperation): + """Remove an element from the end of an array list and return it.""" + + name = "catalyst.list_pop" + + assembly_format = """ $list attr-dict `:` type($list) """ + + T: ClassVar = VarConstraint("T", AnyAttr()) + + list = operand_def(base(ArrayListType[T])) + + result = result_def(T) + + +@irdl_op_definition +class ListPushOp(IRDLOperation): + """Append an element to the end of an array list.""" + + name = "catalyst.list_push" + + assembly_format = """ $value `,` $list attr-dict `:` type($list) """ + + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + + list = operand_def(base(ArrayListType[T])) + + +@irdl_op_definition +class PrintOp(IRDLOperation): + """Prints numeric values or constant strings at runtime.""" + + name = "catalyst.print" + + val = opt_operand_def() + + const_val = opt_prop_def(StringAttr) + + print_descriptor = prop_def(UnitAttr) + + +Catalyst = Dialect( + "catalyst", + [ + AssertionOp, + CallbackCallOp, + CallbackOp, + CustomCallOp, + LaunchKernelOp, + ListDeallocOp, + ListInitOp, + ListLoadDataOp, + ListPopOp, + ListPushOp, + PrintOp, + ], + [ + ArrayListType, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/mbqc.py b/frontend/catalyst/python_interface/dialects/mbqc.py new file mode 100644 index 0000000000..b608e037fb --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/mbqc.py @@ -0,0 +1,174 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the definition of the MBQC dialect for the Python compiler. + +The MBQC dialect is a set of operations and types used to represent measurement-based +quantum-computing instructions in the xDSL framework. + +It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the +catalyst/mlir/include/MBQC/IR/MBQCDialect.td file in the catalyst repository. + +For detailed documentation on the operations contained in this dialect, please refer to the MBQC +dialect documentation in Catalyst. +""" + +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + I32, + AnyAttr, + Float64Type, + IntegerAttr, + IntegerType, + StringAttr, + i1, +) +from xdsl.ir import Dialect, EnumAttribute, Operation, SpacedOpaqueSyntaxAttribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_attr_definition, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + result_def, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.str_enum import StrEnum # StrEnum is standard in Python>=3.11 + +from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint + +from .quantum import QubitType, QuregType + +QubitSSAValue: TypeAlias = SSAValue[QubitType] + + +class MeasurementPlaneEnum(StrEnum): + """Enum containing supported measurement-plane attributes""" + + XY = "XY" + YZ = "YZ" + ZX = "ZX" + + +@irdl_attr_definition +class MeasurementPlaneAttr(EnumAttribute[MeasurementPlaneEnum], SpacedOpaqueSyntaxAttribute): + """Planes in the Bloch sphere representation with support for arbitrary-basis measurements""" + + name = "mbqc.measurement_plane" + + +@irdl_op_definition +class MeasureInBasisOp(IRDLOperation): + """A parametric single-qubit projective measurement in an arbitrary basis.""" + + name = "mbqc.measure_in_basis" + + assembly_format = """ + `[` $plane `,` $angle `]` $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results) + """ + + in_qubit = operand_def(QubitType) + + plane = prop_def(MeasurementPlaneAttr) + + angle = operand_def(Float64Type()) + + postselect = opt_prop_def(IntegerAttr[I32]) + + mres = result_def(IntegerType(1)) + + out_qubit = result_def(QubitType) + + def __init__( + self, + in_qubit: QubitSSAValue | Operation, + plane: MeasurementPlaneAttr, + angle: SSAValue[Float64Type], + postselect: int | IntegerAttr | None = None, + ): + properties = {"plane": plane} + + if isinstance(postselect, int): + postselect = IntegerAttr.from_int_and_width(postselect, 32) + + if postselect is not None: + properties["postselect"] = postselect + + super().__init__( + operands=(in_qubit, angle), + properties=properties, + result_types=(IntegerType(1), QubitType()), + ) + + def verify_(self): + """Verify operation when rewriting.""" + if self.postselect is None: + return + + if self.postselect.value.data not in [0, 1]: # pylint: disable=no-member + raise VerifyException("'postselect' must be 0 or 1.") + + +@irdl_op_definition +class GraphStatePrepOp(IRDLOperation): + """Allocate resources for a new graph state.""" + + name = "mbqc.graph_state_prep" + + assembly_format = """ + `(` $adj_matrix `:` type($adj_matrix) `)` `[` `init` $init_op `,` `entangle` $entangle_op `]` attr-dict `:` type(results) + """ + + adj_matrix = operand_def( + TensorConstraint(element_type=i1, rank=1) | MemRefConstraint(element_type=i1, rank=1) + ) + + init_op = prop_def(StringAttr) + + entangle_op = prop_def(StringAttr) + + qreg = result_def(QuregType) + + def __init__( + self, adj_matrix: AnyAttr, init_op: str | StringAttr, entangle_op: str | StringAttr + ): + if isinstance(init_op, str): + init_op = StringAttr(data=init_op) + + if isinstance(entangle_op, str): + entangle_op = StringAttr(data=entangle_op) + + properties = {"init_op": init_op, "entangle_op": entangle_op} + + qreg = QuregType() + + super().__init__( + operands=(adj_matrix,), + result_types=(qreg,), + properties=properties, + ) + + +MBQC = Dialect( + "mbqc", + [ + MeasureInBasisOp, + GraphStatePrepOp, + ], + [ + MeasurementPlaneAttr, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/qec.py b/frontend/catalyst/python_interface/dialects/qec.py new file mode 100644 index 0000000000..be5ba288f3 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/qec.py @@ -0,0 +1,216 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the QEC dialect for the Python compiler. + +The QEC dialect is a set of operations and types used to represent quantum error correction +instructions in the xDSL framework. + +It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the +catalyst/mlir/include/QEC/IR/QECDialect.td file in the catalyst repository. +""" + +from xdsl.dialects.builtin import I16, ArrayAttr, IntegerAttr, IntegerType, StringAttr, i16 +from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.ir import Attribute, Dialect, EnumAttribute, SpacedOpaqueSyntaxAttribute +from xdsl.irdl import ( + AttrSizedOperandSegments, + IRDLOperation, + irdl_attr_definition, + irdl_op_definition, + lazy_traits_def, + operand_def, + opt_operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import HasParent, IsTerminator, Pure, SingleBlockImplicitTerminator +from xdsl.utils.str_enum import StrEnum + +from .quantum import QubitType + + +class LogicalInitKind(StrEnum): + """The initial state of a logical qubit such as |0⟩, |1⟩, |+⟩, |βˆ’βŸ©, |Y⟩, |-Y⟩, |m⟩, or |mΜ…βŸ©.""" + + Zero = "zero" # |0⟩ Non-magic state + One = "one" # |1⟩ Non-magic state + Plus = "plus" # |+⟩ = (|0⟩ + |1⟩) / sqrt(2) Non-magic state + Minus = "minus" # |-⟩ = (|0⟩ - |1⟩) / sqrt(2) Non-magic state + PlusI = "plus_i" # |Y⟩ = (|0⟩ + i|1⟩) / sqrt(2) Non-magic / Magic state + MinusI = "minus_i" # |-Y⟩ = (|0⟩ - i|1⟩) / sqrt(2) Non-magic / Magic state + Magic = "magic" # |m⟩ = |0⟩ + e^{iΟ€/4}|1⟩ Magic state + MagicConj = "magic_conj" # |mΜ…βŸ© = |0⟩ + e^{-iΟ€/4}|1⟩ Magic state + + +@irdl_attr_definition +class LogicalInit(EnumAttribute[LogicalInitKind], SpacedOpaqueSyntaxAttribute): + """The initial state of a logical qubit such as |0⟩, |1⟩, |+⟩, |βˆ’βŸ©, |Y⟩, |-Y⟩, |m⟩, or |mΜ…βŸ©.""" + + name = "qec.enum" + + +# Type alias for a product of Pauli operators, aka a Pauli word. +PauliWord = ArrayAttr[StringAttr] + + +@irdl_op_definition +class YieldOp(AbstractYieldOperation[Attribute]): + """Return results from a layer region""" + + name = "qec.yield" + + traits = lazy_traits_def(lambda: (IsTerminator(), HasParent(LayerOp), Pure())) + + +@irdl_op_definition +class FabricateOp(IRDLOperation): + """Fabricate axillary qubits from qubit factories.""" + + name = "qec.fabricate" + + assembly_format = """ + $init_state attr-dict `:` type($out_qubits) + """ + + init_state = prop_def(LogicalInit) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class LayerOp(IRDLOperation): + """A layer operation""" + + name = "qec.layer" + + initArgs = var_operand_def() + + results = var_result_def() + + region = region_def("single_block") + + traits = traits_def(SingleBlockImplicitTerminator(YieldOp)) + + # TODO: add a custom parse and print for this operation + + +@irdl_op_definition +class PPMeasurementOp(IRDLOperation): + """Pauli Product Measurement on qubits.""" + + name = "qec.ppm" + + assembly_format = """ + $pauli_product (`(` $rotation_sign^ `)`)? $in_qubits (`cond` `(` $condition^ `)`)? attr-dict `:` type(results) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + pauli_product = prop_def(PauliWord) + + rotation_sign = opt_prop_def(IntegerAttr[I16], default_value=IntegerAttr(1, i16)) + + in_qubits = var_operand_def(QubitType) + + condition = opt_operand_def(IntegerType(1)) + + mres = result_def(IntegerType(1)) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class PPRotationOp(IRDLOperation): + """Pauli Product Rotation on qubits.""" + + name = "qec.ppr" + + assembly_format = """ + $pauli_product `(` $rotation_kind `)` $in_qubits attr-dict (`cond` `(` $condition^ `)`)? `:` type($out_qubits) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + pauli_product = prop_def(PauliWord) + + rotation_kind = prop_def(IntegerAttr[IntegerType(16)]) + + in_qubits = var_operand_def(QubitType) + + condition = opt_operand_def(IntegerType(1)) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class PrepareStateOp(IRDLOperation): + """Initialize existing qubits into a given state.""" + + name = "qec.prepare" + + assembly_format = """ + $init_state $in_qubits attr-dict `:` type($out_qubits) + """ + + init_state = prop_def(LogicalInit) + + in_qubits = var_operand_def(QubitType) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class SelectPPMeasurementOp(IRDLOperation): + """Multiplexed Pauli product measurement.""" + + name = "qec.select.ppm" + + assembly_format = """ + `(` $select_switch `,` $pauli_product_0 `,` $pauli_product_1 `)` $in_qubits attr-dict `:` type(results) + """ + + select_switch = operand_def(IntegerType(1)) + + pauli_product_0 = prop_def(PauliWord) + + pauli_product_1 = prop_def(PauliWord) + + in_qubits = var_operand_def(QubitType) + + mres = result_def(IntegerType(1)) + + out_qubits = var_result_def(QubitType) + + +QEC = Dialect( + "qec", + [ + FabricateOp, + LayerOp, + PPMeasurementOp, + PPRotationOp, + PrepareStateOp, + SelectPPMeasurementOp, + YieldOp, + ], + [ + LogicalInit, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/quantum.py b/frontend/catalyst/python_interface/dialects/quantum.py new file mode 100644 index 0000000000..8b549f5440 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/quantum.py @@ -0,0 +1,1138 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the definition of the Quantum dialect for the unified compiler. + +The Quantum dialect is a set of operations and types used to represent quantum computations +in the xDSL framework. + +It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) +starting from the catalyst/mlir/include/Quantum/IR/QuantumOps.td file in the catalyst repository. +""" +# pylint: disable=too-many-lines + +from collections.abc import Sequence +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + I32, + I64, + ComplexType, + Float64Type, + FloatAttr, + IntegerAttr, + IntegerType, + MemRefType, + StringAttr, + TensorType, + UnitAttr, + i1, + i64, +) +from xdsl.ir import ( + Block, + Dialect, + EnumAttribute, + Operation, + ParametrizedAttribute, + Region, + SpacedOpaqueSyntaxAttribute, + SSAValue, + StrEnum, + TypeAttribute, +) +from xdsl.irdl import ( + AtLeast, + AttrSizedOperandSegments, + AttrSizedResultSegments, + IntSetConstraint, + IRDLOperation, + ParsePropInAttrDict, + SameVariadicResultSize, + irdl_attr_definition, + irdl_op_definition, + lazy_traits_def, + operand_def, + opt_operand_def, + opt_prop_def, + opt_result_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + HasParent, + IsTerminator, + NoMemoryEffect, + Pure, + ReturnLike, + SingleBlockImplicitTerminator, +) + +from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint + +################################################################ +######################## ATTRIBUTES ############################ +################################################################ + + +@irdl_attr_definition +class ObservableType(ParametrizedAttribute, TypeAttribute): + """A quantum observable for use in measurements.""" + + name = "quantum.obs" + + +@irdl_attr_definition +class QubitType(ParametrizedAttribute, TypeAttribute): + """A value-semantic qubit (state).""" + + name = "quantum.bit" + + +@irdl_attr_definition +class QuregType(ParametrizedAttribute, TypeAttribute): + """An array of value-semantic qubits (i.e. quantum register).""" + + name = "quantum.reg" + + +@irdl_attr_definition +class ResultType(ParametrizedAttribute, TypeAttribute): + """A quantum measurement result.""" + + name = "quantum.res" + + +class NamedObservable(StrEnum): + """Known named observables""" + + Identity = "Identity" + PauliX = "PauliX" + PauliY = "PauliY" + PauliZ = "PauliZ" + Hadamard = "Hadamard" + + +@irdl_attr_definition +class NamedObservableAttr(EnumAttribute[NamedObservable], SpacedOpaqueSyntaxAttribute): + """Known named observables""" + + name = "quantum.named_observable" + + +################################################################ +######################## OPERATIONS ############################ +################################################################ + + +QubitSSAValue: TypeAlias = SSAValue[QubitType] +QuregSSAValue: TypeAlias = SSAValue[QuregType] +ObservableSSAValue: TypeAlias = SSAValue[ObservableType] + + +@irdl_op_definition +class AdjointOp(IRDLOperation): + """Calculate the adjoint of the enclosed operations""" + + name = "quantum.adjoint" + + assembly_format = """ + `(` $qreg `)` attr-dict `:` type(operands) $region + """ + + qreg = operand_def(QuregType) + + out_qreg = result_def(QuregType) + + region = region_def("single_block") + + traits = lazy_traits_def(lambda: (NoMemoryEffect(), SingleBlockImplicitTerminator(YieldOp))) + + def __init__( + self, + qreg: QuregSSAValue | Operation, + region: Region | Sequence[Operation] | Sequence[Block], + ): + super().__init__(operands=(qreg,), result_types=(QuregType(),), regions=(region,)) + + +@irdl_op_definition +class AllocOp(IRDLOperation): + """Allocate n qubits into a quantum register.""" + + name = "quantum.alloc" + + assembly_format = """ + `(` ($nqubits^):($nqubits_attr)? `)` attr-dict `:` type(results) + """ + + nqubits = opt_operand_def(i64) + + nqubits_attr = opt_prop_def(IntegerAttr.constr(type=I64, value=AtLeast(0))) + + qreg = result_def(QuregType) + + def __init__(self, nqubits): + if isinstance(nqubits, int): + nqubits = IntegerAttr.from_int_and_width(nqubits, 64) + + if isinstance(nqubits, IntegerAttr): + operands = (None,) + properties = {"nqubits_attr": nqubits} + else: + operands = (nqubits,) + properties = {} + + super().__init__(operands=operands, properties=properties, result_types=(QuregType(),)) + + +@irdl_op_definition +class AllocQubitOp(IRDLOperation): + """Allocate a single qubit.""" + + name = "quantum.alloc_qb" + + assembly_format = """attr-dict `:` type(results)""" + + qubit = result_def(QubitType) + + def __init__(self): + super().__init__( + result_types=(QubitType(),), + ) + + +@irdl_op_definition +class ComputationalBasisOp(IRDLOperation): + """Define a pseudo-obeservable of the computational basis for use in measurements""" + + name = "quantum.compbasis" + + assembly_format = """ + (`qubits` $qubits^)? (`qreg` $qreg^)? attr-dict `:` type(results) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + qubits = var_operand_def(QubitType) + + qreg = opt_operand_def(QuregType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class CountsOp(IRDLOperation): + """Compute sample counts for the given observable for the current state""" + + name = "quantum.counts" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $in_eigvals^ `:` type($in_eigvals) `,` $in_counts `:` type($in_counts) `)` )? + attr-dict ( `:` type($eigvals)^ `,` type($counts) )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + SameVariadicResultSize(), + ] + + obs = operand_def(ObservableType) + + dynamic_shape = opt_operand_def(i64) + + in_eigvals = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1)) + + in_counts = opt_operand_def(MemRefConstraint(element_type=i64, rank=1)) + + eigvals = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1)) + + counts = opt_result_def(TensorConstraint(element_type=i64, rank=1)) + + +@irdl_op_definition +class CustomOp(IRDLOperation): + """A generic quantum gate on n qubits with m floating point parameters.""" + + name = "quantum.custom" + + assembly_format = """ + $gate_name `(` $params `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + params = var_operand_def(Float64Type()) + + in_qubits = var_operand_def(QubitType) + + gate_name = prop_def(StringAttr) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + gate_name: str | StringAttr, + params: SSAValue[Float64Type] | Sequence[SSAValue[Float64Type]] | None = None, + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + params = () if params is None else params + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(params, Sequence): + params = (params,) + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + if isinstance(gate_name, str): + gate_name = StringAttr(data=gate_name) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {"gate_name": gate_name} + if adjoint: + properties["adjoint"] = UnitAttr() + + super().__init__( + operands=(params, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class DeallocOp(IRDLOperation): + """Deallocate a quantum register.""" + + name = "quantum.dealloc" + + assembly_format = """ + $qreg attr-dict `:` type(operands) + """ + + qreg = operand_def(QuregType) + + def __init__(self, qreg: QuregSSAValue | Operation): + super().__init__(operands=(qreg,)) + + +@irdl_op_definition +class DeallocQubitOp(IRDLOperation): + """Deallocate a single qubit.""" + + name = "quantum.dealloc_qb" + + assembly_format = """$qubit attr-dict `:` type(operands)""" + + qubit = operand_def(QubitType) + + def __init__(self, qubit: QubitSSAValue | Operation): + super().__init__( + operands=(qubit,), + ) + + +@irdl_op_definition +class DeviceInitOp(IRDLOperation): + """Initialize a quantum device.""" + + name = "quantum.device" + + assembly_format = """ + (`shots` `(` $shots^ `)`)? `[` $lib `,` $device_name `,` $kwargs `]` attr-dict + """ + + irdl_options = [ParsePropInAttrDict()] + + shots = opt_operand_def(i64) + + auto_qubit_management = opt_prop_def(UnitAttr) + + lib = prop_def(StringAttr) + + device_name = prop_def(StringAttr) + + kwargs = prop_def(StringAttr) + + +@irdl_op_definition +class DeviceReleaseOp(IRDLOperation): + """Release the active quantum device.""" + + name = "quantum.device_release" + + assembly_format = "attr-dict" + + +@irdl_op_definition +class ExpvalOp(IRDLOperation): + """Compute the expectation value of the given observable for the current state""" + + name = "quantum.expval" + + assembly_format = "$obs attr-dict `:` type(results)" + + obs = operand_def(ObservableType) + + expval = result_def(Float64Type()) + + def __init__(self, obs: ObservableSSAValue | Operation): + super().__init__(operands=(obs,), result_types=(Float64Type(),)) + + +@irdl_op_definition +class ExtractOp(IRDLOperation): + """Extract a qubit value from a register.""" + + name = "quantum.extract" + + assembly_format = """ + $qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results) + """ + + qreg = operand_def(QuregType) + + idx = opt_operand_def(i64) + + idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0))) + + qubit = result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + def __init__( + self, + qreg: QuregSSAValue | Operation, + idx: int | SSAValue[IntegerType] | Operation | IntegerAttr, + ): + if isinstance(idx, int): + idx = IntegerAttr.from_int_and_width(idx, 64) + + if isinstance(idx, IntegerAttr): + operands = (qreg, None) + properties = {"idx_attr": idx} + else: + operands = (qreg, idx) + properties = {} + + super().__init__( + operands=operands, + result_types=(QubitType(),), + properties=properties, + ) + + +@irdl_op_definition +class FinalizeOp(IRDLOperation): + """Teardown the quantum runtime.""" + + name = "quantum.finalize" + + assembly_format = "attr-dict" + + +@irdl_op_definition +class GlobalPhaseOp(IRDLOperation): + """Global Phase.""" + + name = "quantum.gphase" + + assembly_format = """ + `(` $params `)` + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type(results) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] + + params = operand_def(Float64Type()) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_ctrl_qubits = var_result_def(QubitType) + + def __init__( + self, + *, + params: float | SSAValue[Float64Type], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + ): + if isinstance(params, float): + params = FloatAttr(data=params, type=Float64Type()) + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + + super().__init__( + operands=(params, in_ctrl_qubits, in_ctrl_values), + result_types=(out_ctrl_qubits,), + ) + + +@irdl_op_definition +class HamiltonianOp(IRDLOperation): + """Define a Hamiltonian observable for use in measurements""" + + name = "quantum.hamiltonian" + + assembly_format = """ + `(` $coeffs `:` type($coeffs) `)` $terms attr-dict `:` type(results) + """ + + coeffs = operand_def( + TensorConstraint(element_type=Float64Type(), rank=1) + | (MemRefConstraint(element_type=Float64Type(), rank=1)) + ) + + terms = var_operand_def(ObservableType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class HermitianOp(IRDLOperation): + """Define a Hermitian observable for use in measurements""" + + name = "quantum.hermitian" + + assembly_format = """ + `(` $matrix `:` type($matrix) `)` $qubits attr-dict `:` type(results) + """ + + matrix = operand_def( + TensorConstraint(element_type=ComplexType(Float64Type()), rank=2) + | MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2) + ) + + qubits = var_operand_def(QubitType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class InitializeOp(IRDLOperation): + """Initialize the quantum runtime.""" + + name = "quantum.init" + + assembly_format = "attr-dict" + + +@irdl_op_definition +class InsertOp(IRDLOperation): + """Update the qubit value of a register.""" + + name = "quantum.insert" + + assembly_format = """ + $in_qreg `[` ($idx^):($idx_attr)? `]` `,` $qubit attr-dict `:` type($in_qreg) `,` type($qubit) + """ + + in_qreg = operand_def(QuregType) + + idx = opt_operand_def(i64) + + idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0))) + + qubit = operand_def(QubitType) + + out_qreg = result_def(QuregType) + + traits = traits_def(NoMemoryEffect()) + + def __init__( + self, + in_qreg: QuregSSAValue | Operation, + idx: SSAValue[IntegerType] | Operation | int | IntegerAttr, + qubit: QubitSSAValue | Operation, + ): + if isinstance(idx, int): + idx = IntegerAttr.from_int_and_width(idx, 64) + + if isinstance(idx, IntegerAttr): + operands = (in_qreg, None, qubit) + properties = {"idx_attr": idx} + else: + operands = (in_qreg, idx, qubit) + properties = {} + + super().__init__(operands=operands, properties=properties, result_types=(QuregType(),)) + + +@irdl_op_definition +class MeasureOp(IRDLOperation): + """A single-qubit projective measurement in the computational basis.""" + + name = "quantum.measure" + + assembly_format = """ + $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results) + """ + + in_qubit = operand_def(QubitType) + + postselect = opt_prop_def( + IntegerAttr.constr(type=I32, value=IntSetConstraint(frozenset((0, 1)))) + ) + + mres = result_def(i1) + + out_qubit = result_def(QubitType) + + def __init__( + self, in_qubit: QubitSSAValue | Operation, postselect: int | IntegerAttr | None = None + ): + if isinstance(postselect, int): + postselect = IntegerAttr.from_int_and_width(postselect, 32) + + if postselect is None: + properties = {} + else: + properties = {"postselect": postselect} + + super().__init__( + operands=(in_qubit,), properties=properties, result_types=(i1, QubitType()) + ) + + +@irdl_op_definition +class MultiRZOp(IRDLOperation): + """Apply an arbitrary multi Z rotation""" + + name = "quantum.multirz" + + assembly_format = """ + `(` $theta `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + theta = operand_def(Float64Type()) + + in_qubits = var_operand_def(QubitType) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + theta: SSAValue[Float64Type], + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {"adjoint": UnitAttr()} if adjoint else {} + + super().__init__( + operands=(theta, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class NamedObsOp(IRDLOperation): + """Define a Named observable for use in measurements""" + + name = "quantum.namedobs" + + assembly_format = """ + $qubit `[` $type `]` attr-dict `:` type(results) + """ + + qubit = operand_def(QubitType) + + type = prop_def(NamedObservableAttr) + + obs = result_def(ObservableType) + + def __init__(self, qubit: QubitSSAValue | Operation, obs_type: NamedObservableAttr): + super().__init__( + operands=(qubit,), properties={"type": obs_type}, result_types=(ObservableType(),) + ) + + +@irdl_op_definition +class NumQubitsOp(IRDLOperation): + """Get the number of currently allocated qubits.""" + + name = "quantum.num_qubits" + + assembly_format = """ + attr-dict `:` type(results) + """ + + num_qubits = result_def(i64) + + +@irdl_op_definition +class PCPhaseOp(IRDLOperation): + """Apply a projector-controlled phase gate""" + + name = "quantum.pcphase" + + assembly_format = """ + `(` $theta `,` $dim `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + theta = operand_def(Float64Type()) + + dim = operand_def(Float64Type()) + + in_qubits = var_operand_def(QubitType) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + theta: SSAValue[Float64Type], + dim: SSAValue[Float64Type], + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {"adjoint": UnitAttr()} if adjoint else {} + + super().__init__( + operands=(theta, dim, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class ProbsOp(IRDLOperation): + """Compute computational basis probabilities for the current state""" + + name = "quantum.probs" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $state_in^ `:` type($state_in) `)` )? + attr-dict ( `:` type($probabilities)^ )? + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + obs = operand_def(ObservableType) + + dynamic_shape = opt_operand_def(i64) + + state_in = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1)) + + probabilities = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1)) + + +@irdl_op_definition +class QubitUnitaryOp(IRDLOperation): + """Apply an arbitrary fixed unitary matrix""" + + name = "quantum.unitary" + + assembly_format = """ + `(` $matrix `:` type($matrix) `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + matrix = operand_def( + (TensorConstraint(element_type=ComplexType(Float64Type()), rank=2)) + | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2)) + ) + + in_qubits = var_operand_def(QubitType) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + matrix: SSAValue[TensorType | MemRefType], + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {} + if adjoint: + properties["adjoint"] = UnitAttr() + + super().__init__( + operands=(matrix, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class SampleOp(IRDLOperation): + """Sample eigenvalues from the given observable for the current state""" + + name = "quantum.sample" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $in_data^ `:` type($in_data) `)` )? + attr-dict ( `:` type($samples)^ )? + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + obs = operand_def(ObservableType) + + dynamic_shape = var_operand_def(i64) + + in_data = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=(1, 2))) + + samples = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=(1, 2))) + + +@irdl_op_definition +class SetBasisStateOp(IRDLOperation): + """Set basis state.""" + + name = "quantum.set_basis_state" + + assembly_format = """ + `(` $basis_state`)` $in_qubits attr-dict `:` functional-type(operands, results) + """ + + basis_state = operand_def( + (TensorConstraint(element_type=i1, rank=1)) | (MemRefConstraint(element_type=i1, rank=1)) + ) + + in_qubits = var_operand_def(QubitType) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class SetStateOp(IRDLOperation): + """Set state to a complex vector.""" + + name = "quantum.set_state" + + assembly_format = """ + `(` $in_state `)` $in_qubits attr-dict `:` functional-type(operands, results) + """ + + in_state = operand_def( + (TensorConstraint(element_type=ComplexType(Float64Type()), rank=1)) + | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1)) + ) + + in_qubits = var_operand_def(QubitType) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class StateOp(IRDLOperation): + """Return the current statevector""" + + name = "quantum.state" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $state_in^ `:` type($state_in) `)` )? + attr-dict ( `:` type($state)^ )? + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + obs = operand_def(ObservableType) + + dynamic_shape = opt_operand_def(i64) + + state_in = opt_operand_def(MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1)) + + state = opt_result_def(TensorConstraint(element_type=ComplexType(Float64Type()), rank=1)) + + +@irdl_op_definition +class TensorOp(IRDLOperation): + """Define a tensor product of observables for use in measurements""" + + name = "quantum.tensor" + + assembly_format = """ + $terms attr-dict `:` type(results) + """ + + terms = var_operand_def(ObservableType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class VarianceOp(IRDLOperation): + """Compute the variance of the given observable for the current state""" + + name = "quantum.var" + + assembly_format = """ + $obs attr-dict `:` type(results) + """ + + obs = operand_def(ObservableType) + + variance = result_def(Float64Type()) + + def __init__(self, obs: ObservableSSAValue | Operation): + super().__init__(operands=(obs,), result_types=(Float64Type(),)) + + +@irdl_op_definition +class YieldOp(IRDLOperation): + """Return results from quantum program regions""" + + name = "quantum.yield" + + assembly_format = """ + attr-dict ($retvals^ `:` type($retvals))? + """ + + retvals = var_operand_def(QuregType) + + traits = traits_def(HasParent(AdjointOp), IsTerminator(), Pure(), ReturnLike()) + + +Quantum = Dialect( + "quantum", + [ + AdjointOp, + AllocOp, + AllocQubitOp, + ComputationalBasisOp, + CountsOp, + CustomOp, + DeallocOp, + DeallocQubitOp, + DeviceInitOp, + DeviceReleaseOp, + ExpvalOp, + ExtractOp, + FinalizeOp, + GlobalPhaseOp, + HamiltonianOp, + HermitianOp, + InitializeOp, + InsertOp, + MeasureOp, + MultiRZOp, + NamedObsOp, + NumQubitsOp, + PCPhaseOp, + ProbsOp, + QubitUnitaryOp, + SampleOp, + SetBasisStateOp, + SetStateOp, + StateOp, + TensorOp, + VarianceOp, + YieldOp, + ], + [ + ObservableType, + QubitType, + QuregType, + ResultType, + NamedObservableAttr, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py new file mode 100644 index 0000000000..64380529ac --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py @@ -0,0 +1,161 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO dialect package for PennyLane's compiler infrastructure. + +This package contains organized elementwise operations and other StableHLO-related +functionality. +""" + +from .attributes import ( + CustomCallApiVersion, + CustomCallApiVersionAttr, + GatherDimensionNumbers, + OutputOperandAlias, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, +) +from .control_flow import ( + IfOp, + OptimizationBarrierOp, + WhileOp, +) +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) + +# Import the main StableHLO dialect +from .dialect import StableHLO +from .dynamism import ( + DynamicBroadcastInDimOp, +) +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) +from .elementwise_other import ( + ClampOp, + CompareOp, + ConstantOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) + +# Import all elementwise operations explicitly +from .elementwise_unary import ( + ConvertOp, + CosineOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogisticOp, + LogOp, + LogPlusOneOp, + NegateOp, + RealOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SignOp, + SineOp, + SqrtOp, + TanhOp, + TanOp, +) +from .extensibility import ( + CustomCallOp, +) +from .reduction import ( + ReduceOp, +) + +# Export all operations and the dialect for external use +__all__ = [ + # Main dialect + "StableHLO", + # Elementwise unary operations + "ConvertOp", + "CosineOp", + "ExponentialMinusOneOp", + "ExponentialOp", + "FloorOp", + "ImagOp", + "IsFiniteOp", + "LogOp", + "LogPlusOneOp", + "LogisticOp", + "NegateOp", + "RealOp", + "RoundNearestAfzOp", + "RoundNearestEvenOp", + "RsqrtOp", + "SignOp", + "SineOp", + "SqrtOp", + "TanOp", + "TanhOp", + # Elementwise binary operations + "ComplexOp", + "DivideOp", + "MaximumOp", + "MinimumOp", + "PowerOp", + "RemainderOp", + # Elementwise other operations + "ClampOp", + "CompareOp", + "ConstantOp", + "MapOp", + "ReducePrecisionOp", + "SelectOp", + # Control flow operations + "IfOp", + "WhileOp", + "OptimizationBarrierOp", + # Data movement operations + "BroadcastInDimOp", + "ConcatenateOp", + "DynamicSliceOp", + "GatherOp", + "ReshapeOp", + "ScatterOp", + "SliceOp", + # Dynamism operations + "DynamicBroadcastInDimOp", + # Reduction operations + "ReduceOp", + # Extensibility operations + "CustomCallOp", + # Attributes + "GatherDimensionNumbers", + "ResultAccuracyModeAttr", + "ScatterDimensionNumbers", + "CustomCallApiVersion", + "CustomCallApiVersionAttr", + "OutputOperandAlias", +] diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py new file mode 100644 index 0000000000..7618d77afc --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py @@ -0,0 +1,394 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO attribute definitions for PennyLane's compiler infrastructure. + +This module provides attribute definitions based on the StableHLO specification +(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including +attributes for StableHLO operations. +""" + +# pylint: disable=line-too-long + +from collections.abc import Sequence + +from xdsl.dialects.builtin import I64, ArrayAttr, IntegerAttr, i64 +from xdsl.ir import ( + Attribute, + EnumAttribute, + ParametrizedAttribute, + SpacedOpaqueSyntaxAttribute, + StrEnum, +) +from xdsl.irdl import irdl_attr_definition +from xdsl.parser import AttrParser +from xdsl.printer import Printer + + +# Utility functions for dimension array parsing/printing +def parse_dims(parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]: + """Parse dimension array in [1, 2, 3] format""" + value = parser.parse_comma_separated_list( + AttrParser.Delimiter.SQUARE, + lambda: IntegerAttr(parser.parse_integer(), i64), + ) + return ArrayAttr(value) + + +def print_dims(printer: Printer, dims: ArrayAttr[IntegerAttr[I64]]): + """Print dimension array in [1, 2, 3] format""" + printer.print_string("[") + printer.print_list( + dims.data, + lambda dim: printer.print_string(f"{dim.value.data}"), + ) + printer.print_string("]") + + +class ResultAccuracyMode(StrEnum): + """ + XLA result accuracy mode. + """ + + DEFAULT = "DEFAULT" + HIGH = "HIGHEST" + HIGHEST = "TOLERANCE" + + +@irdl_attr_definition +class ResultAccuracyModeAttr(EnumAttribute[ResultAccuracyMode], SpacedOpaqueSyntaxAttribute): + """ + XLA result accuracy mode. + + See external [documentation](https://github.com/openxla/stablehlo/blob/7c50d4efeaea30bff6aa5e46c7f71170f5aa06af/stablehlo/dialect/StablehloEnums.td#L49-L70). + """ + + name = "stablehlo.result_accuracy_mode" + + +@irdl_attr_definition +class GatherDimensionNumbers(ParametrizedAttribute): + """ + XLA gather dimension numbers. + + This attribute models the dimension information for gather operations. + See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42). + """ + + name = "stablehlo.gather" + + offset_dims: ArrayAttr[IntegerAttr[I64]] + collapsed_slice_dims: ArrayAttr[IntegerAttr[I64]] + operand_batching_dims: ArrayAttr[IntegerAttr[I64]] + start_indices_batching_dims: ArrayAttr[IntegerAttr[I64]] + start_index_map: ArrayAttr[IntegerAttr[I64]] + index_vector_dim: IntegerAttr[I64] + + def print_parameters(self, printer: Printer) -> None: + """Print gather dimension numbers in structured format""" + with printer.in_angle_brackets(): + with printer.indented(): + # Print offset_dims + printer.print_string("\noffset_dims = ") + print_dims(printer, self.offset_dims) + printer.print_string(",") + + # Print collapsed_slice_dims + printer.print_string("\ncollapsed_slice_dims = ") + print_dims(printer, self.collapsed_slice_dims) + printer.print_string(",") + + # Print operand_batching_dims + printer.print_string("\noperand_batching_dims = ") + print_dims(printer, self.operand_batching_dims) + printer.print_string(",") + + # Print start_indices_batching_dims + printer.print_string("\nstart_indices_batching_dims = ") + print_dims(printer, self.start_indices_batching_dims) + printer.print_string(",") + + # Print start_index_map + printer.print_string("\nstart_index_map = ") + print_dims(printer, self.start_index_map) + printer.print_string(",") + + # Print index_vector_dim + printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}") + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: + """Parse gather dimension numbers from structured format""" + with parser.in_angle_brackets(): + # Initialize default values for all fields + offset_dims = ArrayAttr([]) + collapsed_slice_dims = ArrayAttr([]) + operand_batching_dims = ArrayAttr([]) + start_indices_batching_dims = ArrayAttr([]) + start_index_map = ArrayAttr([]) + index_vector_dim = IntegerAttr(0, i64) + + # Try to parse offset_dims + if parser.parse_optional_characters("offset_dims") is not None: + parser.parse_punctuation("=") + offset_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse collapsed_slice_dims + if parser.parse_optional_characters("collapsed_slice_dims") is not None: + parser.parse_punctuation("=") + collapsed_slice_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse operand_batching_dims + if parser.parse_optional_characters("operand_batching_dims") is not None: + parser.parse_punctuation("=") + operand_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse start_indices_batching_dims + if parser.parse_optional_characters("start_indices_batching_dims") is not None: + parser.parse_punctuation("=") + start_indices_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse start_index_map + if parser.parse_optional_characters("start_index_map") is not None: + parser.parse_punctuation("=") + start_index_map = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse index_vector_dim + if parser.parse_optional_characters("index_vector_dim") is not None: + parser.parse_punctuation("=") + index_vector_dim = IntegerAttr(parser.parse_integer(), i64) + + return ( + offset_dims, + collapsed_slice_dims, + operand_batching_dims, + start_indices_batching_dims, + start_index_map, + index_vector_dim, + ) + + +@irdl_attr_definition +class ScatterDimensionNumbers(ParametrizedAttribute): + """ + XLA scatter dimension numbers. + + This attribute models the dimension information for scatter operations. + See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28). + """ + + name = "stablehlo.scatter" + + update_window_dims: ArrayAttr[IntegerAttr[I64]] + inserted_window_dims: ArrayAttr[IntegerAttr[I64]] + input_batching_dims: ArrayAttr[IntegerAttr[I64]] + scatter_indices_batching_dims: ArrayAttr[IntegerAttr[I64]] + scatter_dims_to_operand_dims: ArrayAttr[IntegerAttr[I64]] + index_vector_dim: IntegerAttr[I64] + + def print_parameters(self, printer: Printer) -> None: + """Print scatter dimension numbers in structured format""" + with printer.in_angle_brackets(): + with printer.indented(): + # Print update_window_dims + printer.print_string("\nupdate_window_dims = ") + print_dims(printer, self.update_window_dims) + printer.print_string(",") + + # Print inserted_window_dims + printer.print_string("\ninserted_window_dims = ") + print_dims(printer, self.inserted_window_dims) + printer.print_string(",") + + # Print input_batching_dims + printer.print_string("\ninput_batching_dims = ") + print_dims(printer, self.input_batching_dims) + printer.print_string(",") + + # Print scatter_indices_batching_dims + printer.print_string("\nscatter_indices_batching_dims = ") + print_dims(printer, self.scatter_indices_batching_dims) + printer.print_string(",") + + # Print scatter_dims_to_operand_dims + printer.print_string("\nscatter_dims_to_operand_dims = ") + print_dims(printer, self.scatter_dims_to_operand_dims) + printer.print_string(",") + + # Print index_vector_dim + printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}") + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: + """Parse scatter dimension numbers from structured format""" + with parser.in_angle_brackets(): + # Initialize default values for all fields + update_window_dims = ArrayAttr([]) + inserted_window_dims = ArrayAttr([]) + input_batching_dims = ArrayAttr([]) + scatter_indices_batching_dims = ArrayAttr([]) + scatter_dims_to_operand_dims = ArrayAttr([]) + index_vector_dim = IntegerAttr(0, i64) + + # Try to parse update_window_dims + if parser.parse_optional_characters("update_window_dims") is not None: + parser.parse_punctuation("=") + update_window_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse inserted_window_dims + if parser.parse_optional_characters("inserted_window_dims") is not None: + parser.parse_punctuation("=") + inserted_window_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse input_batching_dims + if parser.parse_optional_characters("input_batching_dims") is not None: + parser.parse_punctuation("=") + input_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse scatter_indices_batching_dims + if parser.parse_optional_characters("scatter_indices_batching_dims") is not None: + parser.parse_punctuation("=") + scatter_indices_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse scatter_dims_to_operand_dims + if parser.parse_optional_characters("scatter_dims_to_operand_dims") is not None: + parser.parse_punctuation("=") + scatter_dims_to_operand_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse index_vector_dim + if parser.parse_optional_characters("index_vector_dim") is not None: + parser.parse_punctuation("=") + index_vector_dim = IntegerAttr(parser.parse_integer(), i64) + + return ( + update_window_dims, + inserted_window_dims, + input_batching_dims, + scatter_indices_batching_dims, + scatter_dims_to_operand_dims, + index_vector_dim, + ) + + +# ===== CustomCall and layout-related attributes ===== + + +class CustomCallApiVersion(StrEnum): + """StableHLO CustomCall API version.""" + + API_VERSION_UNSPECIFIED = "API_VERSION_UNSPECIFIED" + API_VERSION_ORIGINAL = "API_VERSION_ORIGINAL" + API_VERSION_STATUS_RETURNING = "API_VERSION_STATUS_RETURNING" + API_VERSION_STATUS_RETURNING_UNIFIED = "API_VERSION_STATUS_RETURNING_UNIFIED" + API_VERSION_TYPED_FFI = "API_VERSION_TYPED_FFI" + + +@irdl_attr_definition +class CustomCallApiVersionAttr(EnumAttribute[CustomCallApiVersion], SpacedOpaqueSyntaxAttribute): + """StableHLO custom call API version attribute. + + Mirrors StableHLO enum for CustomCall API versions. + """ + + name = "stablehlo.custom_call_api_version" + + +@irdl_attr_definition +class OutputOperandAlias(ParametrizedAttribute): + """ + This attribute captures the alias relationship of the output to one of the + operands for a ``CustomCall`` op, denoted by ``operand_index``. The + ``output_tuple_indices`` and ``operand_tuple_indices`` are used to index into + output and operand types. These indices lists are empty if the corresponding + types are not tuple types, and can be arbitrarily long in case of + arbitrarily nested tuple types. + + See https://www.tensorflow.org/xla/aliasing. + + Example when used as array with in stablehlo.custom-call: + + ```mlir + %0 = "stablehlo.custom_call"(%arg0, %arg1) { + // other attributes + output_operand_alias = [ + #stablehlo.output_operand_alias + ] + } : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple> + + The output and the 0th operand are both tuples. The aliasing shows the + relationship between the 0th element in output tuple with the 1st element in + the 0th operand. And both of them are of the same type: ``tensor<2x3xf32>``. + ``` + """ + + name = "stablehlo.output_operand_alias" + + output_tuple_indices: ArrayAttr[IntegerAttr[I64]] + operand_index: IntegerAttr[I64] + operand_tuple_indices: ArrayAttr[IntegerAttr[I64]] + + def print_parameters(self, printer: Printer) -> None: + """Print the OutputOperandAlias attribute.""" + with printer.in_angle_brackets(): + with printer.indented(): + printer.print_string("\noutput_tuple_indices = ") + print_dims(printer, self.output_tuple_indices) + printer.print_string(",") + + printer.print_string("\noperand_index = ") + printer.print_string(f"{self.operand_index.value.data}") + printer.print_string(",") + + printer.print_string("\noperand_tuple_indices = ") + print_dims(printer, self.operand_tuple_indices) + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser): + """Parse the OutputOperandAlias attribute.""" + with parser.in_angle_brackets(): + output_tuple_indices = ArrayAttr([]) + operand_index = IntegerAttr(0, i64) + operand_tuple_indices = ArrayAttr([]) + + if parser.parse_optional_characters("output_tuple_indices") is not None: + parser.parse_punctuation("=") + output_tuple_indices = parse_dims(parser) + parser.parse_optional_punctuation(",") + + if parser.parse_optional_characters("operand_index") is not None: + parser.parse_punctuation("=") + operand_index = IntegerAttr(parser.parse_integer(), i64) + parser.parse_optional_punctuation(",") + + if parser.parse_optional_characters("operand_tuple_indices") is not None: + parser.parse_punctuation("=") + operand_tuple_indices = parse_dims(parser) + + return (output_tuple_indices, operand_index, operand_tuple_indices) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py new file mode 100644 index 0000000000..c9edced70d --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py @@ -0,0 +1,160 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Control flow operations for the StableHLO dialect. +""" + +from typing import TypeVar + +from xdsl.dialects.builtin import AnyTensorType +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + region_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + Pure, + RecursivelySpeculatable, + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) +from xdsl_jax.dialects.stablehlo import ReturnOp + +# Import our custom StableHLO types +from .types import HLO_PredTensor, HLO_TensorOrPerAxisQuantizedTensorOrToken, HLO_TensorOrToken + +# Generic type variables for templating +T_IN = TypeVar("T_IN", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +@irdl_op_definition +class IfOp(IRDLOperation): + """ + Produces the output from executing exactly one branch from `true_branch` or + `false_branch` depending on the value of `pred`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if + + Example: + %result = "stablehlo.if"(%pred) ({ + "stablehlo.return"(%result_true_branch) : (tensor) -> () + }, { + "stablehlo.return"(%result_false_branch) : (tensor) -> () + }) : (tensor) -> tensor + """ + + name = "stablehlo.if" + + pred = operand_def(HLO_PredTensor) + + res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + true_branch = region_def("single_block") + + false_branch = region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + RecursivelySpeculatable(), + SingleBlockImplicitTerminator(ReturnOp), + # TODO: InferTypeOpInterface + # TODO: OpAsmOpInterface + ) + + # TODO: Add custom assembly format + + +# pylint: disable=line-too-long +@irdl_op_definition +class WhileOp(IRDLOperation): + """ + Produces the output from executing `body` function 0 or more times while the + `cond` function outputs `true`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while + + Example: + ```mlir + %results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor, tensor + cond { + %cond = stablehlo.compare LT, %arg0, %ten : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + } do { + %new_sum = stablehlo.add %arg1, %one : tensor + %new_i = stablehlo.add %arg0, %one : tensor + stablehlo.return %new_i, %new_sum : tensor, tensor + } + """ + + name = "stablehlo.while" + + operand = var_operand_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + cond = region_def("single_block") + + body = region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + RecursivelySpeculatable(), + SingleBlockImplicitTerminator(ReturnOp), + # TODO: InferTypeOpInterface + # TODO: OpAsmOpInterface + ) + + +# pylint: disable=line-too-long +@irdl_op_definition +class OptimizationBarrierOp(IRDLOperation): + """ + Ensures that the operations that produce the `operand` are executed before any + operations that depend on the `result` and prevents compiler transformations + from moving operations across the barrier. Other than that, the operation is + an identity, i.e. `result` = `operand`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier + + Example: + ```mlir + %result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor, tensor + ``` + """ + + name = "stablehlo.optimization_barrier" + + operand = var_operand_def(HLO_TensorOrToken) + + res = var_result_def(HLO_TensorOrToken) + + traits = traits_def( + Pure(), + # TODO: HLO_PairwiseSameOperandAndResultType + # TODO: InferTypeOpInterface + ) + + # TODO: Add custom assembly format + # assembly_format = """ + # attr-dict ($operand^ `:` custom(type($operand), type($result))):(`(` `)`)? + # """ diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py new file mode 100644 index 0000000000..68ce2bddc7 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py @@ -0,0 +1,416 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Data movement operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import BoolAttr, DenseArrayBase, IntegerAttr, TensorType, i64 +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.irdl.attributes import eq +from xdsl.irdl.constraints import AtLeast +from xdsl.irdl.operations import SameVariadicOperandSize +from xdsl.traits import ( + ConditionallySpeculatable, + NoMemoryEffect, + Pure, + RecursiveMemoryEffect, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.type import get_element_type_or_self + +from catalyst.python_interface.xdsl_extras import ( + AllMatchSameOperatorTrait, + SameOperandsAndResultElementType, + TensorConstraint, +) + +from .attributes import GatherDimensionNumbers, ScatterDimensionNumbers +from .types import HLO_AnyIntegerOrIndexTensor, HLO_AnyTensor, HLO_Int, HLO_IntTensor, HLO_Tensor + + +# pylint: disable=line-too-long +@irdl_op_definition +class BroadcastInDimOp(IRDLOperation): + """ + Expands the dimensions and/or rank of an input tensor by duplicating the + data in the ``operand`` tensor and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim + + Example: + ```mlir + %result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + ``` + """ + + name = "stablehlo.broadcast_in_dim" + operand = operand_def(HLO_AnyTensor) + broadcast_dimensions = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_AnyTensor) + + assembly_format = """ + $operand `,` `dims` `=` $broadcast_dimensions + attr-dict `:` functional-type(operands, results) + """ + + traits = traits_def( + NoMemoryEffect(), + # TODO: HLO_SpeculatableIfAllInputsStatic, + # TODO: HLO_CompatibleOperandsAndResultElementType, + ) + + def verify_(self) -> None: + """Verify non-quantized broadcast_in_dim constraints.""" + o_type = self.operand_types[0] + r_type = self.result_types[0] + + # These are constrained to tensors by the op definition + assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType) + + # broadcast_in_dim_c2: broadcast_dimensions size == operand rank + dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member + operand_rank = o_type.get_num_dims() + if len(dims) != operand_rank: + raise VerifyException( + "broadcast_dimensions size (" + f"{len(dims)}" + ") does not match operand rank (" + f"{operand_rank}" + ")" + ) + + # broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(dims)) != len(dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + # Result rank and per-dimension checks + result_rank = r_type.get_num_dims() + o_shape = o_type.get_shape() + r_shape = r_type.get_shape() + + for i, dim_index in enumerate(dims): + # broadcast_in_dim_c3: each dim index in bounds of result rank + if dim_index < 0 or dim_index >= result_rank: + raise VerifyException( + "broadcast_dimensions contains invalid value " + f"{dim_index} for result with rank {result_rank}" + ) + + # If operand dim is static, enforce broadcast_in_dim_c5 + if o_shape[i] != -1: + dim_size = o_shape[i] + result_dim_size = r_shape[dim_index] + if dim_size not in (1, result_dim_size): + raise VerifyException( + "size of operand dimension " + f"{i} ({dim_size}) is not equal to 1 or size of result dimension " + f"{dim_index} ({result_dim_size})" + ) + + +# pylint: disable=line-too-long +@irdl_op_definition +class ConcatenateOp(IRDLOperation): + """ + Concatenates a variadic number of tensors in ``inputs`` along ``dimension`` + dimension in the same order as the given arguments and produces a ``result`` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate + + Example: + ```mlir + %result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + ``` + """ + + name = "stablehlo.concatenate" + + inputs = var_operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + dimension = prop_def(IntegerAttr.constr(type=eq(i64), value=AtLeast(0))) + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + SameOperandsAndResultElementType(), + # InferTypeOpInterface(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) + # """ + + +@irdl_op_definition +class DynamicSliceOp(IRDLOperation): + """ + Extracts a slice from the ``operand`` using dynamically-computed starting + indices and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice + + Example: + ```mlir + %result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2] + : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x2xi32> + ``` + """ + + name = "stablehlo.dynamic_slice" + operand = operand_def(HLO_Tensor) + start_indices = var_operand_def(TensorConstraint(element_type=HLO_Int, rank=0)) + slice_sizes = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand `,` custom($start_indices) + # `sizes` `=` $slice_sizes attr-dict `:` functional-type(operands, results) + # """ + + traits = traits_def( + Pure(), + AllMatchSameOperatorTrait( + ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type" + ), + # TODO: InferTensorType(), + ) + + +# pylint: disable=line-too-long +@irdl_op_definition +class GatherOp(IRDLOperation): + """ + Gathers slices from ``operand`` tensor from offsets specified in + ``start_indices`` and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + + Example: + ```mlir + %result = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64> + ``` + """ + + name = "stablehlo.gather" + operand = operand_def(HLO_Tensor) + start_indices = operand_def(HLO_IntTensor) + dimension_numbers = prop_def(GatherDimensionNumbers) + slice_sizes = prop_def(DenseArrayBase.constr(i64)) + indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + result = result_def(HLO_Tensor) + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + AllMatchSameOperatorTrait( + ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type" + ), + # TODO: InferTensorTypeWithReify(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) + # """ + + +@irdl_op_definition +class ReshapeOp(IRDLOperation): + """ + Performs reshape of ``operand`` tensor to a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape + + Example: + ```mlir + %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32> + ``` + """ + + name = "stablehlo.reshape" + operand = operand_def(HLO_AnyTensor) + result = result_def(HLO_AnyTensor) + + assembly_format = """ + operands attr-dict `:` functional-type(operands, results) + """ + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + # TODO: HLO_CompatibleOperandsAndResultElementType, + ) + + def verify_(self) -> None: + """Verify that the operation has the same shape for all operands and results.""" + o_type = self.operand_types[0] + r_type = self.result_types[0] + + # These are constrained to tensors by the op definition + assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType) + + # If o_type or r_type is dynamically shaped there is nothing to verify. + if not o_type.has_static_shape() or not r_type.has_static_shape(): + return + + # If the operand type is statically shaped (not required) the number of + # elements must match that of the result type. + num_operand_elements = 1 + for dim in o_type.get_shape(): + num_operand_elements *= dim + + num_result_elements = 1 + for dim in r_type.get_shape(): + num_result_elements *= dim + + if num_result_elements != num_operand_elements: + raise VerifyException( + "number of output elements (" + f"{num_result_elements}" + ") doesn't match expected number of elements (" + f"{num_operand_elements}" + ")" + ) + + +@irdl_op_definition +class ScatterOp(IRDLOperation): + """ + Produces ``results`` tensors which are equal to ``inputs`` tensors except that + several slices specified by ``scatter_indices`` are updated with the values + ``updates`` using ``update_computation``. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + + Example: + ```mlir + %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + ``` + """ + + name = "stablehlo.scatter" + inputs = var_operand_def(HLO_Tensor) + scatter_indices = operand_def(HLO_AnyIntegerOrIndexTensor) + updates = var_operand_def(HLO_Tensor) + scatter_dimension_numbers = prop_def(ScatterDimensionNumbers) + indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + unique_indices = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + result = var_result_def(HLO_Tensor) + update_computation = region_def("single_block") + # TODO: The MLIR implementation doesn't have the SingleBlockImplicitTerminator trait, + # However, it is checked to have a terminator in the verifier, + # which does not specifically check the terminator to be stablehlo.return. + + traits = traits_def( + RecursiveMemoryEffect(), + ConditionallySpeculatable(), + # TODO: InferTypeOpInterface(), + ) + + irdl_options = [SameVariadicOperandSize()] + + # TODO: MLIR has a custom verifier for the scatter operation. + + +@irdl_op_definition +class SliceOp(IRDLOperation): + """ + Extracts a slice from the ``operand`` using statically-computed starting + indices and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice + + Example: + ```mlir + %result = stablehlo.slice %operand [1:3, 4:8:2] + : (tensor<3x8xi64>) -> tensor<2x2xi64> + + // Same in generic form: the `1:3` above is mapped to the first entry in + // `start_indices` and `limit_indices`, while `strides` is implicitly 1. + // The `4:8:2` above is parsed into the second entry of `start_indices`, + // `limit_indices` and `strides` respectively. + %result = "stablehlo.slice" (%operand) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x8xi64>) -> tensor<2x2xi64> + ``` + """ + + name = "stablehlo.slice" + + operand = operand_def(HLO_Tensor) + start_indices = prop_def(DenseArrayBase.constr(i64)) + limit_indices = prop_def(DenseArrayBase.constr(i64)) + strides = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand custom($start_indices, $limit_indices, $strides) + # attr-dict `:` functional-type(operands, results) + # """ + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + AllMatchSameOperatorTrait(("start_indices", "limit_indices", "strides"), len, "size"), + SameOperandsAndResultElementType(), + ) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py new file mode 100644 index 0000000000..ce880bba2b --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py @@ -0,0 +1,207 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Extended StableHLO dialect that dynamically includes all upstream operations +plus custom operations for PennyLane's compiler infrastructure. + +This module automatically imports all operations and attributes from the upstream +xdsl_jax.dialects.stablehlo and adds custom ones without needing to hardcode +the upstream operation list. +""" + +import xdsl_jax.dialects.stablehlo as xstablehlo +from xdsl.ir import Dialect + +from .attributes import ( + CustomCallApiVersionAttr, + GatherDimensionNumbers, + OutputOperandAlias, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, +) +from .control_flow import ( + IfOp, + OptimizationBarrierOp, + WhileOp, +) +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) +from .dynamism import ( + DynamicBroadcastInDimOp, +) +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) +from .elementwise_other import ( + ClampOp, + CompareOp, + ConstantOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) + +# Import all elementwise operations from organized files +from .elementwise_unary import ( + ConvertOp, + CosineOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogisticOp, + LogOp, + LogPlusOneOp, + NegateOp, + RealOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SignOp, + SineOp, + SqrtOp, + TanhOp, + TanOp, +) +from .extensibility import ( + CustomCallOp, +) +from .reduction import ( + ReduceOp, +) +from .types import UniformQuantizedPerAxisType, UniformQuantizedType + +# Operations to add to the dialect +OPERATIONS = [ + ClampOp, + CompareOp, + ComplexOp, + ConstantOp, + ConvertOp, + CosineOp, + DivideOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogOp, + LogPlusOneOp, + LogisticOp, + MapOp, + MaximumOp, + MinimumOp, + NegateOp, + PowerOp, + RealOp, + ReducePrecisionOp, + RemainderOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SelectOp, + SignOp, + SineOp, + SqrtOp, + TanOp, + TanhOp, + # Data movement operations + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, + # Control flow operations + IfOp, + WhileOp, + OptimizationBarrierOp, + # Dynamism operations + DynamicBroadcastInDimOp, + # Reduction operations + ReduceOp, + # Extensibility operations + CustomCallOp, +] + +# Attributes to add to the dialect +ATTRIBUTES = [ + CustomCallApiVersionAttr, + GatherDimensionNumbers, + ResultAccuracyModeAttr, + OutputOperandAlias, + ScatterDimensionNumbers, + UniformQuantizedPerAxisType, + UniformQuantizedType, +] + +# Operations/attributes from upstream that should be deleted/replaced in the local version +UPSTREAM_OPERATIONS_TO_DELETE = [ + xstablehlo.ConstantOp, +] +UPSTREAM_ATTRIBUTES_TO_DELETE = [] + + +def filter_and_extend_upstream(upstream_list, to_delete, to_add): + """Filter out operations/attributes from upstream list and add new ones. + + Args: + upstream_list: List of operations/attributes to filter + to_delete: List of operations/attributes to remove + to_add: List of operations/attributes to add + + Returns: + Modified list of operations/attributes + """ + filtered_ops = list(upstream_list) + + # Remove operations that should be deleted + for op_to_delete in to_delete: + if op_to_delete in filtered_ops: + filtered_ops.remove(op_to_delete) + + # Add new operations + filtered_ops.extend(to_add) + + return filtered_ops + + +all_operations = filter_and_extend_upstream( + xstablehlo.StableHLO.operations, UPSTREAM_OPERATIONS_TO_DELETE, OPERATIONS +) +all_attributes = filter_and_extend_upstream( + xstablehlo.StableHLO.attributes, UPSTREAM_ATTRIBUTES_TO_DELETE, ATTRIBUTES +) + +# Create the extended StableHLO dialect by dynamically getting upstream components +StableHLO = Dialect( + "stablehlo", + all_operations, + all_attributes, +) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py new file mode 100644 index 0000000000..e38b4986f7 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py @@ -0,0 +1,198 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Dynamism operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import DenseArrayBase, TensorType, i64 +from xdsl.irdl import ( + IRDLOperation, + ParsePropInAttrDict, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + result_def, + traits_def, +) +from xdsl.traits import ( + ConditionallySpeculatable, + NoMemoryEffect, +) +from xdsl.utils.exceptions import VerifyException + +from catalyst.python_interface.xdsl_extras import TensorConstraint + +from .types import HLO_AnyTensor, HLO_DimensionValue + + +@irdl_op_definition +class DynamicBroadcastInDimOp(IRDLOperation): + """ + This operation is functionally identical to + [broadcast_in_dim](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim) + op, but the result shape is specified dynamically via ``output_dimensions``. + + It also accepts optional attributes to express static knowledge about the + expanding behavior of dimensions. If not specified, all dimensions are + assumed to be possibly expanding. The sets of dimensions that are known to + be expanding and the set of dimensions that are known to be non-expanding + must be disjoint and they must be a subset of the operand's dimensions. + + See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim + + Example: + ```mlir + %operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64> + %output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64> + %result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + ``` + """ + + name = "stablehlo.dynamic_broadcast_in_dim" + + operand = operand_def(HLO_AnyTensor) + output_dimensions = operand_def(TensorConstraint(element_type=HLO_DimensionValue, rank=1)) + broadcast_dimensions = prop_def(DenseArrayBase.constr(i64)) + known_expanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64)) + known_nonexpanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_AnyTensor) + + assembly_format = ( + "$operand `,` $output_dimensions `,` `dims` `=` $broadcast_dimensions " + "attr-dict `:` functional-type(operands, results)" + ) + + traits = traits_def( + ConditionallySpeculatable(), + NoMemoryEffect(), + # TODO: InferShapedTypeOpInterface(), + ) + + irdl_options = [ParsePropInAttrDict()] + + # pylint: disable=too-many-branches + def verify_(self): + """Verify the operation.""" + operand_ty = self.operand_types[0] + result_ty = self.result_types[0] + bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member + + # Operand and result must be tensors + assert isinstance(operand_ty, TensorType) and isinstance(result_ty, TensorType) + + self._verify_rank_constraints(bcast_dims, operand_ty, result_ty) + + # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(bcast_dims)) != len(bcast_dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + self._verify_per_dimension_bounds(bcast_dims, operand_ty, result_ty) + + self._verify_expansion_hints(operand_ty) + + def _verify_rank_constraints(self, bcast_dims, operand_ty, result_ty): + """Verify then operand and result tensors against the rank constraints.""" + + operand_rank = operand_ty.get_num_dims() + result_rank = result_ty.get_num_dims() + + # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank + if len(bcast_dims) != operand_rank: + raise VerifyException( + "broadcast_dimensions size (" + f"{len(bcast_dims)}" + ") does not match operand rank (" + f"{operand_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c3: result rank >= operand rank + if result_rank < operand_rank: + raise VerifyException( + "result rank (" + f"{result_rank}" + ") is less than operand rank (" + f"{operand_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank + out_dims_ty = self.output_dimensions.type # pylint: disable=no-member + assert isinstance(out_dims_ty, TensorType) + # Must be rank-1 tensor (enforced by type constraint), and length must match result + # rank when statically known + out_shape = out_dims_ty.get_shape() + if len(out_shape) != 1: + raise VerifyException("output_dimensions must be a 1D tensor") + if out_shape[0] != -1 and out_shape[0] != result_rank: + raise VerifyException( + "length of output_dimensions (" + f"{out_shape[0]}" + ") is not compatible with result rank (" + f"{result_rank}" + ")" + ) + + def _verify_per_dimension_bounds(self, bcast_dims, operand_ty, result_ty): + """Verify compatibility of operand and result dimensions.""" + # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility + operand_shape = operand_ty.get_shape() + result_shape = result_ty.get_shape() + result_rank = result_ty.get_num_dims() + + for i, dim_index in enumerate(bcast_dims): + if dim_index < 0 or dim_index >= result_rank: + raise VerifyException( + "broadcast_dimensions contains invalid value " + f"{dim_index} for result with rank {result_rank}" + ) + op_dim = operand_shape[i] + res_dim = result_shape[dim_index] + # If operand dim is static and not size-1, require compatibility with result dim + if op_dim not in (-1, 1): + if res_dim not in (-1, op_dim): + raise VerifyException( + "size of operand dimension " + f"{i} ({op_dim}) is not compatible with size of result dimension " + f"{dim_index} ({res_dim})" + ) + + def _verify_expansion_hints(self, operand_ty): + """Verify the operation's expansion hints.""" + # dynamic_broadcast_in_dim_c8: no duplicate expansion hints across both lists + operand_rank = operand_ty.get_num_dims() + + hints = [] + if self.known_expanding_dimensions is not None: + hints.extend(self.known_expanding_dimensions.get_values()) # pylint: disable=no-member + if self.known_nonexpanding_dimensions is not None: + hints.extend( + self.known_nonexpanding_dimensions.get_values() # pylint: disable=no-member + ) + if len(set(hints)) != len(hints): + raise VerifyException("duplicate expansion hint for at least one operand dimension") + + # dynamic_broadcast_in_dim_c9/c10: each hint must reference a valid operand dimension + for h in set(hints): + if h < 0 or h >= operand_rank: + raise VerifyException( + "hint for expanding dimension " + f"{h} does not refer to a valid operand dimension" + ) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py new file mode 100644 index 0000000000..8180c7a208 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py @@ -0,0 +1,214 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Binary elementwise operations for the StableHLO dialect. +""" + +import abc +from typing import Generic, TypeVar + +from xdsl.dialects.builtin import AnyTensorType, ComplexType, Float32Type, Float64Type, TensorType +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + result_def, + traits_def, +) +from xdsl.traits import NoMemoryEffect + +from catalyst.python_interface.xdsl_extras import ( + Elementwise, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +from .types import ( + HLO_ComplexTensor, + HLO_Fp32Or64Tensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_Tensor, +) + +# Type aliases +F32Or64Type = Float32Type | Float64Type +F32Or64TensorType = TensorType[F32Or64Type] +ComplexTensorType = TensorType[ComplexType] + +# Generic type variables for templating +T_LHS = TypeVar("T_LHS", bound=AnyTensorType) +T_RHS = TypeVar("T_RHS", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +class ElementwiseBinaryOperation(IRDLOperation, abc.ABC, Generic[T_LHS, T_RHS, T_OUT]): + """ + Templated base class for elementwise binary operations. + + This class provides a flexible template for binary operations that can work + with different tensor types. + + For more information about the semantics, see: + https://openxla.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + """ + + lhs = operand_def(T_LHS) + rhs = operand_def(T_RHS) + result = result_def(T_OUT) + + traits = traits_def( + NoMemoryEffect(), + SameOperandsAndResultShape(), + Elementwise(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $lhs `,` $rhs attr-dict + # `:` custom(type($lhs), type($rhs), type($result)) + # """ + + def __init__(self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None): + if result_type is None: + result_type = lhs.type + super().__init__(operands=(lhs, rhs), result_types=(result_type,)) + + +@irdl_op_definition +class ComplexOp( + ElementwiseBinaryOperation[HLO_Fp32Or64Tensor, HLO_Fp32Or64Tensor, HLO_ComplexTensor] +): + """ + Performs element-wise conversion to a complex value from a pair of real and + imaginary values, `lhs` and `rhs`, and produces a `result` tensor. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex + Example: + ```mlir + %result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex> + ``` + """ + + name = "stablehlo.complex" + + # assembly_format = """ + # operands attr-dict + # `:` custom(type($lhs), type($rhs), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + SameOperandsElementType(), + SameOperandsAndResultShape(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + ) + + +@irdl_op_definition +class DivideOp( + ElementwiseBinaryOperation[ + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + ] +): + """ + Performs element-wise division of dividend `lhs` and divisor `rhs` tensors + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide + + Example: + ```mlir + %result = stablehlo.divide %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.divide" + + +@irdl_op_definition +class MaximumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise max operation on tensors `lhs` and `rhs` and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum + + Example: + ```mlir + %result = stablehlo.maximum %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.maximum" + + +@irdl_op_definition +class MinimumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise min operation on tensors `lhs` and `rhs` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum + + Example: + ```mlir + %result = stablehlo.minimum %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.minimum" + + +@irdl_op_definition +class PowerOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + + Example: + ```mlir + %result = stablehlo.power %lhs, %rhs : tensor<6xf64> + ``` + """ + + name = "stablehlo.power" + + +@irdl_op_definition +class RemainderOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder + + Example: + ```mlir + %result = stablehlo.remainder %lhs, %rhs : tensor<4xi64> + ``` + """ + + name = "stablehlo.remainder" diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py new file mode 100644 index 0000000000..18527a028d --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py @@ -0,0 +1,236 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Other elementwise operations for the StableHLO dialect. +""" + +import xdsl_jax.dialects.stablehlo as xstablehlo +from xdsl.dialects.builtin import ( + AnyFloat, + DenseArrayBase, + DenseIntOrFPElementsAttr, + IntegerAttr, + TensorType, + i32, + i64, +) +from xdsl.irdl import ( + IRDLOperation, + attr_def, + irdl_op_definition, + operand_def, + opt_attr_def, + prop_def, + result_def, + traits_def, + var_operand_def, + var_region_def, +) +from xdsl.irdl.attributes import eq +from xdsl.irdl.constraints import AtLeast +from xdsl.traits import NoMemoryEffect, RecursiveMemoryEffect, SingleBlockImplicitTerminator + +from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape + +from .types import HLO_AnyTensor, HLO_FpOrQuantizedIntTensor, HLO_PredTensor, HLO_Tensor + +# Type aliases +FloatTensorType = TensorType[AnyFloat] + + +@irdl_op_definition +class ClampOp(IRDLOperation): + """Element-wise clamp with min and max bounds. + + See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp + """ + + name = "stablehlo.clamp" + + min = operand_def(HLO_Tensor) + operand = operand_def(HLO_Tensor) + max = operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $min `,` $operand `,` $max attr-dict + # `:` custom(type($min), type($operand), type($max), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + # TODO: HLO_CompatibleOperandsAndResultElementType(), + # TODO: HLO_BroadcastingElementwise(), + # TODO: InferTensorType(), + # TODO: InferShapedTypeOpInterface(), + ) + + +@irdl_op_definition +class CompareOp(IRDLOperation): + """Element-wise compare with direction and type attributes.""" + + name = "stablehlo.compare" + + assembly_format = """ + $comparison_direction `,` $lhs `,` $rhs (`,` $comparison_type^)? attr-dict `:` functional-type(operands, results) + """ + + lhs = operand_def(HLO_Tensor) + rhs = operand_def(HLO_Tensor) + result = result_def(HLO_PredTensor) + comparison_direction = attr_def(xstablehlo.ComparisonDirectionAttr) + comparison_type = opt_attr_def(xstablehlo.ComparisonTypeAttr) + + traits = traits_def( + NoMemoryEffect(), + Elementwise(), + SameOperandsAndResultShape(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + # TODO: HLO_CompatibleOperandsElementType(), + # TODO: InferTensorTypeWithReify(), + ) + + +@irdl_op_definition +class MapOp(IRDLOperation): + """ + Applies a map function `computation` to `inputs` along the `dimensions` and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map + + Example: + ```mlir + %result = "stablehlo.map"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.multiply %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + dimensions = array + } : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> + ``` + """ + + name = "stablehlo.map" + + inputs = var_operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + dimensions = attr_def(DenseArrayBase.constr(i64)) + computation = var_region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + SameOperandsAndResultShape(), + SingleBlockImplicitTerminator(xstablehlo.ReturnOp), + # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic(), + # TODO: InferTypeOpInterface + # TODO: InferShapedTypeOpInterface(), + ) + + +@irdl_op_definition +class ReducePrecisionOp(IRDLOperation): + """ + Performs element-wise conversion of `operand` to another floating-point type + that uses `exponent_bits` and `mantissa_bits` and back to the original + floating-point type and produces an `output` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision + + Example: + ```mlir + %output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64> + ``` + """ + + name = "stablehlo.reduce_precision" + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand `,` `format` `=` custom($exponent_bits, $mantissa_bits) + # attr-dict `:` custom(type($operand), type($output)) + # """ + + operand = operand_def(HLO_FpOrQuantizedIntTensor) + result = result_def(HLO_FpOrQuantizedIntTensor) + + exponent_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(1))) + mantissa_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(0))) + + traits = traits_def( + NoMemoryEffect(), + Elementwise(), + # TODO: HLO_CompatibleOperandsAndResultType(), + # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(), + ) + + +@irdl_op_definition +class SelectOp(IRDLOperation): + """ + Produces a `result` tensor where each element is selected from `on_true` or + `on_false` tensor based on the value of the corresponding element of `pred`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select + + Example: + ```mlir + %result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32> + ``` + """ + + name = "stablehlo.select" + + # assembly_format = """ + # operands attr-dict `:` + # custom(type($pred), type($on_true), type($on_false), type($result)) + # """ + + pred = operand_def(HLO_PredTensor) + on_true = operand_def(HLO_Tensor) + on_false = operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + + traits = traits_def( + NoMemoryEffect(), + ) + + +@irdl_op_definition +class ConstantOp(IRDLOperation): + """ + Produces an ``output`` tensor from a constant ``value``. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant + + Example: + ```mlir + %output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + """ + + name = "stablehlo.constant" + + value = prop_def(DenseIntOrFPElementsAttr) + output = result_def(HLO_AnyTensor) + + def __init__(self, value: DenseIntOrFPElementsAttr): + super().__init__(properties={"value": value}, result_types=(value.type,)) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py new file mode 100644 index 0000000000..a8a6e6ca86 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py @@ -0,0 +1,552 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unary elementwise operations for the StableHLO dialect. +""" + +import abc +from typing import Generic, TypeVar + +from xdsl.dialects.builtin import ( + I1, + AnyFloat, + AnyTensorType, + ComplexType, + TensorType, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + opt_attr_def, + result_def, + traits_def, +) +from xdsl.traits import NoMemoryEffect + +from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape + +from .attributes import ResultAccuracyMode, ResultAccuracyModeAttr +from .types import ( + HLO_FloatTensor, + HLO_FpComplexOrQuantizedIntTensor, + HLO_FpOrComplexTensor, + HLO_FpOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_NonQuantizedTensor, + HLO_PredTensor, + HLO_SIntFpComplexOrQuantizedIntTensor, +) + +# Type aliases +I1TensorType = TensorType[I1] +FloatTensorType = TensorType[AnyFloat] +FloatOrComplexType = AnyFloat | ComplexType +FloatOrComplexTensorType = TensorType[FloatOrComplexType] +ComplexTensorType = TensorType[ComplexType] + +# Generic type variables for templating +T_IN = TypeVar("T_IN", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +class ElementwiseUnaryOperation(IRDLOperation, abc.ABC, Generic[T_IN, T_OUT]): + """ + Templated base class for elementwise unary operations. + + This class provides a flexible template for unary operations that can work + with different tensor types. + + For more informtation about the semantics, see: + https://openxla.org/xla/operation_semantics#element-wise_unary_functions + """ + + operand = operand_def(T_IN) + result = result_def(T_OUT) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand attr-dict `:` custom(type($operand), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + SameOperandsAndResultShape(), + Elementwise(), + # TODO: InferShapedTypeOpInterface(), + # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(), + ) + + def __init__(self, operand: SSAValue, result_type: Attribute | None = None): + if result_type is None: + result_type = operand.type + super().__init__(operands=(operand,), result_types=(result_type,)) + + +@irdl_op_definition +class ConvertOp(ElementwiseUnaryOperation[HLO_NonQuantizedTensor, HLO_NonQuantizedTensor]): + """ + Performs an element-wise conversion from one element type to another on + `operand` tensor and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert + + Example: + ```mlir + %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex> + ``` + """ + + name = "stablehlo.convert" + + traits = traits_def(SameOperandsAndResultShape()) + + +@irdl_op_definition +class CosineOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise cosine operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine + + Example: + ```mlir + %result = stablehlo.cosine %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.cosine" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class ExponentialMinusOneOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise exponential minus one operation on `operand` tensor + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one + + Example: + ```mlir + %result = stablehlo.exponential_minus_one %operand : tensor<2xf64> + ``` + """ + + name = "stablehlo.exponential_minus_one" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class ExponentialOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise exponential operation on `operand` tensor and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential + + Example: + ```mlir + %result = stablehlo.exponential %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.exponential" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class FloorOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]): + """ + Performs element-wise floor of `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor + + Example: + ```mlir + %result = stablehlo.floor %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.floor" + + +@irdl_op_definition +class ImagOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]): + """ + Extracts the imaginary part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag + + Example: + ```mlir + %result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32> + ``` + """ + + name = "stablehlo.imag" + + +@irdl_op_definition +class IsFiniteOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_PredTensor]): + """ + Performs element-wise check whether the value in `x` is finite (i.e. is + neither +Inf, -Inf, nor NaN) and produces a `y` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite + + Example: + ```mlir + %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1> + ``` + """ + + name = "stablehlo.is_finite" + + +@irdl_op_definition +class LogOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise logarithm operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log + + Example: + ```mlir + %result = stablehlo.log %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.log" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class LogPlusOneOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise logarithm plus one operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one + + Example: + ```mlir + %result = stablehlo.log_plus_one %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.log_plus_one" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class LogisticOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise logistic operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic + + Example: + ```mlir + %result = stablehlo.logistic %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.logistic" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class NegateOp( + ElementwiseUnaryOperation[ + HLO_IntFpOrComplexOrQuantizedIntTensor, HLO_IntFpOrComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise negation of `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate + + Example: + ```mlir + %result = stablehlo.negate %operand : tensor<2x3xi32> + ``` + """ + + name = "stablehlo.negate" + + +@irdl_op_definition +class RealOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]): + """ + Extracts the real part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real + + Example: + ```mlir + %result = stablehlo.real %operand : tensor<2xcomplex> : tensor<2xf32> + ``` + """ + + name = "stablehlo.real" + + +@irdl_op_definition +class RoundNearestAfzOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise rounding towards the nearest integer, breaking ties + away from zero, on the `operand` tensor and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz + + Example: + ```mlir + %result = stablehlo.round_nearest_afz %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.round_nearest_afz" + + +@irdl_op_definition +class RoundNearestEvenOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise rounding towards the nearest integer, breaking ties + towards the even integer, on the `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even + + Example: + ```mlir + %result = stablehlo.round_nearest_even %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.round_nearest_even" + + +@irdl_op_definition +class RsqrtOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise reciprocal square root operation on `operand` tensor + and produces a `result` tensor, implementing the `rSqrt` operation from the + IEEE-754 specification. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt + + Example: + ```mlir + %result = stablehlo.rsqrt %operand : tensor<2x2xf32> + ``` + """ + + name = "stablehlo.rsqrt" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class SignOp( + ElementwiseUnaryOperation[ + HLO_SIntFpComplexOrQuantizedIntTensor, HLO_SIntFpComplexOrQuantizedIntTensor + ] +): + """ + Returns the sign of the `operand` element-wise and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign + + Example: + ```mlir + %result = stablehlo.sign %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.sign" + + +@irdl_op_definition +class SineOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise sine operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine + + Example: + ```mlir + %result = stablehlo.sine %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.sine" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class SqrtOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise square root operation on `operand` tensor and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt + + Example: + ```mlir + %result = stablehlo.sqrt %operand : tensor<2x2xf32> + ``` + """ + + name = "stablehlo.sqrt" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class TanOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise tangent operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan + + Example: + ```mlir + %result = stablehlo.tan %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.tan" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class TanhOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise hyperbolic tangent operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh + + Example: + ```mlir + %result = stablehlo.tanh %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.tanh" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py new file mode 100644 index 0000000000..e6f8f0542d --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py @@ -0,0 +1,167 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Dynamism operations for the StableHLO dialect. +""" + + +from xdsl.dialects.builtin import ( + ArrayAttr, + BoolAttr, + DenseIntElementsAttr, + DictionaryAttr, + FlatSymbolRefAttr, + StringAttr, + TensorType, + TupleType, +) +from xdsl.ir import Attribute +from xdsl.irdl import ( + AnyAttr, + IRDLOperation, + irdl_op_definition, + opt_prop_def, + prop_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + MemoryEffect, +) +from xdsl.utils.exceptions import VerifyException + +from .attributes import CustomCallApiVersion, CustomCallApiVersionAttr, OutputOperandAlias + + +@irdl_op_definition +class CustomCallOp(IRDLOperation): + """ + Encapsulates an implementation-defined operation ``call_target_name`` that + takes ``inputs`` and ``called_computations`` and produces ``results``. + + Depending on the API version there are two ways to pass extra bits of static + information to the external function: + 1. Use ``API_VERSION_TYPED_FFI`` which allows passing a dictionary attribute. + 2. Use a previous API version with a ``StringAttr`` to encode backend config. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call + + Example: + ```mlir + %results = stablehlo.custom_call @foo(%input0) { + backend_config = {bar = 42 : i32}, + api_version = 4 : i32, + called_computations = [@foo] + } : (tensor) -> tensor + ``` + """ + + name = "stablehlo.custom_call" + + inputs = var_operand_def(AnyAttr()) + call_target_name = prop_def(StringAttr) + has_side_effect = prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + backend_config = opt_prop_def(DictionaryAttr | StringAttr) + api_version = prop_def( + CustomCallApiVersionAttr, + default_value=CustomCallApiVersionAttr(CustomCallApiVersion.API_VERSION_ORIGINAL), + ) + called_computations = opt_prop_def(ArrayAttr[FlatSymbolRefAttr], default_value=ArrayAttr([])) + operand_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr]) + result_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr]) + output_operand_aliases = prop_def(ArrayAttr[OutputOperandAlias]) + + result = var_result_def(AnyAttr()) + + traits = traits_def( + MemoryEffect(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($call_target_name) `(` $inputs `)` + # attr-dict `:` functional-type(operands, results) + # """ + + def verify_(self) -> None: + """Verify the CustomCallOp.""" + # If both operand and result layout attributes are not specified then nothing to verify. + if self.operand_layouts is None and self.result_layouts is None: + return + + # Layout constraints for either both operands & results or none should be specified. + if (self.operand_layouts is None) != (self.result_layouts is None): + raise VerifyException( + "Layout attributes should be specified for either both operands and results " + "or none." + ) + + assert self.operand_layouts is not None and self.result_layouts is not None + + def verify_types_and_layouts( + types: tuple[Attribute, ...], layouts: ArrayAttr, value_name: str + ): + if len(types) != len(layouts.data): + raise VerifyException( + "Number of " + f"{value_name}s must match the number of {value_name} layouts, " + f"{len(types)} != {len(layouts.data)}" + ) + + for index, (ty, layout_attr) in enumerate(zip(types, layouts.data)): + # Tuple types are not fully supported with layout constraints yet + if isinstance(ty, TupleType): + raise VerifyException( + "Tuple types are not fully supported with layout constraints yet" + ) + + try: + dims = list(layout_attr.get_values()) + except Exception as exc: + raise VerifyException("invalid layout attribute") from exc + + # For non-tensor types, layout must be empty + if not isinstance(ty, TensorType): + if len(dims) == 0: + continue + raise VerifyException( + "Only tensor types can have non-empty layout: " + f"{value_name} #{index} of type {ty} has layout {dims}" + ) + + # For ranked tensors, require permutation of [0, rank) + rank = ty.get_num_dims() + if rank != len(dims) or sorted(dims) != list(range(rank)): + raise VerifyException( + f"incorrect layout {dims} for type {ty}, layout must be a permutation " + f"of [0, {rank})" + ) + + # Operand types + operand_types: tuple[Attribute, ...] = tuple(op.type for op in self.operands) + + # Result types: if single tuple result, use its element types + if len(self.result_types) == 1 and isinstance(self.result_types[0], TupleType): + tuple_ty: TupleType = self.result_types[0] + result_types = tuple(tuple_ty.types.data) + else: + result_types = tuple(self.result_types) + + # Verify that operands and operand layouts match. + verify_types_and_layouts(operand_types, self.operand_layouts, "operand") + # Verify that results and result layouts match. + verify_types_and_layouts(result_types, self.result_layouts, "result") diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py new file mode 100644 index 0000000000..7e9bdcfdf8 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py @@ -0,0 +1,169 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Dynamism operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import DenseArrayBase, i64 +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + prop_def, + region_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.irdl.operations import SameVariadicOperandSize +from xdsl.traits import ( + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) +from xdsl.utils.exceptions import VerifyException +from xdsl_jax.dialects import stablehlo as xstablehlo + +from .types import HLO_Tensor + + +@irdl_op_definition +class ReduceOp(IRDLOperation): + """ + Applies a reduction function ``body`` to ``inputs`` and ``init_values`` along the + ``dimensions`` and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce + + Example: + ```mlir + %result = "stablehlo.reduce"(%input, %init_value) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + dimensions = array + } : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + ``` + """ + + name = "stablehlo.reduce" + + inputs = var_operand_def(HLO_Tensor) + init_values = var_operand_def(HLO_Tensor) + dimensions = prop_def(DenseArrayBase.constr(i64)) + result = var_result_def(HLO_Tensor) + body = region_def("single_block") + + irdl_options = [SameVariadicOperandSize()] + + traits = traits_def( + RecursiveMemoryEffect(), + # TODO: InferShapedTypeOpInterface(), + # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic, + # TODO: InferTensorTypeWithReify(), + SingleBlockImplicitTerminator(xstablehlo.ReturnOp), + ) + + # pylint: disable=no-member + # pylint: disable=too-many-branches + def verify_(self): + """Verify the ReduceOp.""" + # Gather shaped operand/result types + input_types = [op.type for op in self.inputs] + init_types = [op.type for op in self.init_values] + + self._verify_input_and_init_types(input_types, init_types) + + self._verify_reducer_region(input_types) + + def _verify_input_and_init_types(self, input_types, init_types): + """Verify the types of the inputs and init values.""" + + # Basic structural checks mirroring verifyReduceOpInputsAndInferShape + if len(input_types) == 0: + raise VerifyException("expected at least 1 input for reduce") + if len(input_types) != len(init_types): + raise VerifyException("number of inputs must match number of init_values") + + # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility + dims_attr = self.dimensions + dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple() + + # All inputs must have equal rank; dimensions must be within rank and unique + # and not empty. + ranks = [] + for t in input_types: + # Tensors by op definition + assert hasattr(t, "get_num_dims") + ranks.append(t.get_num_dims()) + rank0 = ranks[0] + if any(r != rank0 for r in ranks): + raise VerifyException("all inputs must have the same rank") + + if len(dims) == 0: + raise VerifyException("dimensions cannot be empty for reduce") + if len(set(dims)) != len(dims): + raise VerifyException("dimensions should not have duplicates") + if any(d < 0 or d >= rank0 for d in dims): + raise VerifyException("dimensions contains an invalid value") + + # Element type compatibility between each input and its init value + for it, iv in zip(input_types, init_types): + it_elem = it.get_element_type() + iv_elem = iv.get_element_type() + if it_elem != iv_elem: + raise VerifyException("input and init_value must have the same element type") + + def _verify_reducer_region(self, input_types): + """Verify the operation's reducer region.""" + + # reduce_c2/c6: verify reducer region shape + # Expect block with arity 2 * number of inputs, with matching tensor element types + # and 0D tensors + if len(self.body.blocks) != 1: + raise VerifyException("reducer must have a single block") + block = self.body.blocks[0] + + expected_args = 2 * len(input_types) + if len(block.args) != expected_args: + raise VerifyException( + f"reducer must take {expected_args} arguments, got {len(block.args)}" + ) + + # Each pair (arg_i, arg_{i+N}) must be 0D tensors of the input element type + for i, it in enumerate(input_types): + it_elem = it.get_element_type() + acc = block.args[i] + val = block.args[i + len(input_types)] + for a in (acc, val): + a_ty = a.type + if not hasattr(a_ty, "get_num_dims") or a_ty.get_num_dims() != 0: + raise VerifyException("reducer arguments must be rank-0 tensors") + if a_ty.get_element_type() != it_elem: + raise VerifyException( + "reducer argument element types must match input element type" + ) + + # Region must terminate with exactly len(inputs) results + ret = block.ops.last + if len(ret.operands) != len(input_types): + raise VerifyException("reducer must return exactly one value per input") + for i, it in enumerate(input_types): + it_elem = it.get_element_type() + rty = ret.operands[i].type + if not hasattr(rty, "get_num_dims") or rty.get_num_dims() != 0: + raise VerifyException("reducer return values must be rank-0 tensors") + if rty.get_element_type() != it_elem: + raise VerifyException("reducer return element types must match input element type") diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/types.py b/frontend/catalyst/python_interface/dialects/stablehlo/types.py new file mode 100644 index 0000000000..a27f4fc8f2 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/types.py @@ -0,0 +1,247 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO type definitions for PennyLane's compiler infrastructure. + +This module provides type definitions based on the StableHLO specification +(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including +token types and other necessary type definitions for StableHLO operations. +""" + +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + AnyFloatConstr, + ComplexType, + Float32Type, + Float64Type, + IndexType, + IntAttr, + IntAttrConstraint, + IntegerType, + ParametrizedAttribute, + Signedness, + SignednessAttr, + TensorType, + i1, +) +from xdsl.irdl import eq, irdl_attr_definition +from xdsl.irdl.attributes import EqAttrConstraint, ParamAttrConstraint +from xdsl.irdl.constraints import IntSetConstraint +from xdsl_jax.dialects.stablehlo import TokenType + +from catalyst.python_interface.xdsl_extras.constraints import ( + NestedTupleOfConstraint, +) + + +def _create_param_constrained_type( + base_attr: type, widths: list[int], signedness: Signedness | None = None +): + """Create an integer type constrained using ParamAttrConstraint with IntSetConstraint.""" + width_constraint = IntAttrConstraint(IntSetConstraint(frozenset(widths))) + + if signedness is None: + signedness_constraint = None + else: + signedness_constraint = EqAttrConstraint(SignednessAttr(signedness)) + + return ParamAttrConstraint(base_attr, [width_constraint, signedness_constraint]) + + +# ============================================================================= +# Core StableHLO types constraints +# ============================================================================= + +HLO_Pred = eq(i1) +HLO_PredTensor: TypeAlias = TensorType[HLO_Pred] + +# NOTE: IntegerType is defined in the StableHLO spec as: +# IntegerType ::= SignedIntegerType | UnsignedIntegerType, +# but the MLIR implementation is using signless integers instead of signed, +# and there is a TODO to fix it. + +_HLO_INT_WIDTHS = [2, 4, 8, 16, 32, 64] +HLO_SignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.SIGNED) +HLO_UnsignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.UNSIGNED) +HLO_SignlessInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, None) + +HLO_Int: TypeAlias = HLO_UnsignedInt | HLO_SignlessInt +HLO_IntTensor: TypeAlias = TensorType[HLO_Int] + +_HLO_INT_OR_PRED_WIDTHS = [1, 2, 4, 8, 16, 32, 64] +HLO_IntOrPred = _create_param_constrained_type(IntegerType, _HLO_INT_OR_PRED_WIDTHS, None) + + +HLO_AnyIntegerOrIndex: TypeAlias = IntegerType | IndexType +HLO_AnyIntegerOrIndexTensor: TypeAlias = TensorType.constr(HLO_AnyIntegerOrIndex) + +HLO_DimensionValue: TypeAlias = HLO_Int | IndexType + +# Constraint variants for use in unions with ParamAttrConstraint +HLO_Float: TypeAlias = AnyFloatConstr +HLO_Float32Or64: TypeAlias = Float32Type | Float64Type +HLO_FloatTensor: TypeAlias = TensorType.constr(HLO_Float) +HLO_Fp32Or64Tensor: TypeAlias = TensorType.constr(HLO_Float32Or64) + +# Complex as a constraint over element types {f32,f64} +HLO_Complex: TypeAlias = ComplexType[HLO_Float32Or64] +HLO_ComplexTensor: TypeAlias = TensorType.constr(HLO_Complex) + +# ============================================================================= +# Quantized element type definitions +# ============================================================================= + + +@irdl_attr_definition +class UniformQuantizedType(ParametrizedAttribute): + """ + Placeholder for StableHLO per-tensor uniform quantized types. + + Parameterized by width to support different quantized integer widths + (e.g., 8-bit, 16-bit quantization). + """ + + name = "stablehlo.uniform_quantized" + width: IntAttr + signedness: SignednessAttr + + +@irdl_attr_definition +class UniformQuantizedPerAxisType(ParametrizedAttribute): + """ + Placeholder for StableHLO per-axis uniform quantized types. + + Parameterized by width to support different quantized integer widths + (e.g., 8-bit, 16-bit quantization). + """ + + name = "stablehlo.uniform_quantized_per_axis" + width: IntAttr + signedness: SignednessAttr + + +# ============================================================================= +# StableHLO quantized type aliases +# ============================================================================= + +_HLO_QUANTIZED_WIDTHS = [2, 4, 8, 16, 32] + +# Constraint-based types for operation definitions +HLO_QuantizedSignedInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED +) +HLO_QuantizedUnsignedInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED +) +HLO_QuantizedAnySignednessInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, None +) +HLO_QuantizedInt: TypeAlias = HLO_QuantizedSignedInt | HLO_QuantizedUnsignedInt + +HLO_PerAxisQuantizedSignedInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED +) +HLO_PerAxisQuantizedUnsignedInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED +) +HLO_PerAxisQuantizedAnySignednessInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, None +) +HLO_PerAxisQuantizedInt: TypeAlias = HLO_PerAxisQuantizedSignedInt | HLO_PerAxisQuantizedUnsignedInt + +# ============================================================================= +# Main tensor type definitions +# ============================================================================= + +HLO_Tensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt] +HLO_NonQuantizedTensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred] + +# Note: There is a discrepancy between the StableHLO spec and the MLIR implementation. +# The spec does not allow unranked tensors, but the MLIR implementation +# defines it as a tensor of any type and rank. There is a TODO to fix this in MLIR. +# Therefore, we use the correct ranked tensor type. +HLO_AnyTensor: TypeAlias = TensorType[ + HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt | HLO_PerAxisQuantizedInt +] +HLO_TensorOrToken: TypeAlias = HLO_Tensor | TokenType +HLO_TensorOrPerAxisQuantizedTensorOrToken: TypeAlias = HLO_AnyTensor | TokenType + +# HLO_AnyTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Token]> +HLO_AnyTuple = NestedTupleOfConstraint([HLO_AnyTensor, TokenType]) + +HLO_CustomCallValue: TypeAlias = HLO_Tensor | TokenType | HLO_AnyTuple + +# ============================================================================= +# HLO combined type definitions +# ============================================================================= + +HLO_PredOrIntTensor: TypeAlias = TensorType.constr(HLO_IntOrPred) + +HLO_FpOrComplexTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_Complex) +HLO_FpOrQuantizedIntTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_QuantizedInt) +HLO_FpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_Float | HLO_Complex | HLO_QuantizedInt +) +HLO_IntFpOrComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_Int | HLO_Float | HLO_Complex | HLO_QuantizedInt +) +HLO_SIntFpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_SignedInt | HLO_Float | HLO_Complex | HLO_QuantizedInt +) + + +__all__ = [ + # Core types + "HLO_Pred", + "HLO_PredTensor", + "HLO_Int", + "HLO_IntTensor", + "HLO_AnyIntegerOrIndex", + "HLO_AnyIntegerOrIndexTensor", + "HLO_DimensionValue", + "HLO_Float", + "HLO_Float32Or64", + "HLO_FloatTensor", + "HLO_Fp32Or64Tensor", + "HLO_ComplexTensor", + "HLO_SignedInt", + "HLO_UnsignedInt", + "HLO_SignlessInt", + "HLO_QuantizedSignedInt", + "HLO_QuantizedUnsignedInt", + "HLO_QuantizedAnySignednessInt", + "HLO_QuantizedInt", + "HLO_PerAxisQuantizedSignedInt", + "HLO_PerAxisQuantizedUnsignedInt", + "HLO_PerAxisQuantizedAnySignednessInt", + "HLO_PerAxisQuantizedInt", + # Quantized types + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "HLO_Tensor", + "HLO_NonQuantizedTensor", + "HLO_AnyTensor", + "HLO_TensorOrToken", + "HLO_TensorOrPerAxisQuantizedTensorOrToken", + "HLO_CustomCallValue", + # Combined types + "HLO_PredOrIntTensor", + "HLO_FpOrComplexTensor", + "HLO_FpOrQuantizedIntTensor", + "HLO_FpComplexOrQuantizedIntTensor", + "HLO_IntFpOrComplexOrQuantizedIntTensor", + "HLO_SIntFpComplexOrQuantizedIntTensor", +] diff --git a/frontend/catalyst/python_interface/dialects/transform.py b/frontend/catalyst/python_interface/dialects/transform.py new file mode 100644 index 0000000000..32449e7d5e --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/transform.py @@ -0,0 +1,122 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains an updated version of the transform dialect. +As of the time of writing, xDSL uses the MLIR released with LLVM's +version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916 +will be updating MLIR where the transform dialect has the +`apply_registered_pass` operation re-defined. + +See the following changelog on the above PR + + Things related to transform.apply_registered_pass op: + + It now takes in a dynamic_options + + [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as + a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static. + https://github.com/llvm/llvm-project/pull/142683 + + The options it takes in are now dictionaries instead of strings + [MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159 + https://github.com/llvm/llvm-project/pull/143159 + +This file will re-define the apply_registered_pass operation in xDSL +and the transform dialect. + +Once xDSL moves to a newer version of MLIR, these changes should +be contributed upstream. +""" + +from xdsl.dialects.builtin import Dialect +from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp +from xdsl.dialects.transform import ( + DictionaryAttr, + StringAttr, +) +from xdsl.dialects.transform import Transform as xTransform +from xdsl.dialects.transform import ( + TransformHandleType, + irdl_op_definition, + operand_def, + prop_def, + result_def, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import IRDLOperation, ParsePropInAttrDict + + +# pylint: disable=line-too-long +@irdl_op_definition +class ApplyRegisteredPassOp(IRDLOperation): + """ + See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop). + """ + + name = "transform.apply_registered_pass" + + options = prop_def(DictionaryAttr, default_value=DictionaryAttr({})) + pass_name = prop_def(StringAttr) + target = operand_def(TransformHandleType) + result = result_def(TransformHandleType) + # While this assembly format doesn't match + # the one in upstream MLIR, + # this is because xDSL currently lacks CustomDirectives + # https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives + # https://github.com/xdslproject/xdsl/pull/4829 + # However, storing the property in the attribute should still work + # specially when parsing and printing in generic format. + # Which is how Catalyst and XDSL currently communicate at the moment. + # TODO: Add test. + assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)" + irdl_options = [ParsePropInAttrDict()] + + def __init__( + self, + pass_name: str | StringAttr, + target: SSAValue, + options: dict[str | StringAttr, Attribute | str | bool | int] | None = None, + ): + if isinstance(pass_name, str): + pass_name = StringAttr(pass_name) + + if isinstance(options, dict): + options = DictionaryAttr(options) + + super().__init__( + properties={ + "pass_name": pass_name, + "options": options, + }, + operands=[target], + result_types=[target.type], + ) + + +# Copied over from xDSL's sources +# the main difference will be the use +# of a different ApplyRegisteredPassOp +operations = list(xTransform.operations) +del operations[operations.index(xApplyRegisteredPassOp)] +operations.append(ApplyRegisteredPassOp) + +Transform = Dialect( + "transform", + [ + *operations, + ], + [ + *xTransform.attributes, + ], +) diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst new file mode 100644 index 0000000000..1525c030ee --- /dev/null +++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst @@ -0,0 +1,1376 @@ +Unified Compiler Cookbook +========================= + +**Note:** The cookbook is developed with the following package versions, +on Python 3.12.11: + +.. code-block:: bash + + jax==0.6.2 + jaxlib==0.6.2 + numpy==2.3.1 + pennylane==0.44.0-dev19 + pennylane-lightning==0.43.0 + pennylane-catalyst==0.14.0-dev15 + xdsl==0.53.0 + xdsl-jax==git+https://github.com/xdslproject/xdsl-jax.git@895f7c13e8d0f02bbe99d7fb9ebcaafea4ea629f#egg=xdsl_jax + +Note that ``xdsl-jax`` does not currently have a release published on +PyPI, so it needs to be installed from GitHub by running the following: + +.. code-block:: bash + + pip install git+https://github.com/xdslproject/xdsl-jax.git + +Motivation +========== + +As we approach FTQC, quantum compilation becomes a more and more +important research area. Catalyst uses MLIR as its intermediate +representation, which is also the layer in which a majority of the +optimizations happen. However, quantum compilation researchers are not +likely to be accustomed with MLIR or C++. + +So, the motivation of the β€œUnified Compiler” is to provide a Python +layer in which compilation passes can be implemented and applied. +Additionally, we also want to enable researchers to use abstractions +that they are familiar with. We’re aiming to do this using xDSL, which +is a reimplementation of MLIR in Python. + +This document is meant to be a quickstart for users that are interested +in developing compiler passes for Catalyst, but are not familiar with +MLIR or xDSL. + +MLIR basics +=========== + +If readers are already familiar to some degree with MLIR, this section +can be skipped. + +SSA +--- + +β€œIn compiler design, static single assignment form (often abbreviated as +SSA form or simply SSA) is a type of intermediate representation (IR) +where each variable is assigned exactly once” [`1 <#references>`__]. + +SSA is powerful because it allows us to define chains of uses and +definitions, i.e, we can keep track of which operation created a +variable (since each variable only gets created once), and all +operations that use that variable. These chains of uses and definitions, +if used well, can make transformations quite performant, and in some +cases, very simple to implement, but they can also make the IR harder to +parse. + +IR structure +------------ + +MLIR represents programs using a hierarchical, graph-like structure. In +this structure, nodes are called *operations*, and edges are called +*values*. Each value is the result of exactly one operation or argument +for a block (more on that later), and has a type that is defined by the +type system (more on that later also) [`2 <#references>`__]. + +The IR is recursively nested - operations may contain regions, which +contain blocks, which contain operations. More concretely, an operation +may contain zero or more regions, each of which must contain one or more +blocks, each of which holds a list of arguments and an ordered list of +operations that may use those arguments [`3 <#references>`__]. + +Operations +~~~~~~~~~~ + +Operations are basic units of execution. Operations are fully +extensible, i.e., there is no fixed list of operations. Operations can +return zero or more results, take in zero or more operands, declare +properties and attributes, and can have zero or more successors and +regions [`2 <#references>`__]. + +Regions +~~~~~~~ + +A region is an ordered list of blocks. Its semantics are defined by the +operation that contains them [`2 <#references>`__]. For example, an +``scf.IfOp``, which represents a conditional operation, contains two +regions, one for the true branch, and another for the false branch, and +the way we interpret these regions is dependent on the fact that they +belong to the ``scf.IfOp``. + +Blocks +~~~~~~ + +Blocks are lists of operations. The operations inside blocks are +executed in order. Blocks take a list of block arguments, annotated in a +function-like way. The first block in a region is special, and is called +the β€œentry block”. The block arguments of the entry block are also +arguments to its outer region [`2 <#references>`__]. + +Values +~~~~~~ + +In MLIR, there are computable values with a type, a single defining +operation, and zero or more uses [`4 <#references>`__]. These values are +either the results of operations or block arguments, and adhere to SSA +semantics. + +Dialects +-------- + +MLIR uses dialects to allow developers to define a set of high level +operations, attributes, and types, which can be used in the intermediate +representation to represent custom subroutines, etc. and can be +converted into lower level representations using interpretation rules +that can be defined. For example, the MLIR ``arith`` dialect defines +operations for arithmetic computations, the ``linalg`` dialect defines +linear algebra operations, etc. We use dialects to define quantum +instructions such as gates, measurements, and attributes such as qubits, +quantum registers. + +Def-use and use-def chains +-------------------------- + +Frequently, we might have a value and we want to determine which +instructions use this value. The list of all users of a value is the +*def-use chain*. Conversely, we might have an instruction and we want to +determine which values it uses. The list of values used by an +instruction is the *use-def chain* [`5 <#references>`__]. These chains +allow us to iterate through operations topologically, which can be very +powerful when implementing passes. + +xDSL API +======== + +Now that we’re familiar with the high-level constructs that MLIR uses, +let’s go over what their xDSL actual implementation looks like. + +``SSAValue`` +------------ + +``SSAValue`` is the class used to represent variables. SSA values may +used by operations as operands, and returned as results. The three key +properties that this class provides are listed below: + +- ``type``: The type of the value. Since SSA values are variables, their + value is not known at compile time, but their type is. +- ``uses``: A set of all operations that use a given ``SSAValue`` as an + operand. The operations are wrapped around a convenience class called + ``Use``, and the corresponding operation can be accessed using + ``Use.operation``. +- ``owner``: The operation or block that defined a given ``SSAValue`` + (more on that later) + +``SSAValue`` has two subclasses, which will be seen a lot when +inspecting xDSL modules. These are: + +- ``OpResult``: This subclass represents SSA values that are defined by + an operation + + .. code-block:: + + c = add a b + + In the above pseudocode, ``c`` is defined by the operation ``add``, + and in xDSL, will be represented as an ``OpResult`` instance. + +- ``BlockArgument``: This subclass represents SSA values that are block + arguments + +``Attribute`` +------------- + +An attribute defines a compile-time value. These can be used to define +types, and can also be used to define properties of operations that are +known at compile time. For example: + +- ``CustomOp``, which is an operation from the ``Quantum`` dialect (more + on that later) has a ``gate_name`` that must be a string, represented + by the ``StringAttr`` attribute. ``StringAttr`` has a reference to the + concrete string representing the gate name, and we can access this + concrete value at compile-time using ``CustomOp.gate_name.data``. +- ``CustomOp`` also takes qubits as inputs and outputs, which are of + type ``QubitType``. The ``QubitType`` class inherits ``Attribute``, + but we use it to declare a type that can be used to define SSA values. + +Below is the definition of ``QubitType`` to illustrate: + +.. code-block:: python + + # TypeAttribute just means that this attribute can be used to represent + # the types of the operands and results of an operation, but not its + # properties. + # ParametrizedAttribute means that this attribute can take parameters as + # input (although QubitType doesn't have any parameters). + @irdl_attr_definition + class QubitType(ParametrizedAttribute, TypeAttribute): + """A value-semantic qubit (state).""" + + name = "quantum.bit" + +``Operation`` +------------- + +The ``Operation`` class is used to represent operations, which are basic +units of execution. All instructions in a module are operations. In +fact, the modules themselves are operations. Operations contain several +fields used to define their form and function: + +- Operands: operands are runtime values that the operation consumes. + Note that when defining an operation, we only declare the *types* of + the operands, not the actual operands. Only when we construct an + instance of an operation do we provide actual ``SSAValue``\ s as the + operands, and these must adhere to the type system. +- Properties: properties are compile-time values used to define the + semantics of an operation. For example, the ``CustomOp`` operation + that is used to define quantum gates in Catalyst has a property called + ``gate_name``, which is a string specifier of the gate’s name. This + name directly impacts how the operation should be interpreted, and its + value is known at compile-time. +- Attributes: Operation attributes are stored as a dictionary, + containing more compile-time values. Generally, these don’t get used + much in xDSL, but they serve a purpose very similar to properties. +- Result types: operations may return values, and if so, the types of + the return values must be defined. + +Below is the definition of ``CustomOp`` from the ``Quantum`` dialect, +which represents general quantum gates to illustrate what defining +operations looks like: + +.. code-block:: python + + @irdl_op_definition + class CustomOp(IRDLOperation): + """A generic quantum gate on n qubits with m floating point parameters.""" + + name = "quantum.custom" + + # assembly_format defines what the operation should look like + # when pretty-printed in the IR + assembly_format = """ + $gate_name `(` $params `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + # These options are used because we have operands whose lengths aren't + # known (eg. different instances of CustomOp may have different number + # of qubits depending on the gate they represent (2 for CNOT, 1 for + # RX, etc.). These options basically say, "when the operation instance + # is initialized, create 2 properties that store the length of each of + # the different groups of operands and results. + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + # var_operand_def means that the length of this operand + # can vary. + params = var_operand_def(EqAttrConstraint(Float64Type())) + + in_qubits = var_operand_def(BaseAttr(QubitType)) + + # prop_def means that gate_name is a required property + gate_name = prop_def(BaseAttr(StringAttr)) + + # opt_prop_def means adjoint is an optional property. Additionally, + # it's type is a UnitAttr(), which essentially means that a given + # instance of CustomOp is an adjoint gate iff it has an adjoint property. + # The value of the property is irrelevant; it only gets meaning from its + # existance. + adjoint = opt_prop_def(EqAttrConstraint(UnitAttr())) + + in_ctrl_qubits = var_operand_def(BaseAttr(QubitType)) + + in_ctrl_values = var_operand_def(EqAttrConstraint(IntegerType(1))) + + # var_result_def means that the length of out_qubits can vary + out_qubits = var_result_def(BaseAttr(QubitType)) + + out_ctrl_qubits = var_result_def(BaseAttr(QubitType)) + +.. _dialects-1: + +Dialects +-------- + +The ``Dialect`` class in xDSL is used as a container around a list of +operations, types, and attributes, and it also declares a name for the +dialect. At the time of writing this, there are currently 4 custom +dialects available in the xDSL layer of Catalyst: + +- ``Quantum``: this dialect contains operations and attributes necessary + for general qubit-level operations, such as gates, measurements, + qubits, etc. +- ``Catalyst``: this dialect contains operations and attributes for + classical computing features unavailable out of the box with xDSL/MLIR +- ``QEC``: This dialect contains operations and attributes useful for + QEC, such as PPRs/PPMs. +- ``MBQC``: This dialect contains operations and attributes for + representing MBQC formalism. + +Pass API +-------- + +xDSL provides an API for defining and applying transformations on +programs (or modules), which is described below: + +``ModulePass`` +~~~~~~~~~~~~~~ + +``ModulePass`` is used to create rewrite passes over an IR module. It is +the parent class used to define compiler passes (or transforms). +``ModulePass`` has two key fields that must be implemented + +- ``name``: This is the name that is used to reference the pass. +- ``apply``: This method takes a ``ModuleOp`` as input and applies the + rewrite pattern of the pass to the module. Note that this mutates the + input module *in-place* rather than returning a new, transformed + module. + +``RewritePattern`` +~~~~~~~~~~~~~~~~~~ + +``RewritePattern`` is the class that provides the API for pattern +matching. The most important method of this class is +``match_and_rewrite``. The first argument to this method is an +operation, and it must be type-hinted using the specific operation we’re +trying to match. This type hint gets used by xDSL to match the operation +we’re trying to rewrite. The second argument is a ``PatternRewriter`` +instance. I will cover this class in detail below, but it is essentially +the class that provides the API for rewriting the operations that we’re +matching. + +For example, if I wanted to match all Hadamard gates, I would use +``CustomOp`` from the ``Quantum`` dialect in the type hint for the first +argument (since there is no ``Hadamard`` operation in the ``Quantum`` +dialect), and check in the body of the method if the op is a +``Hadamard``: + +.. code-block:: python + + from xdsl import pattern_rewriter + from catalyst.python_interface.dialects.quantum import CustomOp + + class MyPattern(pattern_rewriter.RewritePattern): + """Dummy class for example.""" + + # This decorator is what xDSL uses to match operations + # based on the type hint. + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter + ): + if op.gate_name.data != "Hadamard": + # If not Hadamard, we do nothing + return + + # Do whatever we need + +``PatternRewriter`` +~~~~~~~~~~~~~~~~~~~ + +``PatternRewriter`` is the class that provides the API for rewriting the +IR. It includes several methods for replacing/removing/updating +operations, replacing values, replacing uses of a value with another +value, etc. In most cases, any rewriting that users want to do must be +done through this API rather than manually, as it includes state +management for keeping track of whether any changes were made, which is +necessary for the worklist algorithm (more on that later). + +Some key methods are: + +- ``replace_op``: Replaces one operation with another. +- ``replace_all_uses_with``: Replaces all uses of one value with + another. +- ``erase_op``: Erases an operation. If this operation returns any + values, all uses of these values must be updated accordingly before + the erasure. +- ``notify_op_modified``: Method to notify the rewriter that a change + was made to an operation manually. + +The example below shows us implementing a ``RewritePattern`` that +updates all ``Hadamard``\ s with ``PauliX``\ s: + +.. code-block:: python + + from xdsl import pattern_rewriter + from xdsl.dialects import builtin + from catalyst.python_interface.dialects.quantum import CustomOp + + class HToXPattern(pattern_rewriter.RewritePattern): + """Dummy class for example.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter + ): + if op.gate_name.data != "Hadamard": + # If not Hadamard, we do nothing + return + + # Update the gate name to PauliX and notify the rewriter that + # the op was manually updated + op.gate_name = builtin.StringAttr("PauliX") + rewriter.notify_op_modified(op) + + # Alternatively, we could also create a new CustomOp for + # the PauliX from scratch, and replace the Hadamard with + # the new op: + # new_op = CustomOp( + # gate_name="PauliX", + # in_qubits=op.in_qubits, + # ) + # rewriter.replace_op(op, new_op) + +``PatternRewriteWalker`` +~~~~~~~~~~~~~~~~~~~~~~~~ + +``PatternRewriteWalker`` walks over the IR in depth-first order, and +applies a provided ``RewritePattern`` to it. By default, it implements a +worklist algorithm that keeps iterating over the operations and matching +and rewriting until a steady state is reached (i.e.Β no new changes are +detected; this is why ``PatternRewriter`` needs to keep track of whether +any changes were made). + +Putting everything together, we can create a ``ModulePass`` that +replaces all ``Hadamard``\ s with ``PauliX``\ s + +.. code-block:: python + + from xdsl import passes, pattern_rewriter + from xdsl.dialects import builtin + from catalyst.python_interface.dialects.quantum import CustomOp + from catalyst.python_interface import compiler_transform + + class HToXPattern(pattern_rewriter.RewritePattern): + """Dummy class for example.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter + ): + if op.gate_name.data != "Hadamard": + # If not Hadamard, we do nothing + return + + # Update the gate name to PauliX and notify the rewriter that + # the op was manually updated + op.gate_name = builtin.StringAttr("PauliX") + rewriter.notify_op_modified(op) + + class HToXPass(passes.ModulePass): + """Pass that replaces Hadamards with PauliXs""" + + name = "h-to-x" + + def apply(self, ctx, module): + """Apply the iterative pass.""" + walker = pattern_rewriter.PatternRewriteWalker( + pattern=HToXPattern() + ) + walker.rewrite_module(module) + + # We will cover this later + h_to_x_pass = compiler_transform(HToXPass) + +``PassPipeline`` +~~~~~~~~~~~~~~~~ + +``PassPipeline`` is a meta-pass that takes a sequence of +``ModulePass``\ es as input, and applies them to the input module. The +following example shows how a sequence of ``ModulePass``\ s can be +applied to a module ``mod``: + +.. code-block:: python + + from xdsl import passes + + pipeline = passes.PassPipeline((Pass1(), Pass2(), Pass3())) + pipeline.apply(xdsl.context.Context(), mod) + +To complete the example we’ve been building in this section, let’s put +it all together and implement a ``PassPipeline`` to apply the +``HToXPass`` to an xDSL module. + +Let’s first create the module to which we want to apply the pass. For +this, we will use the ``xdsl_from_qjit`` utility, which is described in +the β€œPennyLane integration” section below. + +- **Creating the module** + + .. code-block:: python + + import pennylane as qml + from catalyst.python_interface.conversion import xdsl_from_qjit + + dev = qml.device("lightning.qubit", wires=3) + + @xdsl_from_qjit + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(): + qml.Hadamard(0) + qml.Hadamard(1) + qml.Hadamard(2) + return qml.state() + + >>> mod = circuit() + >>> print(mod) + builtin.module @circuit { + func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor + %3 = quantum.alloc(3) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = quantum.custom "Hadamard"() %5 : !quantum.bit + %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %8 = tensor.extract %7[] : tensor + %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit + %10 = quantum.custom "Hadamard"() %9 : !quantum.bit + %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor + %12 = tensor.extract %11[] : tensor + %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit + %14 = quantum.custom "Hadamard"() %13 : !quantum.bit + %15 = tensor.extract %0[] : tensor + %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit + %17 = tensor.extract %7[] : tensor + %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit + %19 = tensor.extract %11[] : tensor + %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit + %21 = quantum.compbasis qreg %20 : !quantum.obs + %22 = quantum.state %21 : tensor<8xcomplex> + quantum.dealloc %20 : !quantum.reg + quantum.device_release + func.return %22 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + } + +- **Transforming the module** + + In the above module, there are 3 ``CustomOp``\ s, each with gate name + ``Hadamard``. Let’s try applying our pass to it. Bear in mind that + passes update modules in-place: + + .. code-block:: python + + from xdsl import passes + + pipeline = passes.PassPipeline((HToXPass(),)) + pipeline.apply(xdsl.context.Context(), mod) + + >>> print(mod) + builtin.module @circuit { + func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor + %3 = quantum.alloc(3) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = quantum.custom "PauliX"() %5 : !quantum.bit + %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %8 = tensor.extract %7[] : tensor + %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit + %10 = quantum.custom "PauliX"() %9 : !quantum.bit + %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor + %12 = tensor.extract %11[] : tensor + %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit + %14 = quantum.custom "PauliX"() %13 : !quantum.bit + %15 = tensor.extract %0[] : tensor + %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit + %17 = tensor.extract %7[] : tensor + %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit + %19 = tensor.extract %11[] : tensor + %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit + %21 = quantum.compbasis qreg %20 : !quantum.obs + %22 = quantum.state %21 : tensor<8xcomplex> + quantum.dealloc %20 : !quantum.reg + quantum.device_release + func.return %22 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + } + + Great! We can see that all the ``Hadamard``\ s have been replaced with + ``PauliX``\ s, just how we wanted. + +PennyLane integration +===================== + +This section will cover the API in the ``qml.compiler.python_compiler`` +submodule. + +Lowering to MLIR +---------------- + +Catalyst compiles programs using the following workflow, when program +capture is enabled: + +- Trace user function to create ``ClosedJaxpr`` (plxpr) +- Convert plxpr to value-semantic jaxpr +- Lower value-semantic jaxpr to MLIR +- Apply passes defined in a static pipeline to the MLIR + + - These passes include optimizations, further lowering to more + elementary dialects, and lowering to LLVM + +- Generate machine code + +The integration with the xDSL layer happens after we lower to MLIR. We +currently rely on JAX’s API to lower to MLIR. This has the special +effect of lowering to a specific dialect called StableHLO, which is used +to represent all arithmetic operations present in the program. + +Once lowered to MLIR, if the original ``qjit`` decorator specified the +xDSL pass plugin, we pass control over to the xDSL layer, which applies +all transforms that were requested by the user. We can request the use +of the xDSL plugin like so: + +.. code-block:: python + + from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + ... + +.. _ir-structure-1: + +IR structure +~~~~~~~~~~~~ + +The lowered MLIR has a lot of structure to aid developers, which is +described below using the following example: + +.. code-block:: python + + import pennylane as qml + + qml.capture.enable() + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit(): + qml.X(0) + return qml.state() + +>>> print(circuit.mlir) +module @circuit { + func.func public @jit_circuit() -> tensor<2xcomplex> attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> + } + module @module_circuit { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "merge-rotations" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + %1 = transform.apply_registered_pass "remove-chained-self-inverse" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + func.func public @circuit() -> tensor<2xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %c0_i64 = arith.constant 0 : i64 + quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.state %3 : tensor<2xcomplex> + quantum.dealloc %2 : !quantum.reg + quantum.device_release + return %4 : tensor<2xcomplex> + } + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } +} + +- The program is represented as a module with a function inside it. This + function contains non-QNode user code (which there is none in our + example for simplicity), and calls to QNodes using the + ``catalyst.launch_kernel`` operation. +- QNodes are represented as modules as wellβ€”each QNode has its own + module. These modules have 4 key components: + + - The ``module attributes {transform.with_named_sequence}`` contains + the transforms that the user requested for the QNode. In our + example, those are the ``merge_rotations`` and ``cancel_inverses`` + transforms. + - The ``@circuit`` function represents the body of the QNode. Inside + this body, we initialize a device using the specified shots, + allocate a quantum register, and then apply both quantum and + classical instructions. Note that quantum instructions can *only* be + present inside this function. Note that this function will contain + an attribute called ``qnode``. + +SSA in the ``Quantum`` dialect +------------------------------ + +In the ``Quantum`` dialect, as mentioned earlier, gates are represented +using the ``CustomOp`` operation. This operation accepts ``in_qubits`` +and ``in_ctrl_qubits``, which are variable length sequences of +``QubitType`` attributes, which correspond to wire indices. +``CustomOp``\ s also return ``out_qubits`` and ``out_ctrl_qubits`` of +the same type. + +Before a ``QubitType`` can be used by any gates, it must be extracted +from the quantum register, or a ``QregType``. Quantum registers can be +thought of as a sequence containing all valid wire indices that are +available to be used by gates/measurements. Qubits can be extracted from +and inserted into a quantum register using the ``ExtractOp`` and +``InsertOp`` operations. An ``AllocOp`` is used to allocate a quantum +register with the user-provided number of device wires. + +Let’s take a look at a very simple example. Below, I have a very simple +circuit with two gates applied to the same wire. Let’s take a look at +its MLIR representation: + +- **Example** + + .. code-block:: python + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(): + qml.X(0) + qml.H(0) + return qml.state() + + >>> print(circuit.mlir) + module @circuit { + func.func public @jit_circuit() -> tensor<8xcomplex> attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> + } + module @module_circuit { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit() -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %c0_i64 = arith.constant 0 : i64 + quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %0 = quantum.alloc( 3) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.state %3 : tensor<8xcomplex> + quantum.dealloc %2 : !quantum.reg + quantum.device_release + return %4 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } + } + +**Notes** + +- ``quantum.alloc`` initializes a quantum register (``%0``) with 3 + wires, which is the number of wires used to create the PennyLane + device. +- ``quantum.extract`` extracts a qubit corresponding to wire index 0 + (``%1``) + + - ``%1`` is used as input by the ``X`` gate. + +- The ``X`` gate returns a qubit (``%out_qubits``), which is used by the + ``quantum.custom`` corresponding to the ``H`` gate. + + - Note that this is different from how the circuit is defined in + Python. Instead of just using ``%1`` again, the ``H`` gate uses + ``%out_qubits``. + +- The ``H`` gate returns a new qubit (``%out_qubits_0``). This qubit is + consumed by ``quantum.insert``, which inserts updates the quantum + register to essentially say that ``%out_qubits_0`` should be the new + qubit that corresponds to wire index 0. +- In this example, note that ``%1``, ``%out_qubits``, and + ``%out_qubits_0`` all correspond to wire index 0. This has cool + implications, that I will discuss below. + +Dynamic wires +~~~~~~~~~~~~~ + +The lowering rules for Catalyst automatically handle dynamic wires. When +a new dynamic wire, ``w``, is used, all wires used before it are first +inserted into the quantum register using ``InsertOp``. Only then does +the ``QubitType`` corresponding to ``w`` get extracted from the quantum +register using ``ExtractOp``. Consider the following example: + +- **Example** + + .. code-block:: python + + import pennylane as qml + + qml.capture.enable() + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(w1: int, w2: int): + qml.X(0) + qml.Y(w1) + qml.Z(w1) + qml.S(w2) + qml.T(w1) + qml.H(0) + return qml.state() + + >>> print(circuit.mlir) + module @circuit { + func.func public @jit_circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg0, %arg1) : (tensor, tensor) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> + } + module @module_circuit { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %c0_i64 = arith.constant 0 : i64 + quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %0 = quantum.alloc( 3) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %extracted = tensor.extract %arg0[] : tensor + %3 = quantum.extract %2[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "PauliY"() %3 : !quantum.bit + %out_qubits_1 = quantum.custom "PauliZ"() %out_qubits_0 : !quantum.bit + %extracted_2 = tensor.extract %arg0[] : tensor + %4 = quantum.insert %2[%extracted_2], %out_qubits_1 : !quantum.reg, !quantum.bit + %extracted_3 = tensor.extract %arg1[] : tensor + %5 = quantum.extract %4[%extracted_3] : !quantum.reg -> !quantum.bit + %out_qubits_4 = quantum.custom "S"() %5 : !quantum.bit + %extracted_5 = tensor.extract %arg1[] : tensor + %6 = quantum.insert %4[%extracted_5], %out_qubits_4 : !quantum.reg, !quantum.bit + %extracted_6 = tensor.extract %arg0[] : tensor + %7 = quantum.extract %6[%extracted_6] : !quantum.reg -> !quantum.bit + %out_qubits_7 = quantum.custom "T"() %7 : !quantum.bit + %extracted_8 = tensor.extract %arg0[] : tensor + %8 = quantum.insert %6[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit + %9 = quantum.extract %8[ 0] : !quantum.reg -> !quantum.bit + %out_qubits_9 = quantum.custom "Hadamard"() %9 : !quantum.bit + %10 = quantum.insert %8[ 0], %out_qubits_9 : !quantum.reg, !quantum.bit + %11 = quantum.compbasis qreg %10 : !quantum.obs + %12 = quantum.state %11 : tensor<8xcomplex> + quantum.dealloc %10 : !quantum.reg + quantum.device_release + return %12 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } + } + + **Notes** + + - If a qubit is inserted into the quantum register, it is not reused + again, and a new qubit corresponding to the same wire label must be + extracted from the register. + - When a dynamic wire is going to be used for the first time, all + qubits that have been used previously without being re-inserted into + the quantum register must be inserted. + - If a dynamic wire is reused before any other wires, then it does not + need to be inserted to and extracted from the quantum register + again. + - To use static wires after dynamic wires, the dynamic wires are again + re-inserted into the quantum register. + - All of the above make it such that dynamic wires essentially create + barriers that break the qubit def-use chains. This causes some + functionality to be lost, but makes sure that we’re using qubits + safely. + +Implications/notes +~~~~~~~~~~~~~~~~~~ + +- Operations on the same wires can be tracked quickly using the chain of + definitions and uses of the qubits (def-use chains). +- Operations on dynamic wires (i.e wires whose values are only known at + run-time) are handled automatically and we don’t need to worry about + managing how we work around them. For context, this is an issue that + we found in the plxpr variant of ``cancel_inverses``, where two + operations on the same wire that were separated by another operation + on a dynamic wire would get cancelled, which is incorrect + (`source `__). +- One thing to keep in mind is that qubits (``QubitType``) and quantum + registers (``QregType``) are there so that we conform to SSA semantics + in a way that works well for our purposes. They get meaning from how + we choose to interpret them. We could just as easily have defined the + ``Quantum`` dialect and Catalyst’s lowering rules to use wire indices + the same way we do in Python, but we may lose capabilities that MLIR + and xDSL enable through the SSA form. + +``compiler_transform`` +---------------------- + +``compiler_transform`` is the function used to register xDSL +``ModulePass``\ es to be used with ``qjit``-ed workflows. It is +currently accessible as +``qml.compiler.python_compiler.compiler_transform``. + +.. code-block:: python + + from catalyst.python_interface import compiler_transform + + class MyPass(xdsl.passes.ModulePass): + """MyPass that does something""" + + name = "my-pass" + + def apply(self, ctx, module): + # Apply the pass to module + return + + my_pass = compiler_transform(MyPass) + + # Program capture must be enabled to use the compiler transform + # as a decorator + qml.capture.enable() + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit( + pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()] + ) + @my_pass + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + circuit(1.5) + +The ``compiler_transform`` function returns an object that gives easy +access to the underlying ``ModulePass``, as well as its name as seen by +the compiler. + +>>> my_pass.module_pass +__main__.MyPass +>>> my_pass.name +'my-pass' + +Additionally, we don’t need to manually apply passes using +``PassPipeline`` when decorating QNodes with registered compiler +transforms. Those transforms get applied automatically when the workflow +is compiled! + +Conversion utilities +-------------------- + +The ``python_compiler.conversion`` submodule provides several utilities +for creating xDSL modules from Python functions. There are many more +utilities in the submodule, but I will focus on the most important ones, +and provide examples of how to use them. + +- ``xdsl_module(jitted_fn: Callable) -> Callable[Any, [xdsl.dialects.builtin.ModuleOp]]``: + Create a wrapper around a ``jax.jit``-ed function that returns an xDSL + module +- ``xdsl_from_qjit(qjitted_fn: QJIT) -> Callable[..., xbuiltin.ModuleOp]``: + Create a wrapper around a ``qjit``-ed function that returns an xDSL + module. This is currently not merged to ``master`` +- ``inline_module(from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None) -> None``: + This function takes two modules as input, and inlines the whole body + of the first module into the second. Additionally, if + ``change_main_to`` is provided, it looks for a function named + ``main``, and updates its name to ``change_main_to`` +- ``inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp, *args, **kwargs) -> None``: + This function takes a ``jax.jit``-ed function, converts it into an + xDSL module, and then inlines the contents of the xDSL module into + ``mod``. Note that this function does not return anything; instead, it + modifies ``mod`` in-place. + +Check out the :doc:`xDSL conversion utilities tutorial <./xdsl_utils_tutorial>` to see examples of how each of +the utilities can be used. + +Useful patterns +=============== + +Now that we have gone over compilers, xDSL, and how it’s being used in +PennyLane, let’s take a look at some common patterns that might be +useful. + +Post-processing functions +------------------------- + +Post-processing functions are purely classical, so we can leverage the +``xdsl_module`` utility function to create xDSL modules from Python +code, and inject it into the modules we are rewriting as needed. The +:doc:`xDSL post-processing tutorial <./xdsl_post_processing>` shows an +example where we perform very simple post-processing on a QNode that +returns an expectation value by squaring the expectation value. + +Splitting tapes +--------------- + +When implementing passes that may transform our module such that +multiple device executions are required (akin to tape transforms that +return multiple tapes), there are various strategies that can be used. +Because there is no guarantee of shared structure between the tapes, +there is no one perfect strategy that can be used for all transforms. +I’ll use tapes to provide details below about some common cases: + +- If all the tapes are identical (eg. ``dynamic_one_shot``), then the + entire execution can be put inside a ``for`` loop, with + post-processing on the execution outputs done how I showed in the + above notebook. +- If all gates are identical, but measurements are different (eg. + ``split_non_commuting``), we can capture all gates in a single + function, and then use a ``for`` loop that iterates over the number of + tapes. Within this ``for`` loop, we would call the aforementioned + function, which would evolve the state, and apply a different + measurement inside each iteration of the loop. Post-processing can be + handled same as above. +- If there is little/no shared structure between the tapes, we would + need separate functions for each of the transformed tapes. We would + need to call each function one by one, and then use the results for + post-processing. + +All of the above are very non-trivial. I will leave out code examples +for now, as that may be unnecessarily time consuming. If we get to the +stage where we need to write a transform that splits into multiple +tapes, we can revisit this section and the Python compiler/compilation +team can assist in developing such transforms. + +Note +~~~~ + +Currently, we don’t have a consistent API for implementing transforms +that split QNodes into multiple device executions (eg. +``split_non_commuting``, etc.). Thus, it is *very* hard to implement +transforms that do that. We cannot guarantee that one transform that +does split QNodes into multiple device executions will work well when +there are other transforms in the pipeline. + +Writing tests +============= + +**Note to readers**: this section is written based on how the testing +infrastructure for the Python compiler exists in PennyLane. However, the +Python compiler may be getting moved to Catalyst, in which case, the +infrastructure would likely change. + +FileCheck +--------- + +`FileCheck `__ is a +pattern matching file verifier developed by the LLVM project and is +widely used to test MLIR. Since xDSL and MLIR are syntactically +identical, we can use the string representation of xDSL modules for +testing with FileCheck. + +Below is an example of how an MLIR/xDSL string may look when populated +with directives that FileCheck can use for testing. In this example, we +have a void function ``test_func`` that takes no arguments, creates two +qubits (corresponding to different wires), and applies a ``PauliX`` to +each of these wires. We can see that there are 4 comments starting with +``CHECK`` - these comments are what FileCheck uses to match the expected +string against the actual string. + +The number of ``CHECK`` comments does not need to be the same as the +number of operations in the program - we can see below that there is no +``CHECK`` for ``func.func`` and ``return``. + +``CHECK`` statements can assign parameters that can be reused by later +``CHECK``\ s. Below, the first two ``CHECK``\ s create ``q0`` and +``q1``, which are matched using the regular expression ``%.+``, which +matches any expressions that start with ``%`` and contain at least one +character after it. The latter 2 ``CHECK``\ s then use ``q0`` and ``q1`` +as the input qubits for the assertion. + +``CHECK`` statements can contain partial statements. Below, the last two +``CHECK``\ s don’t include the outputs of the ``quantum.custom`` +operations. + +.. code-block:: python + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q1]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + %3 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + +FileCheck essentially uses the MLIR generated by a program and asserts +it against directives that can be specified by the user. Commonly used +directives are: + +- ``CHECK``: This is used to check if a specified pattern is found. Its + syntax is very simple: they are fixed strings that must occur in + order, and horizontal whitespace is ignored by default. +- ``CHECK-NOT``: This directive is used to check that the provided + string does not occur between two matches, before the first match, or + after the last match. +- ``CHECK-DAG``: This directive may be used when it’s necessary to match + strings that don’t have to occur in the same order, but it is able to + match valid topological orderings of the program DAG, with edges from + the definition to the uses of a variable. +- ``CHECK-NEXT``: This directive is used to check that the matched line + occurs right after the previously matched line with no other lines in + between them. +- ``CHECK-SAME``: This directive is used when we want to match lines and + would like to verify that matches happen on the same line as the + previous match. + +To find more details about the above directives, or to learn about other +available directives, please refer to the `FileCheck +documentation `__. + +Test dialect +------------ + +xDSL provides a ``Test`` dialect +(`source `__), +which contains many operations that are useful for unit testing. In our +testing, we found the ``TestOp`` operation to be the most useful. This +operation can produce arbitrary results, which we can use to limit +artificial dependencies on other dialects. + +For example, if I just need to assert that a specific gate is present, +without ``TestOp``, I would need my module to contain an ``AllocOp`` +that creates a quantum register, an ``ExtractOp`` that extracts a qubit +from the register, and only then I can use the qubit for the gate I’m +trying to match. Instead, I can just insert a ``TestOp`` that returns a +qubit and use that. + +This is very powerful for unit testing, as it makes writing tests much +simpler, while also limiting the scope of the test as one would expect +for unit tests. + +In the code block in the previous section, ``TestOp`` has been used in +exactly the described way - we use it to create 2 qubits that are then +used as input for the 2 ``PauliX`` gates. + +.. _pennylane-integration-1: + +PennyLane integration +--------------------- + +To use FileCheck with ``pytest``, we use the ```filecheck`` Python +package `__, which allows us to use +assertions for testing in a way that ``pytest`` can understand. All of +the ``filecheck`` API has been captured inside two fixtures available +within the ``tests/python_compiler`` folder: + +- ``run_filecheck``: This fixture is for unit testing. One can specify a + program along with filecheck directives as a multi-line string. +- ``run_filecheck_qjit``: This fixture is for integration testing. One + can create a normal ``qml.qjit``-ed workflow and include filecheck + directives as in-line comments. + +Let’s write tests for the ``HToXPass`` that was implemented in the +`β€œPass API” <#pass-api>`__ sub-section to illustrate. The dev comments +will explain what is going on. + +``run_filecheck`` example +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + def test_h_to_x_pass(run_filecheck): + """Test that Hadamard gets converted into PauliX.""" + # The original program creates a qubit, and gives it to a + # CustomOp that is a Hadamard. The CHECK directives check that + # the transformed program has a CustomOp that is a PauliX applied + # to the same qubit, and no CustomOp that is a Hadamard. + + # Below we also see how we can use `test.op`, which we use to + # create a qubit to give to the Hadamard. + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %1 = quantum.custom "Hadamard"() %0 : !quantum.bit + return + } + """ + + # First, we must create the pass pipeline that we want to apply + # to our program. + pipeline = (HToXPass(),) + + # Next, we use run_filecheck to run filecheck testing on the + # program. The fixture will create an xDSL module from the program + # string, call the filecheck API, and assert correctness. + run_filecheck(program, pipeline) + + # Optionally, there are two keyword arguments that can modify the + # behaviour of run_filecheck. Both the roundtrip and verify + # arguments are false by default. + run_filecheck(program, pipeline, roundtrip=True, verify=True) + + # roundtrip=True makes it so that after parsing the program string + # to an xDSL module, we print it as a string again, and then parse + # it back into an xDSL module. This is useful when writing tests + # for dialects, so that we can check that the dialects can be printed + # parsed correctly. + + # verify=True simply runs `module.verify()`, which iteratively uses + # xDSL's verifiers to verify all operations and attributes in the + # module. + +``run_filecheck_qjit`` example +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath + + def test_h_to_x_pass_integration(run_filecheck_qjit): + """Test that Hadamard gets converted into PauliX.""" + # The original program simply applies a Hadamard to a circuit + # Remember that we need to specify the pass plugin. Additionally, + # with qjit, we can use the decorator created using + # `compiler_transform`. To make sure that the xDSL API works + # correctly, program capture must be enabled. + # qml.capture.enable() + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath]) + @h_to_x_pass + def circuit(): + # CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + # CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + # CHECK-NOT: quantum.custom "Hadamard" + qml.Hadamard(0) + return qml.state() + + # Finally, we use the run_filecheck_qjit fixture. We pass it our + # original qjitted workflow. It extracts the filecheck directives + # from the workflow, creates an MLIR program, parses it to xDSL, + # applies the specified transforms, and uses the filecheck API to + # assert correctness. + run_filecheck_qjit(circuit) + + # run_filecheck_qjit also accepts a boolean `verify` argument + # that is false by default, which works exactly the same way as + # the `verify` argument of `run_filecheck`. + run_filecheck_qjit(circuit, verify=True) + +Key blockers +============ + +There are several blockers that are currently disabling developers from +taking full advantage of the ``python_compiler`` submodule. These +include: + +* Lack of support for quantum subroutines. This impacts pattern + matching passes that need to substitute the matched operation(s) with + subroutines containing quantum instructions. + +Strategies to circumvent blockers +--------------------------------- + +* We can use dummy subroutines for now. We know what the inputs and + outputs of these subroutines should be, so we can create our own + ``FuncOp``\ s that adhere to the input/output spec and just have + their body be empty for now. To see an example where we create a + dummy quantum subroutine and use it to develop a pass, check out the + :doc:`xDSL subroutines tutorial <./xdsl_dummy_quantum_subroutines>`. + +Suggested reading +================= + +Useful dialects +--------------- + +- ``scf``: Structured control flow +- ``func``: Functions +- ``builtin``: Core types and attributes +- ``arith``: arithmetic operations +- ``stablehlo``: Advanced match dialect + +References +========== + +#. Wikimedia Foundation. (2025, August 11). *Static single-assignment + form*. Wikipedia. + https://en.wikipedia.org/wiki/Static_single-assignment_form +#. *MLIR Language Reference*. MLIR. (n.d.). + https://mlir.llvm.org/docs/LangRef/ +#. *Understanding the IR Structure*. MLIR. (n.d.-b). + https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/ +#. *Mlir::Value class reference*. MLIR. (n.d.-b). + https://mlir.llvm.org/doxygen/classmlir_1_1Value.html +#. *LLVM Programmer’s Manual*. LLVM. (n.d.). + https://llvm.org/docs/ProgrammersManual.html diff --git a/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst new file mode 100644 index 0000000000..a1e3a53324 --- /dev/null +++ b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst @@ -0,0 +1,216 @@ +.. code-block:: python + + from dataclasses import dataclass + + import pennylane as qml + from catalyst.python_interface.conversion import xdsl_from_qjit + from catalyst.python_interface.dialects.quantum import CustomOp, QubitType + + from xdsl import context, passes, pattern_rewriter + from xdsl.builder import ImplicitBuilder + from xdsl.dialects import builtin, func + from xdsl.ir import Block, Region + +Convert into xDSL module +======================== + +.. code-block:: python + + dev = qml.device("lightning.qubit", wires=5) + + @xdsl_from_qjit + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(x): + qml.H(0) + return qml.expval(qml.Z(0)) + + +>>> qjit_mod = circuit(1.5) +>>> print(qjit_mod) +builtin.module @circuit { + func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + func.return %0 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = quantum.custom "Hadamard"() %5 : !quantum.bit + %7 = quantum.namedobs %6[PauliZ] : !quantum.obs + %8 = quantum.expval %7 : f64 + %9 = tensor.from_elements %8 : tensor + %10 = tensor.extract %0[] : tensor + %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit + quantum.dealloc %11 : !quantum.reg + quantum.device_release + func.return %9 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +Let’s create a quantum subroutine +================================= + +This subroutine’s purpose is to replace Hadamard gates with a blackbox +(which will be empty for now), so it must take a single qubit as its +argument. Additionally, we are assuming that the qubits on which the +subroutine will act are not just the ones that the subroutine takes as +input, so we must also provide the quantum register as input, and also +give it as the output. + +``FuncOp``\ s have a single region with a single block that contains the +body of the function. + +We need to build the function’s body by populating its inner ``Block`` +with operations. We can do so using the ``xdsl.builder.ImplicitBuilder`` +class. This class can be used as a context manager that takes a +``Block`` as input, and any operations created within the context of the +builder get added to the block. Let’s try it out. + +Here, we create a subroutine that we will use to replace +``Hadamard``\ s. This subroutine applies a gate provided by the user, +and returns the ``out_qubit`` of the gate. + +.. code-block:: python + + def create_hadamard_replacement_subroutine(gate_name): + input_types = (QubitType(),) + output_types = (QubitType(),) + block = Block(arg_types=input_types) + + with ImplicitBuilder(block): + in_qubits = [block.args[0]] + op1 = CustomOp(in_qubits=in_qubits, gate_name=gate_name) + func.ReturnOp(op1.out_qubits[0]) + + region = Region([block]) + funcOp = func.FuncOp("replace_hadamard", (input_types, output_types), region=region) + return funcOp + + +>>> funcOp = create_hadamard_replacement_subroutine("S") +>>> print(funcOp) +func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit { + %1 = quantum.custom "S"() %0 : !quantum.bit + func.return %1 : !quantum.bit +} + + +Now, we write a pass to do the substitution +=========================================== + +.. code-block:: python + + class ReplaceHadamardPattern(pattern_rewriter.RewritePattern): + + def __init__(self, subroutine: func.FuncOp): + self.subroutine = subroutine + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, customOp: CustomOp, rewriter: pattern_rewriter.PatternRewriter): + if customOp.gate_name.data != "Hadamard": + return + + callOp = func.CallOp( + builtin.SymbolRefAttr("replace_hadamard"), + [customOp.in_qubits[0]], + self.subroutine.function_type.outputs.data, + ) + rewriter.insert_op_after_matched_op(callOp) + rewriter.replace_all_uses_with(customOp.out_qubits[0], callOp.results[0]) + rewriter.erase_op(customOp) + + + @dataclass(frozen=True) + class ReplaceHadamardPass(passes.ModulePass): + name = "replace-hadamard" + gate_name: str + + def apply(self, ctx: context.Context, module: builtin.ModuleOp): + funcOp = create_hadamard_replacement_subroutine(self.gate_name) + module.regions[0].blocks.first.add_op(funcOp) + + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([ReplaceHadamardPattern(funcOp)]) + ).rewrite_module(module) + +Let’s see it in action +====================== + +Here, we will replace all ``Hadamard``\ s with ``S``\ s + +.. code-block:: python + + ctx = context.Context() + + pipeline = passes.PassPipeline((ReplaceHadamardPass("S"),)) + pipeline.apply(ctx, qjit_mod) + +Great! We can see below that ``Hadamard`` was replaced by a call to +``replace_hadamard``, which applies a single ``S`` gate. + +>>> print(qjit_mod) +builtin.module @circuit { + func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + func.return %0 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = func.call @replace_hadamard(%5) : (!quantum.bit) -> !quantum.bit + %7 = quantum.namedobs %6[PauliZ] : !quantum.obs + %8 = quantum.expval %7 : f64 + %9 = tensor.from_elements %8 : tensor + %10 = tensor.extract %0[] : tensor + %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit + quantum.dealloc %11 : !quantum.reg + quantum.device_release + func.return %9 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit { + %1 = quantum.custom "S"() %0 : !quantum.bit + func.return %1 : !quantum.bit + } +} diff --git a/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst new file mode 100644 index 0000000000..b9ca0ac3ae --- /dev/null +++ b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst @@ -0,0 +1,225 @@ +Simple tutorial for injecting functions into xDSL modules +========================================================= + +.. code-block:: python + + from dataclasses import dataclass + import jax + + import pennylane as qml + from catalyst.python_interface.conversion import inline_module, xdsl_from_qjit, xdsl_module + + from xdsl import context, passes, pattern_rewriter + from xdsl.dialects import builtin, func + from xdsl.traits import SymbolTable + from xdsl.rewriter import InsertPoint + +Create workflow and convert to xDSL module +========================================== + +.. code-block:: python + + @xdsl_from_qjit + @qml.qjit(target="mlir") + def workflow(x, y): + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + res = circuit(x) + return res - y + + +>>> xmod = workflow(3.5, 4.5) +>>> print(xmod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + %1 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %2 = "stablehlo.subtract"(%0, %1) : (tensor, tensor) -> tensor + func.return %2 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +Now, let’s try creating a pass that squares the output of the qnode +=================================================================== + +To do so, we can use the ``inline_module`` utility to easily add our +post-processing function into the module we’re transforming. First, we +create the function that squares the input value and turn it into an +xDSL module. + +.. code-block:: python + + @jax.jit + def square(x): + return x * x + + +>>> square_mod = xdsl_module(square)(1.5) +>>> print(square_mod) +builtin.module @jit_square attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +.. code-block:: python + + def is_kernel_launch(op): + return op.name == "catalyst.launch_kernel" + + + class SquarePattern(pattern_rewriter.RewritePattern): + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + # We only rewrite the function that calls the qnode, and the caller of the qnode will + # always have catalyst.launch_kernel present. Additionally, we only rewrite the caller + # if it hasn't already been rewritten. We can put a UnitAttr() inside the caller's + # attributes to indicate whether it has been rewritten or not. + if funcOp.attributes.get("transformed") == builtin.UnitAttr() or not any( + is_kernel_launch(op) for op in funcOp.body.ops + ): + return + + # Update funcOp to inidicate that it has been rewritten + funcOp.attributes["transformed"] = builtin.UnitAttr() + + # Insert square into the module + mod = funcOp.parent_op() + inline_module(square_mod, mod, change_main_to="square") + square_fn = SymbolTable.lookup_symbol(mod, "square") + + # Call square_fn and use its results instead of the qnode's results for + # the rest of the function + for op in funcOp.body.walk(): + if is_kernel_launch(op): + callOp = func.CallOp( + builtin.SymbolRefAttr(square_fn.sym_name), + op.results, + square_fn.function_type.outputs.data, + ) + rewriter.insert_op(callOp, InsertPoint.after(op)) + + # We have inserted a CallOp that takes the output of the qnode as input. Let's call + # the qnode output %0, and the CallOp output %1. The following replaces all uses of + # %0 with %1 EXCEPT for the case where %0 is an input to callOp + op.results[0].replace_by_if(callOp.results[0], lambda use: use.operation != callOp) + rewriter.notify_op_modified(funcOp) + + + @dataclass(frozen=True) + class SquarePass(passes.ModulePass): + name = "square" + + def apply(self, ctx: context.Context, module: builtin.ModuleOp): + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([SquarePattern()]) + ).rewrite_module(module) + +Let’s apply the pass to our workflow +==================================== + +.. code-block:: python + + ctx = context.Context() + + pipeline = passes.PassPipeline((SquarePass(),)) + pipeline.apply(ctx, xmod) + +Great! Let’s see what the transformed module looks like +======================================================= + +As you can see below, the ``square_xdsl`` function is the first function +in the module, and it gets called by ``jit_workflow``, and its +inputs/outputs are consistent with the behaviour we wanted. + +>>> print(xmod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface, transformed} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + %1 = func.call @square(%0) : (tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func public @square(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} diff --git a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst new file mode 100644 index 0000000000..8c0163a3de --- /dev/null +++ b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst @@ -0,0 +1,404 @@ +Python compiler utilities +========================= + +All utilities we care about are in the +``catalyst.python_interface.conversion`` submodule. + +.. code-block:: python + + import pennylane as qml + + from catalyst.python_interface.conversion import ( + inline_jit_to_module, + inline_module, + xdsl_from_qjit, + xdsl_module, + ) + +``xdsl_module`` +=============== + +This function takes a ``jax.jit``-ed function as input, and returns a +wrapper. This wrapper can be called to return an xDSL module. Note that +this function is intended to be used to covert purely classical +functions into xDSL modules. Let’s take a look at a very simple example: + +.. code-block:: python + + import jax + + + @jax.jit + def inner(x): + return x**2 + + + @jax.jit + def outer(x, y): + return inner(x) - y + + +>>> wrapped_outer = xdsl_module(outer) +>>> jit_mod = wrapped_outer(1.5, 2.5) +>>> print(jit_mod) +builtin.module @jit_outer attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +Nice! Key points to note: \* The module has the same name as the +decorated function (``outer``) with the ``jit_`` prefix. \* The entry +point function is aptly named ``main``. \* Any jitted functions called +within the entry point have their own function inside the module, and a +corresponding ``func.call`` operation where it gets called. \* If +``inner`` was not decorated with ``jax.jit``, its body would have been +inlined into ``outer``: + +.. code-block:: python + + def inner2(x): + return x**2 + + + @xdsl_module + @jax.jit + def outer2(x, y): + return inner2(x) - y + + +>>> mod2 = outer2(1.5, 2.5) +>>> print(mod2) +builtin.module @jit_outer2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0 : tensor, %arg1 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg1) : (tensor, tensor) -> tensor + func.return %1 : tensor + } +} + + +``xdsl_from_qjit`` +================== + +Since ``xdsl_module`` is for purely classical code, there is another +function that can lower hybrid quantum-classical code written in Python +to an xDSL module. ``xdsl_from_qjit`` takes a ``QJIT``-ed function as +input, and converts it into an xDSL module with the same structure as a +Catalyst program. Let’s check it out: + +.. code-block:: python + + @xdsl_from_qjit + @qml.qjit + def workflow(x, y): + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def qnode(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + return qnode(x) ** 2 - y + + +>>> qjit_mod = workflow(2.5, 3.5) +>>> print(qjit_mod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +Nice! The usefulness of having this utility is that all dialects that we +would commonly find being used at the MLIR layer are already loaded, so +users don’t need to worry about loading dialects manually. + +``inline_jit_to_module`` +======================== + +This utility takes ``xdsl_module`` a step beyond. It takes a +``jax.jit``-ed function ``func`` and an xDSL module ``mod`` as input. It +lowers ``func`` to an xDSL module, and appends the lowered module’s body +to the end of ``mod``\ ’s body. An additional step that this function +takes is that it renames the ``main`` function in the lowered module to +the same name as ``func`` so that it’s easier to find. + +Let’s try inlining the previous ``outer`` function into ``qjit_mod``. +``inline_jit_to_module`` will modify ``qjit_mod`` in-place: + +>>> inline_jit_to_module(outer, qjit_mod)(1.5, 3.5) +>>> print(qjit_mod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func public @outer(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +Nice! We can see that two new functions have been added to the bottom of +``qjit_mod``, corresponding to ``outer`` and ``inner``. This allows us +to make function calls to both functions, which can be useful for +embedding post-processing logic into a module, say, when applying a +compiler pass. + +``inline_module`` +================= + +This utility takes two modules as input, and inlines the body of the +first module into the body of the second module. Additionally, +recognizing that modules created from ``jax.jit``-ed functions may +contain a ``FuncOp`` called ``main``, the function also takes a string +as an optional input to rename the ``main`` function to something else. + +Let’s try revisiting the above modules and inlining them using +``inline_module`` instead of ``inline_jit_to_module``: + +.. code-block:: python + + @xdsl_from_qjit + @qml.qjit + def workflow2(x, y): + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def qnode(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + return qnode(x) ** 2 - y + + +>>> qjit_mod2 = workflow2(2.5, 3.5) +>>> print(qjit_mod2) +builtin.module @workflow2 { + func.func public @jit_workflow2(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +.. code-block:: python + + @jax.jit + def inner3(x): + return x**2 + + + @jax.jit + def outer3(x, y): + return inner3(x) - y + + +>>> wrapped_outer3 = xdsl_module(outer3) +>>> jit_mod3 = wrapped_outer3(1.5, 2.5) +>>> print(jit_mod3) +builtin.module @jit_outer3 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner3(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner3(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +Now, we will inline the contents of ``jit_mod3`` into ``qjit_mod2``, and +rename ``main`` to ``outer3``. As seen below, ``outer3`` and ``inner3`` +have been inlined into ``qjit_mod2``, just like we wanted. + +One might wonder why we need both ``inline_jit_to_module`` and +``inline_module``. At this stage, the intention is just to enable as +much functionality for users of the Python compiler as possible. There +may be use cases where users may want to create their own module +manually instead of having one be created automatically by JAX or +Catalyst. + +>>> inline_module(jit_mod3, qjit_mod2, change_main_to="outer3") +>>> print(qjit_mod2) +builtin.module @workflow2 { + func.func public @jit_workflow2(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func public @outer3(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner3(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner3(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} diff --git a/frontend/catalyst/python_interface/parser.py b/frontend/catalyst/python_interface/parser.py new file mode 100644 index 0000000000..b7169362e8 --- /dev/null +++ b/frontend/catalyst/python_interface/parser.py @@ -0,0 +1,70 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for translating JAX to xDSL""" + +from collections.abc import Sequence + +from xdsl.context import Context as xContext +from xdsl.dialects import arith as xarith +from xdsl.dialects import builtin as xbuiltin +from xdsl.dialects import func as xfunc +from xdsl.dialects import scf as xscf +from xdsl.dialects import tensor as xtensor +from xdsl.ir import Dialect as xDialect +from xdsl.parser import Parser as xParser + +from catalyst.python_interface.dialects import MBQC, QEC, Catalyst, Quantum, StableHLO, Transform + + +class QuantumParser(xParser): # pylint: disable=abstract-method + """A subclass of ``xdsl.parser.Parser`` that automatically loads relevant dialects + into the input context. + + Args: + ctx (xdsl.context.Context): Context to use for parsing. + input (str): Input program string to parse. + name (str): The name for the input. ``""`` by default. + extra_dialects (Sequence[xdsl.ir.Dialect]): Any additional dialects + that should be loaded into the context before parsing. + """ + + default_dialects: tuple[xDialect] = ( + xarith.Arith, + xbuiltin.Builtin, + xfunc.Func, + xscf.Scf, + StableHLO, + xtensor.Tensor, + Transform, + Quantum, + MBQC, + Catalyst, + QEC, + ) + + # pylint: disable=redefined-builtin + def __init__( + self, + ctx: xContext, + input: str, + name: str = "", + extra_dialects: Sequence[xDialect] | None = (), + ) -> None: + super().__init__(ctx, input, name) + + extra_dialects = extra_dialects or () + for dialect in self.default_dialects + tuple(extra_dialects): + if self.ctx.get_optional_dialect(dialect.name) is None: + self.ctx.load_dialect(dialect) diff --git a/frontend/catalyst/python_interface/pass_api/__init__.py b/frontend/catalyst/python_interface/pass_api/__init__.py new file mode 100644 index 0000000000..ceb32f9de5 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""xDSL transforms core API.""" + +from .apply_transform_sequence import ( + ApplyTransformSequence, + available_passes, + is_xdsl_pass, + register_pass, +) +from .compiler_transform import PassDispatcher, compiler_transform +from .transform_interpreter import TransformFunctionsExt, TransformInterpreterPass + +__all__ = [ + "ApplyTransformSequence", + "available_passes", + "is_xdsl_pass", + "PassDispatcher", + "register_pass", + "TransformFunctionsExt", + "TransformInterpreterPass", + "compiler_transform", +] diff --git a/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py new file mode 100644 index 0000000000..7a07218b2a --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py @@ -0,0 +1,84 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the pass that applies all passes present in the program representation.""" + + +from dataclasses import dataclass + +from pennylane.typing import Callable +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.passes import ModulePass, PassPipeline + +from .transform_interpreter import TransformInterpreterPass + +available_passes = {} + + +def register_pass(name, _callable): + """Registers the passes available in the dictionary""" + available_passes[name] = _callable # pragma: no cover + + +def is_xdsl_pass(pass_name: str) -> bool: + """Check if a pass name corresponds to an xDSL implemented pass. + + This function checks if the pass is registered in PennyLane's unified compiler + pass registry, which dynamically tracks all available xDSL passes. + + Args: + pass_name (str): Name of the pass to check + + Returns: + bool: True if this is an xDSL compiler pass + """ + return pass_name in available_passes + + +@dataclass(frozen=True) +class ApplyTransformSequence(ModulePass): + """ + Looks for nested modules. Nested modules in this context are guaranteed to correspond + to qnodes. These modules are already annotated with which passes are to be executed. + The pass ApplyTransformSequence will run passes annotated in the qnode modules. + + At the end, we delete the list of passes as they have already been applied. + """ + + name = "apply-transform-sequence" + callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None + + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: + """Applies the transformation""" + nested_modules = [] + for region in op.regions: + for block in region.blocks: + for operation in block.ops: + if isinstance(operation, builtin.ModuleOp): + nested_modules.append(operation) + + pipeline = PassPipeline( + (TransformInterpreterPass(passes=available_passes, callback=self.callback),) + ) + for operation in nested_modules: + pipeline.apply(ctx, operation) + + for mod in nested_modules: + for region in mod.regions: + for block in region.blocks: + for operation in block.ops: + if isinstance(operation, builtin.ModuleOp) and operation.get_attr_or_prop( + "transform.with_named_sequence" + ): + block.erase_op(operation) # pragma: no cover diff --git a/frontend/catalyst/python_interface/pass_api/compiler_transform.py b/frontend/catalyst/python_interface/pass_api/compiler_transform.py new file mode 100644 index 0000000000..87943c555d --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/compiler_transform.py @@ -0,0 +1,64 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core API for registering xDSL transforms for use with PennyLane and Catalyst.""" + +from collections.abc import Callable + +from pennylane.transforms.core.transform_dispatcher import TransformDispatcher +from xdsl.passes import ModulePass + +from catalyst.from_plxpr import register_transform + +from .apply_transform_sequence import register_pass + + +def _create_null_transform(name: str) -> Callable: + """Create a dummy tape transform. The tape transform raises an error if used.""" + + def null_transform(_): + raise RuntimeError( + f"Cannot apply the {name} pass without '@qml.qjit'. Additionally, program capture " + "must be enabled using 'qml.capture.enable()'. Otherwise, the pass must be applied " + f"using the '@catalyst.passes.apply_pass(\"{name}\")' decorator." + ) + + return null_transform + + +class PassDispatcher(TransformDispatcher): + """Wrapper class for applying passes to QJIT-ed workflows.""" + + name: str + module_pass: ModulePass + + def __init__(self, module_pass: ModulePass): + self.module_pass = module_pass + self.name = module_pass.name + tape_transform = _create_null_transform(self.name) + super().__init__(tape_transform) + + +def compiler_transform(module_pass: ModulePass) -> PassDispatcher: + """Wrapper function to register xDSL passes to use with QJIT-ed workflows.""" + dispatcher = PassDispatcher(module_pass) + + # Registration to map from plxpr primitive to pass + register_transform(dispatcher, module_pass.name, False) + + # Registration for apply-transform-sequence interpreter + def get_pass_cls(): + return module_pass + + register_pass(module_pass.name, get_pass_cls) + return dispatcher diff --git a/frontend/catalyst/python_interface/pass_api/transform_interpreter.py b/frontend/catalyst/python_interface/pass_api/transform_interpreter.py new file mode 100644 index 0000000000..cf635c1591 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/transform_interpreter.py @@ -0,0 +1,146 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +"""Custom Transform Dialect Interpreter Pass + +Differs from xDSL's upstream implementation by allowing passes +to be passed in as options and for the pipeline to have a callback and apply it +after every pass is run. The callback differs from how xDSL callback mechanism +is integrated into the PassPipeline object since PassPipeline only runs +if there are more than two passes. Here we are running one pass at a time +which will prevent the callback from being called. + + +See here (link valid with xDSL 0.46): https://github.com/xdslproject/xdsl/blob/334492e660b1726bc661efc7afb927e74bac48f4/xdsl/passes.py#L211-L222 +""" + +import io +from collections.abc import Callable + +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.dialects.transform import NamedSequenceOp +from xdsl.interpreter import Interpreter, PythonValues, impl, register_impls +from xdsl.interpreters.transform import TransformFunctions +from xdsl.parser import Parser +from xdsl.passes import ModulePass, PassPipeline +from xdsl.printer import Printer +from xdsl.rewriter import Rewriter +from xdsl.utils.exceptions import PassFailedException + +from catalyst.compiler import _quantum_opt +from catalyst.python_interface.dialects.transform import ApplyRegisteredPassOp + + +@register_impls +class TransformFunctionsExt(TransformFunctions): + """ + Unlike the implementation available in xDSL, this implementation overrides + the semantics of the `transform.apply_registered_pass` operation by + first always attempting to apply the xDSL pass, but if it isn't found + then it will try to run this pass in Catalyst. + """ + + def __init__(self, ctx, passes, callback=None): + super().__init__(ctx, passes) + # The signature of the callback function is assumed to be + # def callback(previous_pass: ModulePass, module: ModuleOp, next_pass: ModulePass, pass_level=None) -> None + self.callback = callback + self.pass_level = 0 + + def _pre_pass_callback(self, compilation_pass, module): + """Callback wrapper to run the callback function before the pass.""" + if not self.callback: + return + if self.pass_level == 0: + # Since this is the first pass, there is no previous pass + self.callback(None, module, compilation_pass, pass_level=0) + + def _post_pass_callback(self, compilation_pass, module): + """Increment level and run callback if defined.""" + if not self.callback: + return + self.pass_level += 1 + self.callback(compilation_pass, module, None, pass_level=self.pass_level) + + @impl(ApplyRegisteredPassOp) + def run_apply_registered_pass_op( + self, + _interpreter: Interpreter, + op: ApplyRegisteredPassOp, + args: PythonValues, + ) -> PythonValues: + """Try to run the pass in xDSL, if not found then run it in Catalyst.""" + + pass_name = op.pass_name.data + module = args[0] + + # ---- xDSL path ---- + if pass_name in self.passes: + pass_class = self.passes[pass_name]() + pass_instance = pass_class(**op.options.data) + pipeline = PassPipeline((pass_instance,)) + self._pre_pass_callback(pass_instance, module) + pipeline.apply(self.ctx, module) + self._post_pass_callback(pass_instance, module) + return (module,) + + # ---- Catalyst path ---- + buffer = io.StringIO() + Printer(stream=buffer, print_generic_format=True).print_op(module) + + schedule = f"--{pass_name}" + self._pre_pass_callback(pass_name, module) + modified = _quantum_opt(schedule, "-mlir-print-op-generic", stdin=buffer.getvalue()) + + data = Parser(self.ctx, modified).parse_module() + rewriter = Rewriter() + rewriter.replace_op(module, data) + self._post_pass_callback(pass_name, data) + return (data,) + + +class TransformInterpreterPass(ModulePass): + """Transform dialect interpreter""" + + passes: dict[str, Callable[[], type[ModulePass]]] + name = "transform-interpreter" + callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None + + entry_point: str = "__transform_main" + + def __init__(self, passes, callback): + self.passes = passes + self.callback = callback + + @staticmethod + def find_transform_entry_point(root: builtin.ModuleOp, entry_point: str) -> NamedSequenceOp: + """Find the entry point of the program""" + for op in root.walk(): + if isinstance(op, NamedSequenceOp) and op.sym_name.data == entry_point: + return op + raise PassFailedException( # pragma: no cover + f"{root} could not find a nested named sequence with name: {entry_point}" + ) + + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: + """Run the interpreter with op.""" + schedule = TransformInterpreterPass.find_transform_entry_point(op, self.entry_point) + interpreter = Interpreter(op) + interpreter.register_implementations(TransformFunctionsExt(ctx, self.passes, self.callback)) + schedule.parent_op().detach() + if self.callback: + self.callback(None, op, None, pass_level=0) + interpreter.call_op(schedule, (op,)) diff --git a/frontend/catalyst/python_interface/transforms/__init__.py b/frontend/catalyst/python_interface/transforms/__init__.py new file mode 100644 index 0000000000..905893b2bf --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/__init__.py @@ -0,0 +1,64 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PennyLane-xDSL transformations API.""" + +from .mbqc import ( + ConvertToMBQCFormalismPass, + DecomposeGraphStatePass, + NullDecomposeGraphStatePass, + OutlineStateEvolutionPass, + convert_to_mbqc_formalism_pass, + decompose_graph_state_pass, + null_decompose_graph_state_pass, + outline_state_evolution_pass, +) +from .quantum import ( + CombineGlobalPhasesPass, + DiagonalizeFinalMeasurementsPass, + IterativeCancelInversesPass, + MeasurementsFromSamplesPass, + MergeRotationsPass, + SplitNonCommutingPass, + combine_global_phases_pass, + diagonalize_final_measurements_pass, + iterative_cancel_inverses_pass, + measurements_from_samples_pass, + merge_rotations_pass, + split_non_commuting_pass, +) + +__all__ = [ + # Quantum + "combine_global_phases_pass", + "CombineGlobalPhasesPass", + "diagonalize_final_measurements_pass", + "DiagonalizeFinalMeasurementsPass", + "iterative_cancel_inverses_pass", + "IterativeCancelInversesPass", + "measurements_from_samples_pass", + "MeasurementsFromSamplesPass", + "merge_rotations_pass", + "MergeRotationsPass", + "split_non_commuting_pass", + "SplitNonCommutingPass", + # MBQC + "convert_to_mbqc_formalism_pass", + "ConvertToMBQCFormalismPass", + "decompose_graph_state_pass", + "DecomposeGraphStatePass", + "OutlineStateEvolutionPass", + "outline_state_evolution_pass", + "null_decompose_graph_state_pass", + "NullDecomposeGraphStatePass", +] diff --git a/frontend/catalyst/python_interface/transforms/mbqc/__init__.py b/frontend/catalyst/python_interface/transforms/mbqc/__init__.py new file mode 100644 index 0000000000..134e9d3e76 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""xDSL transformations API specifically for the MBQC transform.""" + +from .convert_to_mbqc_formalism import ConvertToMBQCFormalismPass, convert_to_mbqc_formalism_pass +from .decompose_graph_state import ( + DecomposeGraphStatePass, + NullDecomposeGraphStatePass, + decompose_graph_state_pass, + null_decompose_graph_state_pass, +) +from .graph_state_utils import ( + edge_iter, + generate_adj_matrix, + get_graph_state_edges, + get_num_aux_wires, + n_vertices_from_packed_adj_matrix, +) +from .outline_state_evolution import OutlineStateEvolutionPass, outline_state_evolution_pass + +__all__ = [ + # Passes + "ConvertToMBQCFormalismPass", + "DecomposeGraphStatePass", + "OutlineStateEvolutionPass", + "NullDecomposeGraphStatePass", + "convert_to_mbqc_formalism_pass", + "decompose_graph_state_pass", + "null_decompose_graph_state_pass", + "outline_state_evolution_pass", + # Utils + "get_num_aux_wires", + "get_graph_state_edges", + "n_vertices_from_packed_adj_matrix", + "edge_iter", + "generate_adj_matrix", +] diff --git a/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py b/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py new file mode 100644 index 0000000000..4476d62e96 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py @@ -0,0 +1,750 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the convert_to_mbqc_formalism transform, +written using xDSL.""" + +import math +from dataclasses import dataclass + +from xdsl import builder, context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func, scf +from xdsl.dialects.scf import ForOp, IfOp, IndexSwitchOp, WhileOp +from xdsl.ir import SSAValue +from xdsl.ir.core import Block, OpResult, Region +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.mbqc import ( + GraphStatePrepOp, + MeasureInBasisOp, + MeasurementPlaneAttr, + MeasurementPlaneEnum, +) +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + DeallocQubitOp, + ExtractOp, + GlobalPhaseOp, + QubitType, +) +from catalyst.python_interface.pass_api import compiler_transform + +from .graph_state_utils import generate_adj_matrix, get_num_aux_wires + +_PAULIS = { + "PauliX", + "PauliY", + "PauliZ", + "Identity", +} + +_MBQC_ONE_QUBIT_GATES = { + "Hadamard", + "S", + "RZ", + "RotXZX", +} + +_MBQC_TWO_QUBIT_GATES = { + "CNOT", +} + +_MBQC_GATES = _MBQC_ONE_QUBIT_GATES | _MBQC_TWO_QUBIT_GATES + + +@dataclass(frozen=True) +class ConvertToMBQCFormalismPass(passes.ModulePass): + """Pass that converts gates in the MBQC gate set to the MBQC formalism.""" + + name = "convert-to-mbqc-formalism" + + def _prep_graph_state(self, gate_name: str): + """Add a graph state prep operation into the subroutine for each gate and extract and + return auxiliary qubits in the graph state. + + Args: + gate_name[str]: Name of gate operation. + + Return: + graph_qubit_dict : A dictionary of qubits in the graph + """ + num_aux_wres = get_num_aux_wires(gate_name) + + adj_matrix_op = arith.ConstantOp( + builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType( + builtin.IntegerType(1), shape=(len(generate_adj_matrix(gate_name)),) + ), + data=generate_adj_matrix(gate_name), + ) + ) + + graph_state_prep_op = GraphStatePrepOp(adj_matrix_op.result, "Hadamard", "CZ") + graph_state_reg = graph_state_prep_op.results[0] + + graph_qubit_dict = {} + # Extract qubit from the graph state reg + for i in range(num_aux_wres): + extract_op = ExtractOp(graph_state_reg, i) + + # Note the following line maps the aux qubit index in the register to the + # standard context book MBQC representation. Note that auxiliary qubits in + # the graph state for one qubit gates only hit the `if` branch as `i` is + # always less than `4`, while the auxiliary qubits in graph state for a `CNOT` + # gate with an index >= 7 would hit the `else` branch. + key = i + 2 if i < 7 else i + 3 + + graph_qubit_dict[key] = extract_op.results[0] + + return graph_qubit_dict + + def _insert_xy_basis_measure_op( + self, + const_angle_op: arith.ConstantOp, + qubit: QubitType, + ): + """Add an arbitrary basis measure related operations to the subroutine. + + Args: + const_angle_op (arith.ConstantOp) : The angle of measurement basis. + qubit (QubitType) : The target qubit to be measured. + + Returns: + The results include: 1. a measurement result; 2, a result qubit. + """ + plane_op = MeasurementPlaneAttr(MeasurementPlaneEnum("XY")) + # Create a MeasureInBasisOp op + measure_op = MeasureInBasisOp(in_qubit=qubit, plane=plane_op, angle=const_angle_op) + # Returns the results of the newly created measure_op. + # The results include: 1, a measurement result; 2, a result qubit. + return measure_op.results + + def _insert_cond_arbitrary_basis_measure_op( + self, + meas_parity: builtin.IntegerType, + angle: SSAValue[builtin.Float64Type], + plane: str, + qubit: QubitType, + ): + """ + Add a conditional arbitrary basis measurement operation based on a previous measurement + result. + + Args: + meas_parity (builtin.IntegerType) : A parity of previous measurements. + angle (SSAValue[builtin.Float64Type]) : An angle SSAValue from a parametric gate + operation. + plane (str): Plane of the measurement basis. + qubit (QubitType) : The target qubit to be measured. + + Returns: + The results include: 1. a measurement result; 2, a result qubit. + """ + constant_one_op = arith.ConstantOp.from_int_and_width(1, builtin.i1) + cond = arith.CmpiOp(meas_parity, constant_one_op, "eq") + branch = scf.IfOp( + cond, + ( + builtin.IntegerType(1), + QubitType(), + ), + Region(Block()), + Region(Block()), + ) + + plane_op = MeasurementPlaneAttr(MeasurementPlaneEnum(plane)) + + with builder.ImplicitBuilder(branch.true_region): + measure_op = MeasureInBasisOp(in_qubit=qubit, plane=plane_op, angle=angle) + scf.YieldOp(measure_op.results[0], measure_op.results[1]) + with builder.ImplicitBuilder(branch.false_region): + const_neg_angle_op = arith.NegfOp(angle) + measure_neg_op = MeasureInBasisOp( + in_qubit=qubit, plane=plane_op, angle=const_neg_angle_op.result + ) + scf.YieldOp(measure_neg_op.results[0], measure_neg_op.results[1]) + + return branch.results + + def _hadamard_measurements(self, graph_qubit_dict): + """Add measurement ops for a Hadamard gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + const_y_angle = arith.ConstantOp( + builtin.FloatAttr(data=math.pi / 2, type=builtin.Float64Type()) + ) # measure_y + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _s_measurements(self, graph_qubit_dict): + """Add measurement ops for a S gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + const_y_angle = arith.ConstantOp( + builtin.FloatAttr(data=math.pi / 2, type=builtin.Float64Type()) + ) # measure_y + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _rz_measurements(self, graph_qubit_dict, params): + """Add measurement ops for a RZ gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_cond_arbitrary_basis_measure_op( + m2, params[0], "XY", graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _rotxzx_measurements(self, graph_qubit_dict, params): + """Add measurement ops for a RotXZX gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_cond_arbitrary_basis_measure_op( + m1, params[0], "XY", graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_cond_arbitrary_basis_measure_op( + m2, params[1], "XY", graph_qubit_dict[3] + ) + + m1_xor_m3 = arith.XOrIOp(m1, m3) + + m4, graph_qubit_dict[4] = self._insert_cond_arbitrary_basis_measure_op( + m1_xor_m3.result, params[2], "XY", graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _cnot_measurements(self, graph_qubit_dict): + """Add measurement ops for a CNOT gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + const_y_angle = arith.ConstantOp( + builtin.FloatAttr(data=math.pi / 2, type=builtin.Float64Type()) + ) # measure_y + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[4] + ) + m5, graph_qubit_dict[5] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[5] + ) + m6, graph_qubit_dict[6] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[6] + ) + m8, graph_qubit_dict[8] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[8] + ) + m9, graph_qubit_dict[9] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[9] + ) + m10, graph_qubit_dict[10] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[10] + ) + m11, graph_qubit_dict[11] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[11] + ) + m12, graph_qubit_dict[12] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[12] + ) + m13, graph_qubit_dict[13] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[13] + ) + m14, graph_qubit_dict[14] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[14] + ) + + return [m1, m2, m3, m4, m5, m6, m8, m9, m10, m11, m12, m13, m14], graph_qubit_dict + + def _parity_check( + self, + mres: list[builtin.IntegerType], + additional_const_one: bool = False, + ): + """Add parity check related operations to the subroutine. + + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + additional_const_one (bool) : Whether we need to add an additional const one to + get the parity or not. Defaults to False. + + Returns: + The result of parity check. + """ + prev_res = mres[0] + xor_op = None + # Create xor ops to iterate all elements in the mres and insert them to the IR + for i in range(1, len(mres)): + xor_op = arith.XOrIOp(prev_res, mres[i]) + prev_res = xor_op.result + + # Create an xor op for an additional const one and insert ops to the IR + if additional_const_one: + constant_one_op = arith.ConstantOp.from_int_and_width(1, builtin.i1) + xor_op = arith.XOrIOp(prev_res, constant_one_op) + prev_res = xor_op.result + + return prev_res + + def _insert_cond_byproduct_op( + self, + parity_res: OpResult, + gate_name: str, + qubit: QubitType, + ): + """Add a byproduct op related operations to the subroutine. + + Args: + parity_res (OpResult) : Parity check result. + gate_name (str) : The name of the gate to be corrected. + qubit (QubitType) : The result auxiliary qubit to be corrected. + + Return: + The result auxiliary qubit. + """ + constant_one_op = arith.ConstantOp.from_int_and_width(1, builtin.i1) + cond = arith.CmpiOp(parity_res, constant_one_op, "eq") + branch = scf.IfOp(cond, (QubitType(),), Region(Block()), Region(Block())) + + with builder.ImplicitBuilder(branch.true_region): + byproduct_op = CustomOp(in_qubits=qubit, gate_name=gate_name) + scf.YieldOp(byproduct_op.results[0]) + with builder.ImplicitBuilder(branch.false_region): + scf.YieldOp(qubit) + return branch.results[0] + + def _hadamard_corrections( + self, + mres: list[builtin.IntegerType], + qubit: QubitType, + ): + """Add correction ops of a Hadamard gate to the subroutine. + + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubit (QubitType) : An auxiliary result qubit. + + Returns: + The result auxiliary qubit. + """ + m1, m2, m3, m4 = mres + + # X correction + x_parity = self._parity_check([m1, m3, m4]) + res_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubit) + + # Z correction + z_parity = self._parity_check([m2, m3]) + res_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", res_aux_qubit) + + return res_aux_qubit + + def _s_corrections( + self, + mres: list[builtin.IntegerType], + qubit: QubitType, + ): + """Add correction ops of a S gate to the subroutine. + + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubit (QubitType) : An auxiliary result qubit. + + Returns: + The result auxiliary qubit. + """ + m1, m2, m3, m4 = mres + + # X correction + x_parity = self._parity_check([m2, m4]) + res_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubit) + + # Z correction + z_parity = self._parity_check([m1, m2, m3], additional_const_one=True) + res_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", res_aux_qubit) + return res_aux_qubit + + def _rot_corrections( + self, + mres: list[builtin.IntegerType], + qubit: QubitType, + ): + """Add correction ops of a RotXZX or RZ gate to the subroutine. + + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubit (QubitType) : An auxiliary result qubit. + + Returns: + The result auxiliary qubit. + """ + m1, m2, m3, m4 = mres + # X correction + x_parity = self._parity_check([m2, m4]) + res_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubit) + + # Z correction + z_parity = self._parity_check([m1, m3]) + res_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", res_aux_qubit) + return res_aux_qubit + + def _cnot_corrections( + self, + mres: list[builtin.IntegerType], + qubits: list[QubitType], + ): + """Add correction ops of a CNOT gate to the subroutine. + + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubits (list[QubitType]) : A list of auxiliary result qubits. + + Returns: + The result auxiliary qubits. + """ + m1, m2, m3, m4, m5, m6, m8, m9, m10, m11, m12, m13, m14 = mres + # Corrections for the control qubit + x_parity = self._parity_check([m2, m3, m5, m6]) + ctrl_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubits[0]) + z_parity = self._parity_check([m1, m3, m4, m5, m8, m9, m11], additional_const_one=True) + ctrl_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", ctrl_aux_qubit) + + # Corrections for the target qubit + x_parity = self._parity_check([m2, m3, m8, m10, m12, m14]) + tgt_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubits[1]) + z_parity = self._parity_check([m9, m11, m13]) + tgt_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", tgt_aux_qubit) + + return ctrl_aux_qubit, tgt_aux_qubit + + def _queue_measurements( + self, gate_name: str, graph_qubit_dict, params: None | list[builtin.Float64Type] = None + ): + """Add measurement ops to the subroutine. + + Args: + gate_name (str): Gate name. + graph_qubit_dict (list[builtin.IntegerType]): A list of the mid-measurement results. + params (None | list[builtin.Float64Type]) : Parameters of the gate. + + Returns: + The measurement results and updated graph_qubit_dict. + + """ + match gate_name: + case "Hadamard": + return self._hadamard_measurements(graph_qubit_dict) + case "S": + return self._s_measurements(graph_qubit_dict) + case "RZ": + return self._rz_measurements(graph_qubit_dict, params) + case "RotXZX": + return self._rotxzx_measurements(graph_qubit_dict, params) + case "CNOT": + return self._cnot_measurements(graph_qubit_dict) + case _: + raise ValueError( + f"{gate_name} is not supported in the MBQC formalism. Please decompose it " + "into the MBQC gate set." + ) + + def _insert_byprod_corrections( + self, + gate_name: str, + mres: list[builtin.IntegerType], + qubits: QubitType | list[QubitType], + ): + """Add correction ops for the result auxiliary qubit/s to the subroutine. + Args: + gate_name (str): Gate name. + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubits (QubitType | list[QubitType]) : An or a list of auxiliary result qubit. + + Returns: + The result auxiliary qubits. + """ + match gate_name: + case "Hadamard": + return self._hadamard_corrections(mres, qubits) + case "S": + return self._s_corrections(mres, qubits) + case "RotXZX": + return self._rot_corrections(mres, qubits) + case "RZ": + return self._rot_corrections(mres, qubits) + case "CNOT": + return self._cnot_corrections(mres, qubits) + case _: + raise ValueError( + f"{gate_name} is not supported in the MBQC formalism. Please decompose it " + "into the MBQC gate set." + ) + + def _create_single_qubit_gate_subroutine(self, gate_name: str): + """Create a subroutine for a single qubit gate based on the given name. + Args: + gate_name (str): Name of the gate. + + Returns: + The corresponding subroutine (func.FuncOp). + """ + if gate_name not in _MBQC_ONE_QUBIT_GATES: + raise NotImplementedError(f"Subroutine for the {gate_name} gate is not supported.") + # ensure the order of parameters are aligned with customOp + input_types = () + if gate_name == "RZ": + input_types += (builtin.Float64Type(),) + if gate_name == "RotXZX": + input_types += (builtin.Float64Type(),) * 3 + input_types += (QubitType(),) + + output_types = (QubitType(),) + block = Block(arg_types=input_types) + + with builder.ImplicitBuilder(block): + in_qubits = [block.args[-1]] + params = None + if gate_name == "RZ": + params = [block.args[0]] + if gate_name == "RotXZX": + params = [ + block.args[0], + block.args[1], + block.args[2], + ] + + graph_qubit_dict = self._prep_graph_state(gate_name=gate_name) + + cz_op = CustomOp(in_qubits=[in_qubits[0], graph_qubit_dict[2]], gate_name="CZ") + + graph_qubit_dict[1], graph_qubit_dict[2] = cz_op.results + + mres, graph_qubit_dict = self._queue_measurements(gate_name, graph_qubit_dict, params) + + # The following could be removed to support Pauli tracker + by_product_correction = self._insert_byprod_corrections( + gate_name, mres, graph_qubit_dict[5] + ) + + graph_qubit_dict[5] = by_product_correction + + for node in graph_qubit_dict: + if node not in [5]: + _ = DeallocQubitOp(graph_qubit_dict[node]) + + func.ReturnOp(graph_qubit_dict[5]) + + region = Region([block]) + # pylint: disable=line-too-long + # Note that visibility is set as private to ensure the subroutines that are + # not called (dead code) can be eliminated as the + # ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # pass was added to the pipeline. + funcOp = func.FuncOp( + gate_name.lower() + "_in_mbqc", + (input_types, output_types), + visibility="private", + region=region, + ) + # Add an attribute to the mbqc transform subroutine + funcOp.attributes["mbqc_transform"] = builtin.NoneAttr() + return funcOp + + def _create_cnot_gate_subroutine(self): + """Create a subroutine for a CNOT gate.""" + gate_name = "CNOT" + input_types = ( + QubitType(), + QubitType(), + ) + output_types = ( + QubitType(), + QubitType(), + ) + block = Block(arg_types=input_types) + + with builder.ImplicitBuilder(block): + in_qubits = [block.args[0], block.args[1]] + + graph_qubit_dict = self._prep_graph_state(gate_name=gate_name) + + # Entangle the op.in_qubits[0] with the graph_qubits_dict[2] + cz_op = CustomOp(in_qubits=[in_qubits[0], graph_qubit_dict[2]], gate_name="CZ") + graph_qubit_dict[1], graph_qubit_dict[2] = cz_op.results + + # Entangle op.in_qubits[1] with with the graph_qubits_dict[10] for a CNOT gate + cz_op = CustomOp(in_qubits=[in_qubits[1], graph_qubit_dict[10]], gate_name="CZ") + graph_qubit_dict[9], graph_qubit_dict[10] = cz_op.results + + mres, graph_qubit_dict = self._queue_measurements(gate_name, graph_qubit_dict) + + # The following could be removed to support Pauli tracker + graph_qubit_dict[7], graph_qubit_dict[15] = self._insert_byprod_corrections( + gate_name, mres, [graph_qubit_dict[7], graph_qubit_dict[15]] + ) + + for node in graph_qubit_dict: + if node not in [7, 15]: + _ = DeallocQubitOp(graph_qubit_dict[node]) + + func.ReturnOp( + *( + graph_qubit_dict[7], + graph_qubit_dict[15], + ) + ) + + region = Region([block]) + # pylint: disable=line-too-long + # Note that visibility is set as private to ensure the subroutines that are + # not called (dead code) can be eliminated as the + # ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # pass was added to the pipeline. + funcOp = func.FuncOp( + gate_name.lower() + "_in_mbqc", + (input_types, output_types), + visibility="private", + region=region, + ) + # Add an attribute to the mbqc transform subroutine + funcOp.attributes["mbqc_transform"] = builtin.NoneAttr() + return funcOp + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the convert-to-mbqc-formalism pass.""" + # pylint: disable=line-too-long + # Insert subroutines for all gates in the MBQC gate set to the module. + # Note that the visibility of those subroutines are set as private, which ensure the + # ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # pass could eliminate the unreferenced subroutines. + subroutine_dict = {} + + for gate_name in _MBQC_ONE_QUBIT_GATES: + funcOp = self._create_single_qubit_gate_subroutine(gate_name) + op.regions[0].blocks.first.add_op(funcOp) + subroutine_dict[gate_name] = funcOp + + cnot_funcOp = self._create_cnot_gate_subroutine() + op.regions[0].blocks.first.add_op(cnot_funcOp) + subroutine_dict["CNOT"] = cnot_funcOp + + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier( + [ConvertToMBQCFormalismPattern(subroutine_dict)] + ), + apply_recursively=False, + ).rewrite_module(op) + + +convert_to_mbqc_formalism_pass = compiler_transform(ConvertToMBQCFormalismPass) + + +class ConvertToMBQCFormalismPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for converting to the MBQC formalism.""" + + def __init__(self, subroutines_dict): + self.subroutine_dict = subroutines_dict + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, + root: func.FuncOp | IfOp | WhileOp | ForOp | IndexSwitchOp, + rewriter: pattern_rewriter.PatternRewriter, + /, + ): + """Match and rewrite for converting to the MBQC formalism.""" + + # Ensure that "Hadamard"/"CZ" gates in mbqc_transform subroutines are not converted. + if isinstance(root, func.FuncOp) and "mbqc_transform" in root.attributes: + return + + for region in root.regions: + # Continue if the region has no block (i.e., function that has no body, and the body is + # defined in runtime.) + if not region.blocks: + continue + + for op in region.ops: + if isinstance(op, CustomOp) and op.gate_name.data in _MBQC_GATES: + callee = builtin.SymbolRefAttr(op.gate_name.data.lower() + "_in_mbqc") + arguments = [] + for param in op.params: + arguments.append(param) + for qubit in op.in_qubits: + arguments.append(qubit) + + return_types = self.subroutine_dict[ + op.gate_name.data + ].function_type.outputs.data + callOp = func.CallOp(callee, arguments, return_types) + rewriter.insert_op(callOp, InsertPoint.before(op)) + for i, out_qubit in enumerate(op.out_qubits): + rewriter.replace_all_uses_with(out_qubit, callOp.results[i]) + rewriter.erase_op(op) + + elif isinstance(op, GlobalPhaseOp) or ( + isinstance(op, CustomOp) and op.gate_name.data in _PAULIS + ): + continue + elif isinstance(op, CustomOp): + raise NotImplementedError( + f"{op.gate_name.data} cannot be converted to the MBQC formalism." + ) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py b/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py new file mode 100644 index 0000000000..0d89b171ba --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py @@ -0,0 +1,207 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the implementations of the decompose_graph_state and +null_decompose_graph_state transforms, written using xDSL. + +.. note:: + + The transforms contained in this module make frequent use of the *densely packed adjacency + matrix* graph representation. For a detailed description of this graph representation, see the + documentation for the GraphStatePrepOp operation in the MBQC dialect in Catalyst. +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TypeAlias + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern + +from catalyst.python_interface.dialects import mbqc, quantum +from catalyst.python_interface.pass_api import compiler_transform + +from .graph_state_utils import edge_iter, n_vertices_from_packed_adj_matrix + +DenselyPackedAdjMatrix: TypeAlias = Sequence[int] | Sequence[bool] + + +@dataclass(frozen=True) +class DecomposeGraphStatePass(passes.ModulePass): + """The decompose-graph-state pass replaces ``graph_state_prep`` operations with their + corresponding sequence of quantum operations for execution on state simulators. + """ + + name = "decompose-graph-state" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the decompose-graph-state pass.""" + + walker = pattern_rewriter.PatternRewriteWalker(DecomposeGraphStatePattern()) + walker.rewrite_module(op) + + +decompose_graph_state_pass = compiler_transform(DecomposeGraphStatePass) + + +# pylint: disable=too-few-public-methods +class DecomposeGraphStatePattern(RewritePattern): + """Rewrite pattern for the decompose-graph-state transform.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, graph_prep_op: mbqc.GraphStatePrepOp, rewriter: PatternRewriter, /): + """Match and rewrite pattern for graph_state_prep ops.""" + # These are the names of the gates that realize the desired initial individual qubit state + # and entangled state, respectively + init_op_gate_name = graph_prep_op.init_op.data + entangle_op_gate_name = graph_prep_op.entangle_op.data + + adj_matrix = _parse_adj_matrix(graph_prep_op) + n_vertices = n_vertices_from_packed_adj_matrix(adj_matrix) + + # Allocate a register with as many qubits as vertices in the graph + alloc_op = quantum.AllocOp(n_vertices) + rewriter.insert_op(alloc_op) + + # This dictionary maps wires indices in the register to qubit SSA values + graph_qubits_map: dict[int, quantum.QubitSSAValue] = {} + + # In this section, we create the sequences of quantum.extract, quantum.custom (for the init + # and entangle gates) and quantum.insert ops, and gather them up into list for insertion + # later on. + qextract_ops: list[quantum.ExtractOp] = [] + for i in range(n_vertices): + qextract_op = quantum.ExtractOp(alloc_op.qreg, i) + qextract_ops.append(qextract_op) + graph_qubits_map[i] = qextract_op.qubit + + init_ops: list[quantum.CustomOp] = [] + for i in range(n_vertices): + init_op = quantum.CustomOp(in_qubits=graph_qubits_map[i], gate_name=init_op_gate_name) + init_ops.append(init_op) + graph_qubits_map[i] = init_op.out_qubits[0] + + entangle_ops: list[quantum.CustomOp] = [] + for edge in edge_iter(adj_matrix): + q0 = graph_qubits_map[edge[0]] + q1 = graph_qubits_map[edge[1]] + entangle_op = quantum.CustomOp(in_qubits=(q0, q1), gate_name=entangle_op_gate_name) + entangle_ops.append(entangle_op) + graph_qubits_map[edge[0]] = entangle_op.out_qubits[0] + graph_qubits_map[edge[1]] = entangle_op.out_qubits[1] + + qinsert_ops: list[quantum.InsertOp] = [] + qreg = alloc_op.qreg + for i in range(n_vertices): + qinsert_op = quantum.InsertOp(in_qreg=qreg, idx=i, qubit=graph_qubits_map[i]) + qinsert_ops.append(qinsert_op) + qreg = qinsert_op.out_qreg + + # In this section, we iterate over the ops created above and insert them. + # Note that we do not need to specify the insertion point here; all ops are inserted before + # the matched op, automatically putting them in the order we want them in. + for qextract_op in qextract_ops: + rewriter.insert_op(qextract_op) + + for init_op in init_ops: + rewriter.insert_op(init_op) + + for entangle_op in entangle_ops: + rewriter.insert_op(entangle_op) + + for qinsert_op in qinsert_ops: + rewriter.insert_op(qinsert_op) + + # The register that is the result of the last quantum.insert op replaces the register that + # was the result of the graph_state_prep op + rewriter.replace_all_uses_with(graph_prep_op.results[0], qinsert_ops[-1].results[0]) + + # Finally, erase the ops that have now been replaced with quantum ops + rewriter.erase_matched_op() + + # Erase the constant op that returned the adjacency matrix only if it has no other uses + if graph_prep_op.adj_matrix.uses.get_length() == 0: + rewriter.erase_op(graph_prep_op.adj_matrix.owner) + + +@dataclass(frozen=True) +class NullDecomposeGraphStatePass(passes.ModulePass): + """The null-decompose-graph-state pass replaces ``graph_state_prep`` operations with a single + quantum-register allocation operation for execution on null devices. + """ + + name = "null-decompose-graph-state" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the null-decompose-graph-state pass.""" + + walker = pattern_rewriter.PatternRewriteWalker(NullDecomposeGraphStatePattern()) + walker.rewrite_module(op) + + +null_decompose_graph_state_pass = compiler_transform(NullDecomposeGraphStatePass) + + +# pylint: disable=too-few-public-methods +class NullDecomposeGraphStatePattern(RewritePattern): + """Rewrite pattern for the null-decompose-graph-state transform.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, graph_prep_op: mbqc.GraphStatePrepOp, rewriter: PatternRewriter, /): + """Match and rewrite pattern for graph_state_prep ops.""" + adj_matrix = _parse_adj_matrix(graph_prep_op) + n_vertices = n_vertices_from_packed_adj_matrix(adj_matrix) + + # Allocate a register with as many qubits as vertices in the graph + alloc_op = quantum.AllocOp(n_vertices) + rewriter.insert_op(alloc_op) + + # The newly allocated register replaces the register that was the result of the + # graph_state_prep op + rewriter.replace_all_uses_with(graph_prep_op.results[0], alloc_op.results[0]) + + # Finally, erase the ops that have now been replaced with quantum ops + rewriter.erase_matched_op() + + # Erase the constant op that returned the adjacency matrix only if it has no other uses + if graph_prep_op.adj_matrix.uses.get_length() == 0: + rewriter.erase_op(graph_prep_op.adj_matrix.owner) + + +def _parse_adj_matrix(graph_prep_op: mbqc.GraphStatePrepOp) -> list[int]: + """Parse the adjacency matrix from the result of the ConstantOp given as input to the + graph_state_prep op. + + We assume that the adjacency matrix is stored as a DenseIntOrFPElementsAttr, whose data is + accessible as a Python 'bytes' array. Converting this bytes array to a list results in integer + elements, whose values are typically either 0 for 'false' or 255 for 'true'. + + Returns: + list[int]: The densely packed adjacency matrix as a list of ints. See the note in the module + documentation for a description of this format. + """ + adj_matrix_const_op = graph_prep_op.adj_matrix.owner + adj_matrix_value = adj_matrix_const_op.properties.get("value") + assert adj_matrix_value is not None and hasattr( + adj_matrix_value, "data" + ), f"Unable to read graph adjacency matrix from op `{adj_matrix_const_op}`" + + adj_matrix_bytes = adj_matrix_value.data + assert isinstance(adj_matrix_bytes, builtin.BytesAttr), ( + f"Expected graph adjacency matrix data to be of type 'builtin.BytesAttr', but got " + f"{type(adj_matrix_bytes).__name__}" + ) + + return list(adj_matrix_bytes.data) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py b/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py new file mode 100644 index 0000000000..f5b7369d3a --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py @@ -0,0 +1,260 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for graph state representations""" + +import math +from collections.abc import Generator, Sequence +from typing import TypeAlias + +from pennylane.exceptions import CompileError + +DenselyPackedAdjMatrix: TypeAlias = Sequence[int] | Sequence[bool] + + +_MBQC_GATE_SET = {"Hadamard", "S", "RZ", "RotXZX", "CNOT"} +_SINGLE_QUBIT_AUX_WIRE_NUM = 4 +_CNOT_AUX_WIRE_NUM = 13 + + +def get_num_aux_wires(gate_name: str) -> int: + """ + Return the number of auxiliary wires required for gates from the MBQC gate set. + The number of auxiliary qubits for a single qubit gate is 4, while it is 13 for a + CNOT gate. + + Args: + gate_name (str): The name of a gate. + + Returns: + The number of auxiliary wires. + """ + if gate_name == "CNOT": + return _CNOT_AUX_WIRE_NUM + if gate_name in _MBQC_GATE_SET: + return _SINGLE_QUBIT_AUX_WIRE_NUM + raise ValueError(f"{gate_name} is not supported in the MBQC formalism.") + + +def get_graph_state_edges(gate_name: str) -> list[tuple[int, int]]: + """ + Return a list of edges information in the graph state of a gate. + + - The connectivity of the target qubits in the register and auxiliary qubits for a + single-qubit gate is: + + tgt -- 0 -- 1 -- 2 -- 3 + + Note that the target qubit is not in the adjacency matrix and the connectivity + of the auxiliary qubits is: + edges_in_adj_matrix = [ + (0, 1), + (1, 2), + (2, 3), + ] + + Wire 1 in the above isn't the target wire described in the Fig.2 of + `arXiv:quant-ph/0301052 `_], + 1 in the above maps to 3 in the figure. + + - The connectivity of the ctrl/target qubits in the register and auxiliary qubits for a + CNOT gate is: + + ctl -- 0 -- 1 -- 2 -- 3 -- 4 -- 5 + | + 6 + | + tgt -- 7 -- 8 -- 9 -- 10 -- 11 -- 12 + + Note that both ctrl and target qubits are not in the adjacency matrix and the connectivity + of the auxiliary qubits is: + edges_in_adj_matrix = [ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (2, 6), + (7, 8), + (6, 9), + (8, 9), + (9, 10), + (10, 11), + (11, 12), + ] + + This graph is labelled based on the rows and columns of the adjacent matrix, but maps on + to the graph described in the Fig.2 of + [`arXiv:quant-ph/0301052 `_], where wire 1 is the + control and wire 9 is the target. + + Args: + gate_name (str): The name of a gate. + + Returns: + A list of edges information in the graph state for the given gate. + """ + + if gate_name == "CNOT": + return [ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (2, 6), + (7, 8), + (6, 9), + (8, 9), + (9, 10), + (10, 11), + (11, 12), + ] + if gate_name in _MBQC_GATE_SET: + return [ + (0, 1), + (1, 2), + (2, 3), + ] + raise ValueError(f"{gate_name} is not supported in the MBQC formalism.") + + +def n_vertices_from_packed_adj_matrix(adj_matrix: DenselyPackedAdjMatrix) -> int: + """Returns the number of vertices in the graph represented by the given densely packed adjacency + matrix. + + Args: + adj_matrix (DenselyPackedAdjMatrix): The densely packed adjacency matrix, given as a + sequence of bools or ints. See the note in the module documentation for a description of + this format. + + Raises: + CompileError: If the number of elements in `adj_matrix` is not compatible with the number of + elements in the lower-triangular part of a square matrix, excluding the elements along + the diagonal. + + Returns: + int: The number of vertices in the graph. + + Example: + >>> _n_vertices_from_packed_adj_matrix([1, 1, 0, 0, 1, 1]) + 4 + """ + assert isinstance( + adj_matrix, Sequence + ), f"Expected `adj_matrix` to be a sequence, but got {type(adj_matrix).__name__}" + + m = len(adj_matrix) + + # The formula to compute the number of vertices, N, in the graph from the number elements in the + # densely packed adjacency matrix, m, is + # N = (1 + sqrt(1 + 8m)) / 2 + # To avoid floating-point errors in the sqrt function, we break it down into integer-arithmetic + # operations and ensure that the solution is one where N is mathematically a true integer. + + discriminant = 1 + 8 * m + sqrt_discriminant = math.isqrt(discriminant) + + # Check if it's a perfect square + if sqrt_discriminant * sqrt_discriminant != discriminant: + raise CompileError( + f"The number of elements in the densely packed adjacency matrix is {m}, which does not " + f"correspond to an integer number of graph vertices" + ) + + # The numerator, 1 + sqrt(1 + 8m), must be even for the result to be an integer. The quantity + # sqrt(1 + 8m) will always be odd if it's a perfect square, so the quantity (1 + sqrt(1 + 8m)) + # will always be even. We can therefore safely divide (using integer division). + return (1 + sqrt_discriminant) // 2 + + +def edge_iter(adj_matrix: DenselyPackedAdjMatrix) -> Generator[tuple[int, int], None, None]: + """Generate an iterator over the edges in a graph represented by the given densely packed + adjacency matrix. + + Args: + adj_matrix (DenselyPackedAdjMatrix): The densely packed adjacency matrix, given as a + sequence of bools or ints. See the note in the module documentation for a description of + this format. + + Yields: + tuple[int, int]: The next edge in the graph, represented as the pair of vertices labelled + according to their indices in the adjacency matrix. + + Example: + >>> for edge in _edge_iter([1, 1, 0, 0, 1, 1]): + ... print(edge) + (0, 1) + (0, 2) + (1, 3) + (2, 3) + """ + # Calling `_n_vertices_from_packed_adj_matrix()` asserts that the input `adj_matrix` is in the + # correct format and is valid. + n_vertices_from_packed_adj_matrix(adj_matrix) + + j = 1 + k = 0 + for entry in adj_matrix: + if entry: + yield (k, j) + k += 1 + if k == j: + k = 0 + j += 1 + + +def _adj_matrix_generation_helper( + num_vertices: int, edges_in_adj_matrix: list[tuple[int, int]] +) -> list: + """Helper function to generate an adjacency matrix to represent the connectivity of auxiliary + qubits in a graph state for a gate operation with the number of vertices and edges information. + Note that the adjacency matrix here means the lower triangular part of the full adjacency + matrix. It can be represented as below and `x` marks here denotes the matrix diagonal. + x + + x + + + x + . + ........ + . + + + + + + x + + Args: + num_vertices (int) : Number of vertices in the adjacency matrix. + edges_in_adj_matrix (list[tuple[int, int]]): List of edges in the adjacency matrix. + + Return: + An adjacency matrix represents the connectivity of vertices. + """ + adj_matrix_length = num_vertices * (num_vertices - 1) // 2 + adj_matrix = [0] * adj_matrix_length + for edge in edges_in_adj_matrix: + col, row = edge + n = col + (row - 1) * row // 2 + adj_matrix[n] = 1 + + return adj_matrix + + +def generate_adj_matrix(op_name: str) -> list: + """Generate an adjacency matrix represents the connectivity of auxiliary qubits in a + graph state for a gate operation. + Args: + op_name (str): The gate name. Note that only a gate in the MBQC gate set is supported. + Returns: + An adjacent matrix represents the connectivity of auxiliary qubits. + """ + num_aux_wires = get_num_aux_wires(op_name) + edges_list = get_graph_state_edges(op_name) + return _adj_matrix_generation_helper(num_aux_wires, edges_list) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py b/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py new file mode 100644 index 0000000000..6035040c18 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py @@ -0,0 +1,477 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the outline_state_evolution transform. + +Known limitations +----------------- + + * If the current pass is applied multiple times, the transform will fail as it would redefined + the `state_evolution` func. This is caused by the way we define the terminal_boundary_op. + Each time the pass is applied to the IR, it would insert a new terminal_boundary_op into the + IR. TODOs: Instead of inserting a new `terminal_boundary_op` op to the IR when applying the + pass, it would be better to: + 1. define a quantum.terminator op before this pass and use it as a delineation of + quantum gate operation; + 2. move the `simplify_io` to a separate pass. +""" + +from dataclasses import dataclass +from itertools import chain + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func +from xdsl.ir import Operation, SSAValue +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform + + +@dataclass(frozen=True) +class OutlineStateEvolutionPass(passes.ModulePass): + """Pass that puts gate operations into an outline_state_evolution callable.""" + + name = "outline-state-evolution" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the outline-state-evolution pass.""" + for op_ in op.ops: + if isinstance(op_, func.FuncOp) and "qnode" in op_.attributes: + rewriter = pattern_rewriter.PatternRewriter(op_) + OutlineStateEvolutionPattern().match_and_rewrite(op_, rewriter) + + +outline_state_evolution_pass = compiler_transform(OutlineStateEvolutionPass) + + +class OutlineStateEvolutionPattern(pattern_rewriter.RewritePattern): + """RewritePattern for outlined state evolution regions in a quantum function.""" + + # pylint: disable=too-few-public-methods + def _get_parent_module(self, op: func.FuncOp) -> builtin.ModuleOp: + """Get the first ancestral builtin.ModuleOp op of a given func.func op.""" + while (op := op.parent_op()) and not isinstance(op, builtin.ModuleOp): + pass + if op is None: + raise RuntimeError( + "The given qnode func is not nested within a builtin.module. Please ensure the " + "qnode func is defined in a builtin.module." + ) + return op + + def __init__(self): + self.module: builtin.ModuleOp = None + self.original_func_op: func.FuncOp = None + + # To determine the boundary of quantum gate operations in the IR + self.alloc_op: quantum.AllocOp = None + self.terminal_boundary_op: Operation = None + + # Input and outputs of the state evolution func + self.required_inputs: list[SSAValue] = None + self.required_outputs: list[SSAValue] = None + + # State evolution function region + self.state_evolution_func: func.FuncOp = None + + def match_and_rewrite( + self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, / + ): + """Transform a quantum function (qnode) to outline state evolution regions. + This implementation assumes that there is only one `quantum.alloc` operation in + the func operations with a "qnode" attribute and all quantum operations are between + the unique `quantum.alloc` operation and the terminal_boundary_op. All operations in between + are to be moved to the newly created outline-state-evolution function operation.""" + if "qnode" not in func_op.attributes: + return + + self.original_func_op = func_op + + self.module = self._get_parent_module(func_op) + + # Simplify the quantum I/O to use only registers at boundaries + self._simplify_quantum_io(func_op, rewriter) + + # Create a new function op for the state evolution region and insert it + # into the parent scope of the original func with qnode attribute + self._create_state_evolution_function(rewriter) + + # Replace the original region with a call to the state evolution function + # by inserting the corresponding callOp and update the rest of operations + # in the qnode func. + self._finalize_transformation() + + def _get_qubit_idx(self, op: Operation) -> int | None: + """Get the index of qubit that an ExtractOp op extracts.""" + if getattr(op, "idx", None): + return op.idx + return getattr(op, "idx_attr", None) + + # pylint: disable=too-many-arguments,too-many-positional-arguments + def _set_up_terminal_boundary_op( + self, + current_reg: quantum.QuregType, + terminal_boundary_op: Operation | None, + qubit_to_reg_idx: dict, + op: Operation, + rewriter: pattern_rewriter.PatternRewriter, + ): + """Set up the terminal boundary operation. This terminal_boundary_op is set as the last + quantum.insert operations added to the IR.""" + insert_ops = set() + + # Insert all qubits recorded in the qubit_to_reg_idx dict before the + # pre-assumed terminal operations. + insertion_point = InsertPoint.before(op) + for qb, idx in qubit_to_reg_idx.items(): + insert_op = quantum.InsertOp(current_reg, idx, qb) + rewriter.insert_op(insert_op, insertion_point=insertion_point) + insert_ops.add(insert_op) + terminal_boundary_op = insert_op + current_reg = insert_op.out_qreg + + # Add the `"terminal_boundary"` attribute to the last newly added + # `quantum.insert` operation. + if terminal_boundary_op is None: + raise RuntimeError("A terminal_boundary_op op is not found in the circuit.") + terminal_boundary_op.attributes["terminal_boundary"] = builtin.UnitAttr() + prev_qreg = terminal_boundary_op.in_qreg + + # extract ops + insertion_point = InsertPoint.before(op) + for qb, idx in list(qubit_to_reg_idx.items()): + extract_op = quantum.ExtractOp(current_reg, idx) + rewriter.insert_op(extract_op, insertion_point=insertion_point) + qb.replace_by_if(extract_op.qubit, lambda use: use.operation not in insert_ops) + for use in qb.uses: + rewriter.notify_op_modified(use.operation) + # update the qubit_to_reg_idx dict + qubit_to_reg_idx[extract_op.qubit] = idx + # pop out qb from the dict + del qubit_to_reg_idx[qb] + return current_reg, prev_qreg, terminal_boundary_op + + # pylint: disable=too-many-branches + def _simplify_quantum_io( + self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter + ) -> func.FuncOp: + """Simplify quantum I/O to use only registers at segment boundaries. + + This ensures that state evolution regions only take registers as input/output, + not individual qubits. + """ + current_reg = None + # Note that all qubits recorded in the `qubit_to_reg_idx` will be inserted into + # the IR and the last insert_op will be set as the `terminal_boundary_op`.` + qubit_to_reg_idx = {} + terminal_boundary_op = None + terminal_op_in_reg = None + + for op in func_op.body.ops: + match op: + case quantum.AllocOp(): + current_reg = op.qreg + case quantum.ExtractOp(): + # Update register mapping + extract_idx = self._get_qubit_idx(op) + qubit_to_reg_idx[op.qubit] = extract_idx + # branch to update extract_op with new qreg + if op.qreg is terminal_op_in_reg: + insertion_point = InsertPoint.before(op) + extract_op = quantum.ExtractOp(current_reg, extract_idx) + rewriter.insert_op(extract_op, insertion_point=insertion_point) + rewriter.replace_all_uses_with(op.results[0], extract_op.results[0]) + rewriter.erase_op(op) + + case quantum.MeasureOp(): + # TODOs: what if the qubit that quantum.measure target at is reset? + # Not a concern by EOY 2025 + qubit_to_reg_idx[op.out_qubit] = qubit_to_reg_idx[op.in_qubit] + del qubit_to_reg_idx[op.in_qubit] + case quantum.CustomOp(): + # To update the qubit_to_reg_idx map for the return type. + for i, qb in enumerate(chain(op.in_qubits, op.in_ctrl_qubits)): + qubit_to_reg_idx[op.results[i]] = qubit_to_reg_idx[qb] + del qubit_to_reg_idx[qb] + case quantum.InsertOp(): + if not terminal_op_in_reg: + if op.idx_attr and qubit_to_reg_idx[op.qubit] is not op.idx_attr: + raise ValueError("op.qubit should be op.idx_attr.") + del qubit_to_reg_idx[op.qubit] + current_reg = op.out_qreg + # branch to update insert_op with new qreg + if op.in_qreg is terminal_op_in_reg: + insertion_point = InsertPoint.before(op) + index = op.idx if op.idx else op.idx_attr + insert_op = quantum.InsertOp(current_reg, index, op.qubit) + rewriter.insert_op(insert_op, insertion_point=insertion_point) + rewriter.replace_all_uses_with(op.out_qreg, insert_op.out_qreg) + rewriter.erase_op(op) + + case _ if ( + isinstance( + op, + ( + quantum.ComputationalBasisOp, + quantum.NamedObsOp, + quantum.HamiltonianOp, + quantum.TensorOp, + ), + ) + and not terminal_boundary_op + ): + current_reg, terminal_op_in_reg, terminal_boundary_op = ( + self._set_up_terminal_boundary_op( + current_reg, terminal_boundary_op, qubit_to_reg_idx, op, rewriter + ) + ) + case _: + # Handle other operations that might has qreg result + # Note that this branch might not be tested so far as adjoint op is not + # tested so far. + if reg := next( + (reg for reg in op.results if isinstance(reg.type, quantum.QuregType)), None + ): + current_reg = reg + + def _create_state_evolution_function(self, rewriter: pattern_rewriter.PatternRewriter): + """Create a new func.func for the state evolution region using clone approach.""" + + alloc_op, terminal_boundary_op = self._find_evolution_range() + + # collect operation from alloc_op to terminal_boundary_op + ops_to_clone = self._collect_operations_in_range(alloc_op, terminal_boundary_op) + + # collect required inputs for the state evolution funcOp + required_inputs = self._collect_required_inputs_for_state_evolution_func(ops_to_clone) + + # collect required outputs for the state evolution funcOp + required_outputs = self._collect_required_outputs(ops_to_clone, terminal_boundary_op) + + register_inputs = [] + other_inputs = [] + for val in required_inputs: + if isinstance(val.type, quantum.QuregType): + register_inputs.append(val) + else: + other_inputs.append(val) + + register_outputs = [] + other_outputs = [] + for val in required_outputs: + if isinstance(val.type, quantum.QuregType): + register_outputs.append(val) + else: + other_outputs.append(val) + + ordered_inputs = register_inputs + other_inputs + ordered_outputs = register_outputs + other_outputs + + input_types = [val.type for val in ordered_inputs] + output_types = [val.type for val in ordered_outputs] + fun_type = builtin.FunctionType.from_lists(input_types, output_types) + + # create a new func.func op and insert it into the IR + state_evolution_func = func.FuncOp( + self.original_func_op.sym_name.data + ".state_evolution", fun_type, visibility="public" + ) + rewriter.insert_op(state_evolution_func, InsertPoint.at_end(self.module.body.block)) + + # pylint: disable=line-too-long + # TODOs: how to define the `value_mapper` arg is not stated in the xdl.core module + # [here](https://github.com/xdslproject/xdsl/blob/e1301e0204bcf6ea5ed433e7da00bee57d07e695/xdsl/ir/core.py#L1429)_. + # It looks like storing ssa value to be cloned would maintain the dependency + # relationship required to build the new DAG for the new ops. + block = state_evolution_func.regions[0].block + value_mapper = {} # only args ssavlue is required + for input_, block_arg in zip(ordered_inputs, block.args): + value_mapper[input_] = block_arg + + self._clone_operations_to_block(ops_to_clone, block, value_mapper) + self._add_return_statement(block, ordered_outputs, value_mapper) + + self.required_inputs = ordered_inputs + self.required_outputs = ordered_outputs + self.alloc_op = alloc_op + self.terminal_boundary_op = terminal_boundary_op + self.state_evolution_func = state_evolution_func + + def _find_evolution_range(self): + """find alloc_op and terminal_boundary_op""" + alloc_op = None + terminal_boundary_op = None + + for op in self.original_func_op.body.walk(): + if isinstance(op, quantum.AllocOp): + alloc_op = op + elif hasattr(op, "attributes") and "terminal_boundary" in op.attributes: + terminal_boundary_op = op + + if not (alloc_op and terminal_boundary_op): + raise RuntimeError("Could not find both alloc_op and terminal_boundary_op") + + return alloc_op, terminal_boundary_op + + def _collect_operations_in_range(self, begin_op, end_op): + """collect top-level operations in range, let op.clone() handle nesting""" + ops_to_clone = [] + + if begin_op.parent_block() != end_op.parent_block(): + raise ValueError("begin_op and end_op are not in the same block") + + block = begin_op.parent_block() + + # skip until begin_op + op_iter = iter(block.ops) + while (op := next(op_iter, None)) != begin_op: + pass + + # collect top-level operations until end_op + while (op := next(op_iter, None)) != end_op: + ops_to_clone.append(op) + + # collect the terminal_boundary_op + if op is not None: + ops_to_clone.append(op) + + return ops_to_clone + + def _collect_required_inputs_for_state_evolution_func( + self, ops: list[Operation] + ) -> list[SSAValue]: + """Collect required inputs for the state evolution funcOp with a given list of operations. + Note that this method does not intent to keep the order of required input SSAValues. + """ + ops_walk = list(chain(*[op.walk() for op in ops])) + + # a set records the ssa values defined by the ops list + ops_defined_values = set() + # a set records all the ssa values required for all operations in the ops list + all_operands = set() + + for nested_op in ops_walk: + ops_defined_values.update(nested_op.results) + all_operands.update(nested_op.operands) + + if hasattr(nested_op, "regions") and nested_op.regions: + for region in nested_op.regions: + for block in region.blocks: + ops_defined_values.update(block.args) + + # the ssa values not defined by the operations in the ops list + missing_defs = list(all_operands - ops_defined_values) + required_inputs = [v for v in missing_defs if v is not None] + + return required_inputs + + def _collect_required_outputs( + self, ops: list[Operation], terminal_op: Operation + ) -> list[SSAValue]: + """Get required outputs for the state evolution funcOp with a given list of operations. + Note: It only considers the values that are defined in the operations and required by + the operations after the terminal operation. + """ + ops_walk = list(chain(*[op.walk() for op in ops])) + + ops_defined_values = set() + + for op_walk in ops_walk: + ops_defined_values.update(op_walk.results) + + # use list here to maintain the order of required outputs + required_outputs = [] + found_terminal = False + for op in self.original_func_op.body.walk(): + # branch for the operations before the terminal_op + if op == terminal_op: + found_terminal = True + continue + + # branch for the operations after the terminal_op + if found_terminal: + for operand in op.operands: + # a required output is an operand defined by a result of op in the ops + if operand in ops_defined_values and operand not in required_outputs: + required_outputs.append(operand) + + return required_outputs + + def _add_return_statement(self, target_block, required_outputs, value_mapper): + """add a func.return op to function""" + return_values = [] + for output_val in required_outputs: + if output_val not in value_mapper: + raise ValueError(f"output_val {output_val} not in value_mapper") + return_values.append(value_mapper[output_val]) + + return_op = func.ReturnOp(*return_values) + target_block.add_op(return_op) + + def _clone_operations_to_block(self, ops_to_clone, target_block, value_mapper): + """Clone operations to target block, use value_mapper to update references""" + for op in ops_to_clone: + cloned_op = op.clone(value_mapper) + target_block.add_op(cloned_op) + + def _finalize_transformation(self): + """Replace the original function with a call to the state evolution function.""" + original_block = self.original_func_op.body.block + ops_list = list(original_block.ops) + + begin_idx = None + end_idx = None + for i, op in enumerate(ops_list): + if op == self.alloc_op: + begin_idx = i + 1 + elif op == self.terminal_boundary_op: + end_idx = i + 1 + break + + if begin_idx is None: + raise RuntimeError("A quantum.alloc operation is not found in original function.") + if end_idx is None: + raise RuntimeError( + "A terminal_boundary_op operation is not found in original function." + ) + if begin_idx > end_idx: + raise RuntimeError( + "A quantum.alloc operation should come before the terminal_boundary_op." + ) + + post_ops = ops_list[end_idx:] + + call_args = list(self.required_inputs) + result_types = [val.type for val in self.required_outputs] + + call_op = func.CallOp(self.state_evolution_func.sym_name.data, call_args, result_types) + + # TODO: I just removed all ops and add them again to update with value_mapper. + # It's not efficient, just because it's easy to implement. Should using replace use method + # instead. + # De-attach all ops of the original function + call_result_mapper = {} + for i, required_output in enumerate(self.required_outputs): + call_result_mapper[required_output] = call_op.results[i] + + value_mapper = call_result_mapper.copy() + original_block.add_op(call_op) + for op in post_ops: + cloned_op = op.clone(value_mapper) + original_block.add_op(cloned_op) + + # replace ops_list with call_op + for op in chain(reversed(post_ops), reversed(ops_list[begin_idx:end_idx])): + op.detach() + op.erase() diff --git a/frontend/catalyst/python_interface/transforms/quantum/__init__.py b/frontend/catalyst/python_interface/transforms/quantum/__init__.py new file mode 100644 index 0000000000..a3141f8b78 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""xDSL API for quantum transforms""" +from .cancel_inverses import IterativeCancelInversesPass, iterative_cancel_inverses_pass +from .combine_global_phases import CombineGlobalPhasesPass, combine_global_phases_pass +from .diagonalize_measurements import ( + DiagonalizeFinalMeasurementsPass, + diagonalize_final_measurements_pass, +) +from .measurements_from_samples import ( + MeasurementsFromSamplesPass, + measurements_from_samples_pass, +) +from .merge_rotations import MergeRotationsPass, merge_rotations_pass +from .split_non_commuting import SplitNonCommutingPass, split_non_commuting_pass + +__all__ = [ + "combine_global_phases_pass", + "CombineGlobalPhasesPass", + "diagonalize_final_measurements_pass", + "DiagonalizeFinalMeasurementsPass", + "iterative_cancel_inverses_pass", + "IterativeCancelInversesPass", + "measurements_from_samples_pass", + "MeasurementsFromSamplesPass", + "merge_rotations_pass", + "MergeRotationsPass", + "split_non_commuting_pass", + "SplitNonCommutingPass", +] diff --git a/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py b/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py new file mode 100644 index 0000000000..51d668b2ba --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py @@ -0,0 +1,100 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the cancel_inverses transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func +from xdsl.ir import Operation + +from catalyst.python_interface.dialects.quantum import CustomOp +from catalyst.python_interface.pass_api import compiler_transform + +self_inverses = [ + "Identity", + "Hadamard", + "PauliX", + "PauliY", + "PauliZ", + "CNOT", + "CZ", + "CY", + "CH", + "SWAP", + "Toffoli", + "CCZ", +] + + +def _can_cancel(op: CustomOp, next_op: Operation) -> bool: + if isinstance(next_op, CustomOp): + if op.gate_name.data == next_op.gate_name.data: + if ( + op.out_qubits == next_op.in_qubits + and op.out_ctrl_qubits == next_op.in_ctrl_qubits + and op.in_ctrl_values == next_op.in_ctrl_values + ): + return True + + return False + + +class IterativeCancelInversesPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for iteratively cancelling consecutive self-inverse gates.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, /): + """Implementation of rewriting FuncOps that may contain operations corresponding to + self-inverse gates.""" + for op in funcOp.body.walk(): + + while isinstance(op, CustomOp) and op.gate_name.data in self_inverses: + + next_user = None + for use in op.results[0].uses: + user = use.operation + if _can_cancel(op, user): + next_user: CustomOp = user + break + + if next_user is None: + break + + for q1, q2 in zip(op.in_qubits, next_user.out_qubits, strict=True): + rewriter.replace_all_uses_with(q2, q1) + for cq1, cq2 in zip(op.in_ctrl_qubits, next_user.out_ctrl_qubits, strict=True): + rewriter.replace_all_uses_with(cq2, cq1) + rewriter.erase_op(next_user) + rewriter.erase_op(op) + op = op.in_qubits[0].owner + + +@dataclass(frozen=True) +class IterativeCancelInversesPass(passes.ModulePass): + """Pass for iteratively cancelling consecutive self-inverse gates.""" + + name = "xdsl-cancel-inverses" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the iterative cancel inverses pass.""" + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([IterativeCancelInversesPattern()]) + ).rewrite_module(op) + + +iterative_cancel_inverses_pass = compiler_transform(IterativeCancelInversesPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py b/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py new file mode 100644 index 0000000000..cea795a3e3 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py @@ -0,0 +1,87 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the combine_global_phases transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func +from xdsl.dialects.scf import ForOp, IfOp, WhileOp +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.quantum import GlobalPhaseOp +from catalyst.python_interface.pass_api import compiler_transform + + +class CombineGlobalPhasesPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for combining all :class:`~pennylane.GlobalPhase` gates within the same region + at the last global phase gate.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, + root: func.FuncOp | IfOp | ForOp | WhileOp, + rewriter: pattern_rewriter.PatternRewriter, + /, + ): # pylint: disable=cell-var-from-loop + """Match and rewrite for the combine-global-phases pattern acting on functions or + control-flow blocks containing GlobalPhase operations. + """ + + for region in root.regions: + phi = None + global_phases = [] + for op in region.ops: + if isinstance(op, GlobalPhaseOp): + global_phases.append(op) + + if len(global_phases) < 2: + continue + + prev = global_phases[0] + phi_sum = prev.operands[0] + for current in global_phases[1:]: + phi = current.operands[0] + addOp = arith.AddfOp(phi, phi_sum) + rewriter.insert_op(addOp, InsertPoint.before(current)) + phi_sum = addOp.result + + rewriter.erase_op(prev) + prev = current + + prev.operands[0].replace_by_if(phi_sum, lambda use: use.operation == prev) + rewriter.notify_op_modified(prev) + + +@dataclass(frozen=True) +class CombineGlobalPhasesPass(passes.ModulePass): + """Pass that combines all global phases within a region into the last global phase operation + within the region. + """ + + name = "combine-global-phases" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the combine-global-phases pass.""" + pattern_rewriter.PatternRewriteWalker( + CombineGlobalPhasesPattern(), + apply_recursively=False, + ).rewrite_module(op) + + +combine_global_phases_pass = compiler_transform(CombineGlobalPhasesPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py b/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py new file mode 100644 index 0000000000..0e07f013fa --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py @@ -0,0 +1,161 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the diagonalize_final_measurements transform, +written using xDSL. + + +Known Limitations +----------------- + * Only observables PauliX, PauliY, PauliZ, Hadamard and Identity are currently supported when + using this transform (but these are also the only observables currently supported in the + Quantum dialect as NamedObservable). + * Unlike the current tape-based implementation of the transform, it doesn't allow for + diagonalization of a subset of observables. + * Unlike the current tape-based implementation of the transform, conversion to measurements + based on eigvals and wires (rather than the PauliZ observable) is not currently supported. + * Unlike the tape-based implementation, this pass will NOT raise an error if given a circuit + that is invalid because it contains non-commuting measurements. It should be assumed that + this transform results in incorrect outputs unless split_non_commuting is applied to break + non-commuting measurements into separate tapes. +""" + +from dataclasses import dataclass + +from pennylane.ops import Hadamard, PauliX, PauliY +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + GlobalPhaseOp, + MultiRZOp, + NamedObservable, + NamedObservableAttr, + NamedObsOp, + QubitUnitaryOp, +) + +from ...pass_api import compiler_transform + + +def _generate_mapping(): + _gate_map = {} + _params_map = {} + + for op in PauliX(0), PauliY(0), Hadamard(0): + diagonalizing_gates = op.diagonalizing_gates() + + _gate_map[op.name] = [gate.name for gate in diagonalizing_gates] + _params_map[op.name] = [gate.data for gate in diagonalizing_gates] + + return _gate_map, _params_map + + +_gate_map, _params_map = _generate_mapping() + + +def _diagonalize(obs: NamedObsOp) -> bool: + """Whether to diagonalize a given observable.""" + if obs.type.data in {"PauliZ", "Identity"}: + return False + if obs.type.data in _gate_map: + return True + raise NotImplementedError(f"Observable {obs.type.data} is not supported for diagonalization") + + +class DiagonalizeFinalMeasurementsPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for diagonalizing final measurements.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, observable: NamedObsOp, rewriter: pattern_rewriter.PatternRewriter, / + ): + """Replace non-diagonalized observables with their diagonalizing gates and PauliZ.""" + + if _diagonalize(observable): + + diagonalizing_gates = _gate_map[observable.type.data] + params = _params_map[observable.type.data] + + qubit = observable.qubit + + insert_point = InsertPoint.before(observable) + + for name, op_data in zip(diagonalizing_gates, params): + if op_data: + param_ssa_values = [] + for param in op_data: + paramOp = arith.ConstantOp( + builtin.FloatAttr(data=param, type=builtin.Float64Type()) + ) + rewriter.insert_op(paramOp, insert_point) + param_ssa_values.append(paramOp.results[0]) + + gate = CustomOp(in_qubits=qubit, gate_name=name, params=param_ssa_values) + else: + gate = CustomOp(in_qubits=qubit, gate_name=name) + + rewriter.insert_op(gate, insert_point) + + qubit = gate.out_qubits[0] + + # we need to replace the initial qubit use everwhere EXCEPT the use that is now the + # input to the first diagonalizing gate. Its not enough to only change the NamedObsOp, + # because the qubit might be inserted/deallocated later + uses_to_change = [ + use + for use in observable.qubit.uses + if not isinstance( + use.operation, (CustomOp, GlobalPhaseOp, MultiRZOp, QubitUnitaryOp) + ) + ] + num_observables = len( + [use for use in uses_to_change if isinstance(use.operation, NamedObsOp)] + ) + + if num_observables > 1: + raise RuntimeError( + "Each wire can only have one set of diagonalizing gates applied, but the " + "circuit contains multiple observables with the same wire." + ) + + observable.qubit.replace_by_if(qubit, lambda use: use in uses_to_change) + for use in uses_to_change: + rewriter.notify_op_modified(use.operation) + + # then we also update the observable to be in the Z basis. Since this is done with the + # rewriter, we don't need to call `rewriter.notify_modified(observable)` regarding this + diag_obs = NamedObsOp( + qubit=qubit, obs_type=NamedObservableAttr(NamedObservable("PauliZ")) + ) + rewriter.replace_op(observable, diag_obs) + + +@dataclass(frozen=True) +class DiagonalizeFinalMeasurementsPass(passes.ModulePass): + """Pass for diagonalizing final measurements.""" + + name = "diagonalize-final-measurements" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the diagonalize final measurements pass.""" + pattern_rewriter.PatternRewriteWalker(DiagonalizeFinalMeasurementsPattern()).rewrite_module( + op + ) + + +diagonalize_final_measurements_pass = compiler_transform(DiagonalizeFinalMeasurementsPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py b/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py new file mode 100644 index 0000000000..0d6ec6b3a4 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py @@ -0,0 +1,661 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the implementation of the measurements_from_samples transform, +written using xDSL. + +Known Limitations +----------------- + + * Only measurements in the computational basis (or where the observable is a Pauli Z op) are + currently supported; for arbitrary observables we require an equivalent compilation pass of the + diagonalize_measurements transform. + * The compilation pass assumes a static number of shots. + * Usage patterns that are not yet supported with program capture are also not supported in the + compilation pass. For example, operator arithmetic is not currently supported, such as + qml.expval(qml.Y(0) @ qml.X(1)). + * qml.counts() is not supported since the return type/shape is different in PennyLane and + Catalyst. See + https://docs.pennylane.ai/projects/catalyst/en/stable/dev/quick_start.html#measurements + for more information. +""" + +from abc import abstractmethod +from dataclasses import dataclass +from itertools import islice + +import jax +import jax.numpy as jnp +from pennylane.exceptions import CompileError +from xdsl import context, ir, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func, tensor +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.conversion import xdsl_module +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform + + +@dataclass(frozen=True) +class MeasurementsFromSamplesPass(passes.ModulePass): + """Pass that replaces all terminal measurements in a program with a single + :func:`pennylane.sample` measurement, and adds postprocessing instructions to recover the + original measurement. + """ + + name = "measurements-from-samples" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the measurements-from-samples pass.""" + shots = _get_static_shots_value_from_first_device_op(op) + + greedy_applier = pattern_rewriter.GreedyRewritePatternApplier( + [ + ExpvalAndVarPattern(shots), + ProbsPattern(shots), + CountsPattern(shots), + StatePattern(shots), + ] + ) + walker = pattern_rewriter.PatternRewriteWalker(greedy_applier, apply_recursively=False) + walker.rewrite_module(op) + + +measurements_from_samples_pass = compiler_transform(MeasurementsFromSamplesPass) + + +class MeasurementsFromSamplesPattern(RewritePattern): + """Rewrite pattern base class for the ``measurements_from_samples`` transform, which replaces + all terminal measurements in a program with a single :func:`pennylane.sample` measurement, and + adds postprocessing instructions to recover the original measurement. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + def __init__(self, shots: int): + super().__init__() + assert isinstance( + shots, int + ), f"Expected `shots` to be an integer value but got {type(shots).__name__}" + self._shots = shots + + @abstractmethod + def match_and_rewrite(self, op: ir.Operation, rewriter: PatternRewriter, /): + """Abstract method for measurements-from-samples match-and-rewrite patterns.""" + + @classmethod + def get_observable_op(cls, op: quantum.ExpvalOp | quantum.VarianceOp) -> quantum.NamedObsOp: + """Return the observable op (quantum.NamedObsOp) given as an input operand to `op`. + + We assume that `op` is either a quantum.ExpvalOp or quantum.VarianceOp, but this is not + strictly enforced. + + Args: + op (quantum.ExpvalOp | quantum.VarianceOp): The op that uses the observable op. + + Returns: + quantum.NamedObsOp: The observable op. + """ + observable_op = op.operands[0].owner + cls._validate_observable_op(observable_op) + + return observable_op + + @staticmethod + def _validate_observable_op(op: quantum.NamedObsOp): + """Validate the observable op. + + Assert that the op is a quantum.NamedObsOp and check if it is supported in the current + implementation of the measurements-from-samples transform. + + Raises: + NotImplementedError: If the observable is anything but a PauliZ quantum.NamedObsOp. + """ + assert isinstance( + op, quantum.NamedObsOp + ), f"Expected `op` to be a quantum.NamedObsOp, but got {type(op).__name__}" + + if op.type.data != "PauliZ": + raise NotImplementedError( + f"Observable '{op.type.data}' used as input to measurement operation is not " + f"supported for the measurements_from_samples transform; currently only the " + f"PauliZ observable is permitted" + ) + + @staticmethod + def insert_compbasis_op( + in_qubit: ir.SSAValue, ref_op: ir.Operation, rewriter: PatternRewriter + ) -> quantum.ComputationalBasisOp: + """Create and insert a computational-basis op (quantum.ComputationalBasisOp). + + The computation-basis op uses `in_qubit` as its input operand. It is inserted *before* the + given reference operation, `ref_op`, using the supplied `rewriter`. + + Args: + in_qubit (SSAValue): The SSA value used as input to the computational-basis op. + ref_op (Operation): The reference op before which the quantum.ComputationalBasisOp is + inserted. + rewriter (PatternRewriter): The xDSL pattern rewriter. + + Returns: + quantum.ComputationalBasisOp: The inserted computation-basis op. + """ + assert isinstance(in_qubit, ir.SSAValue) and isinstance(in_qubit.type, quantum.QubitType), ( + f"Expected `in_qubit` to be an SSAValue with type quantum.QubitType, but got " + f"{type(in_qubit).__name__}" + ) + + # The input operands are [[qubit, ...], qreg] + compbasis_op = quantum.ComputationalBasisOp( + operands=[in_qubit, None], result_types=[quantum.ObservableType()] + ) + rewriter.insert_op(compbasis_op, insertion_point=InsertPoint.before(ref_op)) + + return compbasis_op + + @staticmethod + def insert_sample_op( + compbasis_op: quantum.ComputationalBasisOp, + shots: int, + n_qubits: int, + rewriter: PatternRewriter, + ) -> quantum.SampleOp: + """Create and insert a sample op (quantum.SampleOp). + + The type of the returned samples array is currently restricted to be static, with shape + (shots, n_qubits). + + The sample op is inserted after the supplied `compbasis_op`. + + Args: + compbasis_op (quantum.ComputationalBasisOp): The computational-basis op used as the + input operand to the sample op. + shots (int): Number of shots (to set the shape of the sample op returned array). + n_qubits (int): Number of qubits (to set the shape of the sample op returned array). + rewriter (PatternRewriter): The xDSL pattern rewriter. + + Returns: + quantum.SampleOp: The inserted sample op. + """ + assert isinstance(compbasis_op, quantum.ComputationalBasisOp), ( + f"Expected `compbasis_op` to be a quantum.ComputationalBasisOp, but got " + f"{type(compbasis_op).__name__}" + ) + + sample_op = quantum.SampleOp( + operands=[compbasis_op.results[0], None, None], + result_types=[builtin.TensorType(builtin.Float64Type(), [shots, n_qubits])], + ) + rewriter.insert_op(sample_op, insertion_point=InsertPoint.after(compbasis_op)) + + return sample_op + + @staticmethod + def get_postprocessing_func_op_from_block_by_name( + block: ir.Block, name: str + ) -> func.FuncOp | None: + """Return the post-processing FuncOp from the given `block` with the given `name`. + + If the block does not contain a FuncOp with the matching name, returns None. + + Args: + block (Block): The xDSL block to search. + name (str): The name of the post-processing FuncOp. + + Returns: + func.FuncOp: The FuncOp with matching name. + None: If no match was found. + """ + for op in block.ops: + if isinstance(op, func.FuncOp) and op.sym_name.data == name: + return op + + return None + + @classmethod + def get_postprocessing_funcs_from_module_and_insert( + cls, + postprocessing_module: builtin.ModuleOp, + matched_op: ir.Operation, + name: str | None = None, + ) -> func.FuncOp: + """Get the post-processing FuncOp from `postprocessing_module` (and any helper functions + also contained in `postprocessing_module`) and insert it (them) immediately after the FuncOp + (in the same block) that contains `matched_op`. + + The post-processing function recovers the original measurement process result from the + samples array. This post-postprocessing function is optionally renamed to `name`, if given. + + Args: + postprocessing_module (builtin.ModuleOp): The MLIR module containing the post-processing + FuncOp. + matched_op (Operation): The reference op, the parent of which is used as the + reference point when inserting the post-processing FuncOp. This is usually the op + matched in the call to match_and_rewrite(). + name (str, optional): The name to assign to the post-processing FuncOp, if given. + + Returns: + func.FuncOp: The inserted post-processing FuncOp. + """ + parent_func_op = matched_op.parent_op() + + assert isinstance(parent_func_op, func.FuncOp), ( + f"Expected parent of matched op '{matched_op}' to be a func.FuncOp, but got " + f"{type(parent_func_op).__name__}" + ) + + # This first op in `postprocessing_module` is the "main" post-processing function + postprocessing_func_op = postprocessing_module.body.ops.first.clone() + assert isinstance(postprocessing_func_op, func.FuncOp), ( + f"Expected the first operator of `postprocessing_module` to be a func.FuncOp but " + f"got {type(postprocessing_func_op).__name__}" + ) + + # The function name from jax.jit is 'main'; rename it here + if name is not None: + postprocessing_func_op.sym_name = builtin.StringAttr(data=name) + + parent_block = parent_func_op.parent + parent_block.insert_op_after(postprocessing_func_op, parent_func_op) + + # Get and insert helper functions, if any + if len(postprocessing_module.body.ops) > 1: + prev_op = postprocessing_func_op + for _op in islice(postprocessing_module.body.ops, 1, None): + helper_op = _op.clone() + parent_block.insert_op_after(helper_op, prev_op) + prev_op = helper_op + + return postprocessing_func_op + + @staticmethod + def insert_constant_int_op( + value: int, + insert_point: InsertPoint, + rewriter: PatternRewriter, + value_type: int = 64, + ) -> arith.ConstantOp: + """Create and insert a constant op with the given integer value. + + The integer value is contained within a rankless, dense tensor. + + Args: + value (int): The integer value. + insert_point (InsertPoint): The insertion point for the constant op. + rewriter (PatternRewriter): The xDSL pattern rewriter. + value_type (int, optional): The integer value type (i.e. number of bits). + Defaults to 64. + + Returns: + arith.ConstantOp: The created constant op. + """ + constant_int_op = arith.ConstantOp( + builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(builtin.IntegerType(value_type), shape=()), data=(value,) + ) + ) + rewriter.insert_op(constant_int_op, insertion_point=insert_point) + + return constant_int_op + + @staticmethod + def get_n_qubits_from_qreg(qreg: ir.SSAValue): + """Get the number of qubits from a qreg SSA value. + + This method walks back through the SSA graph from the qreg until it reaches its root + quantum.alloc op, or alloc-like op (with possibly zero or more quantum.insert ops + in-between), from which the number of qubits is extracted. + + An op is "alloc-like" if it has an 'nqubits_attr' attribute. + + Args: + qreg (SSAValue): The qreg SSA value. + """ + assert isinstance(qreg, ir.SSAValue) and isinstance(qreg.type, quantum.QuregType), ( + f"Expected `qreg` to be an SSAValue with type quantum.QuregType, but got " + f"{type(qreg).__name__}" + ) + + def _walk_back_to_alloc_op( + insert_or_alloc_op: quantum.AllocOp | quantum.InsertOp, + ) -> quantum.AllocOp | None: + """Recursively walk back from a quantum.insert op to its root quantum.alloc op or + alloc-like op. + + Once found, return the quantum.alloc op. + """ + if ( + isinstance(insert_or_alloc_op, quantum.AllocOp) + or insert_or_alloc_op.properties.get("nqubits_attr") is not None + ): + return insert_or_alloc_op + + if isinstance(insert_or_alloc_op, quantum.InsertOp): + return _walk_back_to_alloc_op(insert_or_alloc_op.operands[0].owner) + + return None + + alloc_op = _walk_back_to_alloc_op(qreg.owner) + assert alloc_op is not None, "Unable to walk back from qreg to alloc op" + + nqubits_attr = alloc_op.properties.get("nqubits_attr") + assert ( + nqubits_attr is not None + ), "Unable to determine number of qubits from alloc op; missing property 'nqubits_attr'" + + n_qubits = nqubits_attr.value.data + assert n_qubits is not None, "Unable to determine number of qubits from qreg SSA value" + + return n_qubits + + +class ExpvalAndVarPattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.expval()`` and ``qml.var()`` operations. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, matched_op: quantum.ExpvalOp | quantum.VarianceOp, rewriter: PatternRewriter, / + ): + """Match and rewrite for quantum.ExpvalOp.""" + observable_op = self.get_observable_op(matched_op) + in_qubit = observable_op.operands[0] + compbasis_op = self.insert_compbasis_op(in_qubit, observable_op, rewriter) + sample_op = self.insert_sample_op(compbasis_op, self._shots, 1, rewriter) + + # Insert the post-processing function into current module or get handle to it if already + # inserted + match matched_op: + case quantum.ExpvalOp(): + postprocessing_func_name = f"expval_from_samples.tensor.{self._shots}x1xf64" + postprocessing_jit_func = _postprocessing_expval + case quantum.VarianceOp(): + postprocessing_func_name = f"var_from_samples.tensor.{self._shots}x1xf64" + postprocessing_jit_func = _postprocessing_var + case _: + assert False, ( + f"Expected a quantum.ExpvalOp or quantum.VarianceOp, but got " + f"{type(matched_op).__name__}" + ) + + postprocessing_func_op = self.get_postprocessing_func_op_from_block_by_name( + matched_op.parent_op().parent, postprocessing_func_name + ) + + if postprocessing_func_op is None: + # TODO: Do we have to set the shape of the samples array statically here? Or can the + # shape (shots, wire) be dynamic and given as SSA values? + # Same goes for the column/wire indices (the second argument). + postprocessing_module = postprocessing_jit_func( + jax.core.ShapedArray([self._shots, 1], float), 0 + ) + + postprocessing_func_op = self.get_postprocessing_funcs_from_module_and_insert( + postprocessing_module, matched_op, postprocessing_func_name + ) + + # Insert the op that specifies which column in the samples array to access. + # TODO: This also assumes MP acts on a single wire (hence we always use column 0 of the + # samples here); what if MP acts on multiple wires? + column_index_op = self.insert_constant_int_op( + 0, insert_point=InsertPoint.after(sample_op), rewriter=rewriter + ) + + # Insert the call to the post-processing function + postprocessing_func_call_op = func.CallOp( + callee=builtin.FlatSymbolRefAttr(postprocessing_func_op.sym_name), + arguments=[sample_op.results[0], column_index_op], + return_types=[builtin.TensorType(builtin.Float64Type(), shape=())], + ) + + # The op to replace is not the expval/var op itself, but the tensor.from_elements + # op that follows + op_to_replace = list(matched_op.results[0].uses)[0].operation + assert isinstance( + op_to_replace, tensor.FromElementsOp + ), f"Expected to replace a tensor.from_elements op, but got {type(op_to_replace).__name__}" + rewriter.replace_op(op_to_replace, postprocessing_func_call_op) + + # Finally, erase the expval/var op and its associated observable op + rewriter.erase_matched_op() + rewriter.erase_op(observable_op) + + +class ProbsPattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.probs()`` operations. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, probs_op: quantum.ProbsOp, rewriter: PatternRewriter, /): + """Match and rewrite for quantum.ProbsOp.""" + compbasis_op = probs_op.operands[0].owner + + n_qubits = None + if compbasis_op.qreg is not None: + n_qubits = self.get_n_qubits_from_qreg(compbasis_op.qreg) + + elif not compbasis_op.qubits == (): + n_qubits = len(compbasis_op.qubits) + + assert ( + n_qubits is not None + ), "Unable to determine number of qubits from quantum.compbasis op" + + sample_op = self.insert_sample_op(compbasis_op, self._shots, n_qubits, rewriter) + + # Insert the post-processing function into current module or + # get handle to it if already inserted + postprocessing_func_name = f"probs_from_samples.tensor.{self._shots}x{n_qubits}xf64" + + postprocessing_func_op = self.get_postprocessing_func_op_from_block_by_name( + probs_op.parent_op().parent, postprocessing_func_name + ) + + if postprocessing_func_op is None: + # TODO: Do we have to set the shape of the samples array statically here? Or can the + # shape (shots, wire) be dynamic and given as SSA values? + # Same goes for the column/wire indices (the second argument). + postprocessing_module = _postprocessing_probs( + jax.core.ShapedArray([self._shots, n_qubits], float) + ) + + postprocessing_func_op = self.get_postprocessing_funcs_from_module_and_insert( + postprocessing_module, probs_op, postprocessing_func_name + ) + + # Insert the call to the post-processing function + postprocessing_func_call_op = func.CallOp( + callee=builtin.FlatSymbolRefAttr(postprocessing_func_op.sym_name), + arguments=[sample_op.results[0]], + return_types=[builtin.TensorType(builtin.Float64Type(), shape=(2**n_qubits,))], + ) + + rewriter.replace_matched_op(postprocessing_func_call_op) + + +class CountsPattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.counts()`` operations. + + Currently there is no plan to support ``qml.counts()`` for this transform. It is included for + completeness and to notify users that workloads containing ``counts`` measurement processes are + not supported with the measurements-from-samples transform. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, counts_op: quantum.CountsOp, rewriter: PatternRewriter, /): + """Match and rewrite for quantum.CountsOp.""" + raise NotImplementedError("qml.counts() operations are not supported.") + + +class StatePattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.state()`` operations. + + It is not possible to recover a quantum state from samples; this pattern is included for + completeness and to notify users that workloads containing ``state`` measurement processes are + not supported with the measurements-from-samples transform. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, state_op: quantum.StateOp, rewriter: PatternRewriter, /): + """Match and rewrite for quantum.StateOp.""" + raise NotImplementedError("qml.state() operations are not supported.") + + +def _get_static_shots_value_from_first_device_op(module: builtin.ModuleOp) -> int: + """Returns the number of shots as a static (i.e. known at compile time) integer value from the + first instance of a device-initialization op (quantum.DeviceInitOp) found in `module`. + + If `module` contains multiple quantum.DeviceInitOp ops, only the number of shots from the + *first* instance is used, and the others are ignored. + + This function expects the number of shots to be an SSA value given as an operand to the + quantum.DeviceInitOp op. It also assumes that the number of shots is static, retrieving it from + the 'value' attribute of its corresponding constant op. + + Args: + module (builtin.ModuleOp): The MLIR module containing the quantum.DeviceInitOp. + + Returns: + int: The number of shots. + + Raises: + CompileError: If `module` does not contain a quantum.DeviceInitOp. + """ + device_op = None + + for op in module.body.walk(): + if isinstance(op, quantum.DeviceInitOp): + device_op = op + break + + if device_op is None: + raise CompileError( + "Cannot get number of shots; the module does not contain a quantum.DeviceInitOp" + ) + + # The number of shots is passed as an SSA value operand to the DeviceInitOp + shots_operand = device_op.shots + shots_extract_op = shots_operand.owner + + if isinstance(shots_extract_op, tensor.ExtractOp): + shots_constant_op = shots_extract_op.operands[0].owner + shots_value_attribute: builtin.DenseIntOrFPElementsAttr = shots_constant_op.properties.get( + "value" + ) + if shots_value_attribute is None: + raise ValueError("Cannot get number of shots; the constant op has no 'value' attribute") + + shots_int_values = shots_value_attribute.get_values() + if len(shots_int_values) != 1: + raise ValueError(f"Expected a single shots value, got {len(shots_int_values)}") + + return shots_int_values[0] + + if isinstance(shots_extract_op, arith.ConstantOp): + shots_value_attribute: builtin.IntAttr = shots_extract_op.properties.get("value") + return shots_value_attribute.value.data + + raise ValueError( + f"Expected owner of shots operand to be a tensor.ExtractOp or arith.ConstantOp but got " + f"{type(shots_extract_op).__name__}" + ) + + +@xdsl_module +@jax.jit +def _postprocessing_expval(samples, column): + """Post-processing to recover the expectation value from the given `samples` array for each + requested `column` in the array. + + This function assumes that the samples are in the computational basis (0s and 1s) and that the + observable operand of the expectation value has eigenvalues +1 and -1. + + Args: + samples (jax.core.ShapedArray): Array of samples, with shape (shots, wires). + column (int, jax.core.ShapedArray): Column index (or indices) of the `samples` array over + which the expectation value is computed. + + Returns: + jax.core.ShapedArray: The expectation value for each requested column. + """ + return jnp.mean(1.0 - 2.0 * samples[:, column], axis=0) + + +@xdsl_module +@jax.jit +def _postprocessing_var(samples, column): + """Post-processing to recover the variance from the given `samples` array for each requested + `column` in the array. + + This function assumes that the samples are in the computational basis (0s and 1s) and that the + observable operand of the variance has eigenvalues +1 and -1. + + Args: + samples (jax.core.ShapedArray): Array of samples, with shape (shots, wires). + column (int, jax.core.ShapedArray): Column index (or indices) of the `samples` array over + which the variance is computed. + + Returns: + jax.core.ShapedArray: The variance for each requested column. + """ + return jnp.var(1.0 - 2.0 * samples[:, column], axis=0) + + +@xdsl_module +@jax.jit +def _postprocessing_probs(samples): + """Post-processing to recover the probability values from the given `samples` array. + + This function assumes that the samples are in the computational basis (0s and 1s). + + Args: + samples (jax.core.ShapedArray): Array of samples, with shape (shots, wires). + """ + n_samples = samples.shape[0] + n_wires = samples.shape[1] + + # Convert samples from a list of 0, 1 integers to base 10 representation + powers_of_two = 2 ** jnp.arange(n_wires)[::-1] + indices = samples @ powers_of_two + dim = 2**n_wires + + # This block is effectively equivalent to `jnp.bincount(indices.astype(int), length=dim)`. + # However, we are currently not able to use jnp.bincount with Catalyst because after lowering, + # it contains a stablehlo.scatter op with , + # which we currently do not support. + # If Catalyst PR https://github.com/PennyLaneAI/catalyst/pull/1849 is merged, then we should be + # able to use bincount. + counts = jnp.zeros(dim, dtype=int) + for i in indices.astype(int): + counts = counts.at[i].add(1) + + return counts / n_samples diff --git a/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py b/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py new file mode 100644 index 0000000000..1f4a553bc8 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py @@ -0,0 +1,139 @@ +# Copyright 2018-2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the merge_rotations transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func +from xdsl.ir import Operation +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.quantum import CustomOp +from catalyst.python_interface.pass_api import compiler_transform + +# Can handle all composible rotations except Rot... for now +composable_rotations = [ + "RX", + "RY", + "RZ", + "PhaseShift", + "CRX", + "CRY", + "CRZ", + "ControlledPhaseShift", + "IsingXX", + "IsingYY", + "IsingXY", + "IsingZZ", + "MultiRZ", + "SingleExcitation", + "DoubleExcitation", + "SingleExcitationMinus", + "SingleExcitationPlus", + "DoubleExcitationMinus", + "DoubleExcitationPlus", + "OrbitalRotation", +] + + +def _can_merge(op: CustomOp, next_op: Operation) -> bool: + if isinstance(next_op, CustomOp): + if op.gate_name.data == next_op.gate_name.data: + if ( + op.out_qubits == next_op.in_qubits + and op.out_ctrl_qubits == next_op.in_ctrl_qubits + and op.in_ctrl_values == next_op.in_ctrl_values + ): + return True + + return False + + +class MergeRotationsPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for merging consecutive composable rotations.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, /): + """Implementation of rewriting FuncOps that may contain operations corresponding to + consecutive composable rotations.""" + for op in funcOp.body.walk(): + if not isinstance(op, CustomOp): + continue + + gate_name = op.gate_name.data + if gate_name not in composable_rotations: + continue + + param = op.operands[0] + while True: + for use in op.results[0].uses: + user = use.operation + if _can_merge(op, user): + next_user = user + break + else: + # No next_user was set because no user could be merged. Go to next op + break + + for q1, q2 in zip(op.in_qubits, op.out_qubits, strict=True): + rewriter.replace_all_uses_with(q2, q1) + for cq1, cq2 in zip(op.in_ctrl_qubits, op.out_ctrl_qubits, strict=True): + rewriter.replace_all_uses_with(cq2, cq1) + + # Whether op is adjoint will determine adjoint of new_op, and how to combine angles + is_adjoint = getattr(op, "adjoint", False) + rewriter.erase_op(op) + + next_param = next_user.operands[0] + next_is_adjoint = getattr(next_user, "adjoint", False) + if is_adjoint == next_is_adjoint: + # If both op and next_user are adjoint, or neither is, we add angles + combOp = arith.AddfOp(param, next_param) + else: + # If one of op and next_user is adjoint, we subtract angles + combOp = arith.SubfOp(param, next_param) + rewriter.insert_op(combOp, InsertPoint.before(next_user)) + param = combOp.result + + new_op = CustomOp( + in_qubits=next_user.in_qubits, + gate_name=next_user.gate_name, + params=(param,), + in_ctrl_qubits=next_user.in_ctrl_qubits, + in_ctrl_values=next_user.in_ctrl_values, + adjoint=is_adjoint, + ) + rewriter.replace_op(next_user, new_op) + op = new_op + + +@dataclass(frozen=True) +class MergeRotationsPass(passes.ModulePass): + """Pass for merging consecutive composable rotation gates.""" + + name = "xdsl-merge-rotations" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the merge rotations pass.""" + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()]) + ).rewrite_module(op) + + +merge_rotations_pass = compiler_transform(MergeRotationsPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py b/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py new file mode 100644 index 0000000000..3cd8a47039 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py @@ -0,0 +1,506 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +"""This file contains a limited prototype of the split_non_commuting pass. + +Known Limitations +----------------- + + * Only single-term observables with no coefficients are supported - there is no support + for CompositeOp or SymbolicOp observables + * Only the Expval measurement process is supported + * There is no option to specify a grouping strategy (this will be more relevant once + CompositeOp support is added) + * Hence, only the "wires" grouping strategy is implemented, not taking into account + observable-commutation logic yet. + * There is no efficient handling of duplicate observables - a circuit that returns + multiple measurements on the same observable will split into multiple executions + (this will be more relevant once CompositeOp support is added) + +Example: +------------------ +For the following IR: +``` +func.func public circ(...) { + ... + %reg0 = quantum.insert ... + %reg1 = quantum.insert ... + %0 = func.call @circ.state_evolution(%reg0, ...) + %q0 = quantum.extract ... + %q1 = quantum.extract ... + ... + %1 = quantum.expval %q0 + %2 = quantum.expval %q1 + ... + func.return %1, %2 +} +``` + +We want to split the function into two functions based on the wire-based grouping strategy. +``` +func.func public circ(...) { + %0 = func.call @circ.dup0(...) + %1 = func.call @circ.dup1(...) + func.return %0, %1 +} +func.func circ.dup0(...) { + ... + %reg0 = quantum.insert ... + %reg1 = quantum.insert ... + %0 = func.call @circ.state_evolution(%reg1, ...) + %q0 = quantum.extract ... + ... + %1 = quantum.expval %q0 + return %1 +} +func.func circ.dup1(...) { + ... + %reg1 = quantum.insert ... + %0 = func.call @circ.state_evolution(%reg1, ...) + %q1 = quantum.extract ... + ... + %1 = quantum.expval %q1 + return %1 +} +func.func circ.state_evolution(%reg, ...) { + %q0 = quantum.extract ... + %q1 = quantum.extract ... + ... + %new_reg0 = quantum.insert ... + %new_reg1 = quantum.insert ... + return %new_reg1 +} +``` + +reference: https://docs.pennylane.ai/en/stable/code/api/pennylane.transforms.split_non_commuting.html +""" + +from dataclasses import dataclass +from typing import Type, TypeVar + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func +from xdsl.ir import Operation, SSAValue +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform + + +@dataclass(frozen=True) +class SplitNonCommutingPass(passes.ModulePass): + """Pass that splits quantum functions measuring non-commuting observables. + + This pass groups measurements using the "wires" grouping strategy and splits + the function into multiple executions, one per group of measurements. + """ + + name = "split-non-commuting" + + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: + """Apply the split non-commuting pass to all QNode functions in the module.""" + for op_ in op.ops: + if isinstance(op_, func.FuncOp) and "qnode" in op_.attributes: + rewriter = pattern_rewriter.PatternRewriter(op_) + SplitNonCommutingPattern().match_and_rewrite(op_, rewriter) + + +split_non_commuting_pass = compiler_transform(SplitNonCommutingPass) + + +class SplitNonCommutingPattern(pattern_rewriter.RewritePattern): + """Pattern that splits a quantum function into multiple functions based on wire-based grouping. + + Measurements acting on different wires are grouped together, while measurements + acting on the same wire are split into separate groups. + """ + + def __init__(self): + self.module: builtin.ModuleOp = None + + T = TypeVar("T") + + def get_parent_of_type(self, op: Operation, kind: Type[T]) -> T | None: + """Walk up the parent tree until an op of the specified type is found.""" + while (op := op.parent_op()) and not isinstance(op, kind): + pass + if not isinstance(op, kind): + raise ValueError(f"get_parent_of_type: expected {kind} but got {type(op)}, op: {op}") + return op + + def clone_operations_to_block(self, ops_to_clone, target_block, value_mapper): + """Clone operations to target block, use value_mapper to update references""" + for op in ops_to_clone: + cloned_op = op.clone(value_mapper) + target_block.add_op(cloned_op) + + def create_dup_function( + self, func_op: func.FuncOp, i: int, rewriter: pattern_rewriter.PatternRewriter + ): + """Create a new function for the dup region by fully cloning the original function.""" + # Use the same signature as the original function + original_func_type = func_op.function_type + input_types = list(original_func_type.inputs.data) + output_types = list(original_func_type.outputs.data) + fun_type = builtin.FunctionType.from_lists(input_types, output_types) + + dup_func = func.FuncOp(func_op.sym_name.data + ".dup." + str(i), fun_type) + rewriter.insert_op(dup_func, InsertPoint.at_end(self.module.body.block)) + + # Map original function arguments to dup function arguments + dup_block = dup_func.regions[0].block + orig_block = func_op.body.block + value_mapper = {} + for orig_arg, dup_arg in zip(orig_block.args, dup_block.args): + value_mapper[orig_arg] = dup_arg + + # Clone all operations except the return statement + ops_to_clone = [] + return_op = None + for op in orig_block.ops: + if isinstance(op, func.ReturnOp): + return_op = op + else: + ops_to_clone.append(op) + + # Clone operations + self.clone_operations_to_block(ops_to_clone, dup_block, value_mapper) + + # Clone the return statement + if return_op: + return_values = [value_mapper.get(val, val) for val in return_op.operands] + new_return_op = func.ReturnOp(*return_values) + dup_block.add_op(new_return_op) + + # Remove expvals from other groups and update return statement + self.remove_group(dup_func, i) + + return dup_func + + def remove_group(self, dup_func: func.FuncOp, target_group: int): + """Remove measurement operations from other groups and update return statement.""" + # Find the return operation in the dup function + return_op = list(dup_func.body.ops)[-1] + + return_values_to_remove = set[SSAValue]() + for operand in return_op.operands: + group_id = self.find_group_for_return_value(operand) + if group_id != target_group: + return_values_to_remove.add(operand) + + # collect all operations to remove + remove_ops = list[Operation]([value.owner for value in return_values_to_remove]) + + # update return statement + self.update_return_statement(dup_func, return_values_to_remove) + + # remove operations + while remove_ops: + op = remove_ops.pop(0) + users = [use.operation for result in list(op.results) for use in list(result.uses)] + + # if the operation has users, skip it + if len(users) > 0: + continue + + if not self.is_observable_op(op): + # keep walking up the chain + for operand in op.operands: + if operand not in remove_ops: + remove_ops.append(operand.owner) + + op.detach() + op.erase() + + def update_return_statement(self, func_op: func.FuncOp, values_to_remove: set[SSAValue]): + """Update the return statement to remove specified values.""" + # Find the return operation + return_op = None + for op in func_op.body.ops: + if isinstance(op, func.ReturnOp): + return_op = op + break + + if not return_op: + return + + # Filter out values to remove + new_return_values = [val for val in return_op.operands if val not in values_to_remove] + + # Create new return operation + new_return_op = func.ReturnOp(*new_return_values) + + # Replace the old return operation + return_op.detach() + return_op.erase() # Important: erase to remove operand uses + func_op.body.block.add_op(new_return_op) + + # Update function signature + new_output_types = [val.type for val in new_return_values] + input_types = [arg.type for arg in func_op.body.block.args] + new_fun_type = builtin.FunctionType.from_lists(input_types, new_output_types) + func_op.function_type = new_fun_type + + def is_measurement_op(self, op: Operation) -> bool: + """Check if an operation is a measurement operation.""" + # TODO: support more measurement operations + if isinstance(op, quantum.ExpvalOp): + return True + if isinstance(op, (quantum.VarianceOp, quantum.ProbsOp, quantum.SampleOp)): + raise NotImplementedError( + f"measurement operations other than expval are not supported: {op}" + ) + return False + + def is_observable_op(self, op: Operation) -> bool: + """Check if an operation is an observable operation.""" + if isinstance( + op, + ( + quantum.NamedObsOp, + quantum.ComputationalBasisOp, + quantum.HamiltonianOp, + quantum.TensorOp, + ), + ): + return True + return False + + def calculate_num_groups(self, func_op: func.FuncOp) -> int: + """Calculate the number of groups using the "wires" grouping strategy. + + This function groups measurements based on wire overlaps only, disregarding + the actual commutation relations between observables. Measurements acting on + different wires are grouped together, while measurements acting on the same + wire are split into different groups. + + The function also stores the group ID in the "group" attribute of each + measurement operation, which is later used to handle the splitting mechanics. + + Args: + func_op: The function operation containing measurements to group. + + Returns: + The number of groups created. + """ + # Find all measurement operations in the current function + measurement_ops = [op for op in func_op.body.ops if self.is_measurement_op(op)] + + # For each measurement operation, find the qubits the operation acts on + op_to_acted_qubits: dict[Operation, set[SSAValue]] = { + measurement_op: set() for measurement_op in measurement_ops + } + + for measurement_op in measurement_ops: + observable = measurement_op.operands[0] + op_to_acted_qubits[measurement_op].update(self.get_qubits_from_observable(observable)) + + # Group measurements: operations on different qubits can be in the same group + # Operations on the same qubit must be in different groups + groups: list[dict[Operation, set[SSAValue]]] = [] # Each group stores op -> qubits mapping + + for measurement_op, qubits in op_to_acted_qubits.items(): + if len(qubits) > 1: + raise NotImplementedError("operations acting on multiple qubits are not supported") + + # Find a group where no operation acts on any of the same qubits + assigned_group_id = None + + for group_id, group in enumerate(groups): + # Get all qubits already used in this group + used_qubits = set() + for group_qubits in group.values(): + used_qubits.update(group_qubits) + + # Check if this measurement's qubits conflict with the group + if not qubits.intersection(used_qubits): + # No conflict - can add to this group + group[measurement_op] = qubits + assigned_group_id = group_id + break + + # If no suitable group found, create a new one + if assigned_group_id is None: + assigned_group_id = len(groups) + groups.append({measurement_op: qubits}) + + # Tag the measurement operation with the group attribute + measurement_op.attributes["group"] = builtin.IntegerAttr( + assigned_group_id, builtin.IntegerType(64) + ) + + return len(groups) + + def get_qubits_from_observable(self, observable: SSAValue) -> set[SSAValue] | None: + """Get the qubit used by an observable operation. + + Traces back from an observable to find the qubit it operates on. + Handles NamedObsOp, ComputationalBasisOp, HamiltonianOp, and TensorOp. + """ + assert observable.owner is not None, "observable should have an owner" + + acted_qubits = set[SSAValue]() + + obs_op = observable.owner + + # For NamedObsOp, the first operand is the qubit + if isinstance(obs_op, (quantum.NamedObsOp)): + acted_qubits.add(obs_op.operands[0]) + + # For other observable operations, we need to handle multiple qubits + elif isinstance( + obs_op, (quantum.HamiltonianOp, quantum.TensorOp, quantum.ComputationalBasisOp) + ): + raise NotImplementedError(f"unsupported observable operation: {obs_op}") + + return acted_qubits + + def analyze_group_return_positions( + self, func_op: func.FuncOp, num_groups: int + ) -> dict[int, list[int]]: + """Analyze which return value positions belong to each group. + + Returns a dict mapping group_id -> list of final return value positions + Example: {0: [0, 2], 1: [1]} for + return qml.expval(qml.X(0)), qml.expval(qml.X(1)), qml.expval(qml.Y(0)) + """ + # Find the return operation + return_op = list(func_op.body.ops)[-1] + + # For each return value, trace back to find its group + group_positions = {i: [] for i in range(num_groups)} + + for position, return_value in enumerate(return_op.operands): + # Trace back to find the expval operation + group_id = self.find_group_for_return_value(return_value) + if group_id is not None: + group_positions[group_id].append(position) + + return group_positions + + def find_group_for_return_value(self, return_value: SSAValue) -> int | None: + """Trace back from a return value to find which group's expval produced it.""" + # BFS backward to find expval + to_check = [return_value] + checked = set() + + while to_check: + val = to_check.pop(0) + if val in checked: + continue + checked.add(val) + + op = val.owner + + # If we found a measurement operation, check its group + if self.is_measurement_op(op) and "group" in op.attributes: + group_attr = op.attributes["group"] + return group_attr.value.data + + # Otherwise, check operands + to_check.extend([operand for operand in op.operands if operand not in checked]) + + return None + + def replace_original_with_calls( + self, + func_op: func.FuncOp, + dup_functions: list[func.FuncOp], + group_return_positions: dict[int, list[int]], + ): + """Replace original function body with calls to dup functions. + + Args: + dup_functions: List of duplicate functions (one per group) + group_return_positions: Dict mapping group_id -> list of return positions + """ + original_block = func_op.body.block + + for op in reversed(func_op.body.ops): + op.detach() + op.erase() + + # Collect parameters needed for dup function calls + # Dup functions take the same parameters as the original begin/end region + # Look at what original function was using and find corresponding values + call_args = list(original_block.args) # Use function arguments as base + + group_results = dict[int, list[SSAValue]]() # group_id -> list of result values + + for group_id, dup_func in enumerate(dup_functions): + # Get the function signature to determine result types + func_type = dup_func.function_type + result_types = list(func_type.outputs.data) + + # Create the call operation + call_op = func.CallOp(dup_func.sym_name.data, call_args, result_types) + original_block.add_op(call_op) + + # Store results for this group + group_results[group_id] = list(call_op.results) + + # Reconstruct the return statement in the original order + # Calculate total number of return values + total_returns = sum(len(positions) for positions in group_return_positions.values()) + final_return_values = [None] * total_returns + + for group_id, positions in group_return_positions.items(): + group_vals = group_results[group_id] + assert len(group_vals) == len( + positions + ), "number of group values and positions must match" + + for i, position in enumerate(positions): + final_return_values[position] = group_vals[i] + + assert all( + v is not None for v in final_return_values + ), "final return values should not be None" + + # Create new return operation + return_op = func.ReturnOp(*final_return_values) + original_block.add_op(return_op) + + def match_and_rewrite( + self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, / + ): + """Split a quantum function into multiple functions using wire-based grouping. + + Creates one duplicate function per group, where each duplicate function contains + only the measurements from that group. The original function is replaced with + calls to these duplicate functions, and the results are combined in the original + return order. + + Args: + func_op: The function operation to split. + rewriter: The pattern rewriter for creating new operations. + """ + self.module = self.get_parent_of_type(func_op, builtin.ModuleOp) + assert self.module is not None, "got orphaned qnode function" + + # Calculate the number of groups using wires-based grouping strategy + num_groups = self.calculate_num_groups(func_op) + + # Analyze return value positions for each group + group_return_positions = self.analyze_group_return_positions(func_op, num_groups) + + # Create dup function for each group + dup_functions = [] + for i in range(num_groups): + dup_func = self.create_dup_function(func_op, i, rewriter) + dup_functions.append(dup_func) + + # Replace original function body with calls to dup functions + self.replace_original_with_calls(func_op, dup_functions, group_return_positions) diff --git a/frontend/catalyst/python_interface/utils.py b/frontend/catalyst/python_interface/utils.py new file mode 100644 index 0000000000..9db8e39008 --- /dev/null +++ b/frontend/catalyst/python_interface/utils.py @@ -0,0 +1,77 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General purpose utilities to use with xDSL.""" + +from numbers import Number + +from xdsl.dialects.arith import ConstantOp as arithConstantOp +from xdsl.dialects.builtin import ComplexType, ShapedType +from xdsl.dialects.tensor import ExtractOp as tensorExtractOp +from xdsl.ir import SSAValue + +from catalyst.python_interface.dialects.stablehlo import ConstantOp as hloConstantOp + + +def get_constant_from_ssa(value: SSAValue) -> Number | None: + """Return the concrete value corresponding to an SSA value if it is a numerical constant. + + .. note:: + + This function currently only returns constants if they are scalar. For non-scalar + constants, ``None`` will be returned. + + Args: + value (xdsl.ir.SSAValue): the SSA value to check + + Returns: + Number or None: If the value corresponds to a constant, its concrete value will + be returned, else ``None``. + """ + + # If the value has a shape, we can assume that it is not scalar. We check + # this because constant-like operations can return container types. This includes + # arith.constant, which may return containers, and stablehlo.constant, which + # always returns a container. + if not isinstance(value.type, ShapedType): + owner = value.owner + + if isinstance(owner, arithConstantOp): + const_attr = owner.value + return const_attr.value.data + + # Constant-like operations can also create scalars by returning rank 0 tensors. + # In this case, the owner of a scalar value should be a tensor.extract, which + # uses the aforementioned rank 0 constant tensor as input. + if isinstance(owner, tensorExtractOp): + tensor_ = owner.tensor + if ( + len(owner.indices) == 0 + and len(tensor_.type.shape) == 0 + and isinstance(tensor_.owner, (arithConstantOp, hloConstantOp)) + ): + dense_attr = tensor_.owner.value + # We know that the tensor has shape (). Dense element attributes store + # their data as a sequence. For a scalar, this will be a sequence with + # a single element. + val = dense_attr.get_values()[0] + if isinstance(tensor_.type.element_type, ComplexType): + # If the dtype is complex, the value will be a 2-tuple containing + # the real and imaginary components of the number rather than a + # Python complex number + val = val[0] + 1j * val[1] + + return val + + return None diff --git a/frontend/catalyst/python_interface/visualization/__init__.py b/frontend/catalyst/python_interface/visualization/__init__.py new file mode 100644 index 0000000000..0a550976e5 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Visualization functions for Catalyst and xDSL transformations. +""" + + +from .collector import QMLCollector +from .draw import draw +from .mlir_graph import generate_mlir_graph + +__all__ = ["QMLCollector", "draw", "generate_mlir_graph"] diff --git a/frontend/catalyst/python_interface/visualization/collector.py b/frontend/catalyst/python_interface/visualization/collector.py new file mode 100644 index 0000000000..d2c355e0bf --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/collector.py @@ -0,0 +1,146 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the QMLCollector class, +which collects and maps PennyLane operations and measurements from xDSL.""" + +from functools import singledispatchmethod + +import xdsl +from pennylane.measurements import MeasurementProcess +from pennylane.operation import Operator +from xdsl.dialects import builtin, func +from xdsl.ir import SSAValue + +from catalyst.python_interface.dialects.quantum import ( + AllocOp, + CustomOp, + ExpvalOp, + ExtractOp, + GlobalPhaseOp, + MeasureOp, + MultiRZOp, + ProbsOp, + QubitUnitaryOp, + SampleOp, + SetBasisStateOp, + SetStateOp, + StateOp, + VarianceOp, +) + +from .xdsl_conversion import dispatch_wires_extract, xdsl_to_qml_measurement, xdsl_to_qml_op + + +class QMLCollector: + """Collects PennyLane ops and measurements from an xDSL module. + + Walks all `FuncOp`s in the given module, building a mapping of SSA qubits to wire indices, + and converting supported xDSL operations and measurements to PennyLane objects. + """ + + def __init__(self, module: builtin.ModuleOp): + self.module = module + self.wire_to_ssa_qubits: dict[int, SSAValue] = {} + self.quantum_register: SSAValue | None = None + + @singledispatchmethod + def handle(self, xdsl_op: xdsl.ir.Operation) -> None: + """Default handler for unsupported operations.""" + if len(xdsl_op.regions) > 0: + raise NotImplementedError("xDSL operations with regions are not yet supported.") + + ############################################################ + ### Measurements + ############################################################ + + @handle.register + def _(self, xdsl_meas: StateOp) -> MeasurementProcess: + return xdsl_to_qml_measurement(xdsl_meas) + + @handle.register + def _(self, xdsl_meas_op: ExpvalOp | VarianceOp | ProbsOp | SampleOp) -> MeasurementProcess: + obs_op = xdsl_meas_op.obs.owner + return xdsl_to_qml_measurement(xdsl_meas_op, xdsl_to_qml_measurement(obs_op)) + + @handle.register + def _(self, xdsl_measure: MeasureOp) -> MeasurementProcess: + return xdsl_to_qml_measurement(xdsl_measure) + + ############################################################ + ### Operators + ############################################################ + + @handle.register + def _( + self, + xdsl_op: ( + CustomOp | GlobalPhaseOp | QubitUnitaryOp | SetStateOp | MultiRZOp | SetBasisStateOp + ), + ) -> Operator: + if self.quantum_register is None: + raise ValueError("Quantum register (AllocOp) not found.") + if not self.wire_to_ssa_qubits: + raise NotImplementedError("No wires extracted from the register found.") + return xdsl_to_qml_op(xdsl_op) + + ############################################################ + ### Internal Methods + ############################################################ + + # TODO: this will probably no longer be needed once PR #7937 is merged + def _process_qubit_mapping(self, op): + """Populate wire mappings from AllocOp and ExtractOp.""" + + if isinstance(op, AllocOp): + if self.quantum_register is not None: + raise ValueError("Found more than one AllocOp for this FuncOp.") + self.quantum_register = op.qreg + + elif isinstance(op, ExtractOp): + wire = dispatch_wires_extract(op) + if wire not in self.wire_to_ssa_qubits: + self.wire_to_ssa_qubits[wire] = op.qubit + + def clear_mappings(self): + """Clear all wire and parameter mappings.""" + + self.wire_to_ssa_qubits.clear() + self.quantum_register = None + + def collect(self, reset: bool = True) -> tuple[list[Operator], list[MeasurementProcess]]: + """Collect PennyLane ops and measurements from the module.""" + + if reset: + self.clear_mappings() + + collected_ops: list[Operator] = [] + collected_meas: list[MeasurementProcess] = [] + + for func_op in self.module.body.ops: + + if not isinstance(func_op, func.FuncOp): + continue + + for op in func_op.body.ops: + + self._process_qubit_mapping(op) + result = self.handle(op) + + if isinstance(result, Operator): + collected_ops.append(result) + + elif isinstance(result, MeasurementProcess): + collected_meas.append(result) + + return collected_ops, collected_meas diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/visualization/draw.py new file mode 100644 index 0000000000..ac26977404 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/draw.py @@ -0,0 +1,105 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the `draw` function for the Unified Compiler.""" + +from __future__ import annotations + +import warnings +from functools import wraps +from typing import TYPE_CHECKING + +from pennylane.tape import QuantumScript + +from catalyst import qjit +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.compiler import Compiler + +from .collector import QMLCollector + +if TYPE_CHECKING: + from pennylane.typing import Callable + from pennylane.workflow.qnode import QNode + from xdsl.dialects.builtin import ModuleOp + +# TODO: This caching mechanism should be improved, +# because now it relies on a mutable global state +_cache_store: dict[Callable, dict[int, tuple[str, str]]] = {} + + +def _get_mlir_module(qnode: QNode, args, kwargs) -> ModuleOp: + """Ensure the QNode is compiled and return its MLIR module.""" + if hasattr(qnode, "mlir_module") and qnode.mlir_module is not None: + return qnode.mlir_module + + func = getattr(qnode, "user_function", qnode) + jitted_qnode = qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(func) + jitted_qnode.jit_compile(args, **kwargs) + return jitted_qnode.mlir_module + + +def draw(qnode: QNode, *, level: None | int = None) -> Callable: + """ + Draw the QNode at the specified level. + + This function can be used to visualize the QNode at different stages of the transformation + pipeline when xDSL or Catalyst compilation passes are applied. + If the specified level is not available, the highest level will be used as a fallback. + + The provided QNode is assumed to be decorated with compilation passes. + If no passes are applied, the original QNode is visualized. + + Args: + qnode (.QNode): the input QNode that is to be visualized. The QNode is assumed to be + compiled with ``qjit``. + level (None | int): the level of transformation to visualize. If `None`, the final + level is visualized. + + + Returns: + Callable: A wrapper function that visualizes the QNode at the specified level. + + """ + cache: dict[int, tuple[str, str]] = _cache_store.setdefault(qnode, {}) + + def _draw_callback(previous_pass, module, next_pass, pass_level=0): + """Callback function for circuit drawing.""" + + pass_instance = previous_pass if previous_pass else next_pass + collector = QMLCollector(module) + ops, meas = collector.collect() + tape = QuantumScript(ops, meas) + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance + cache[pass_level] = ( + tape.draw(show_matrices=False), + pass_name if pass_level else "No transforms", + ) + + @wraps(qnode) + def wrapper(*args, **kwargs): + if args or kwargs: + warnings.warn( + "The `draw` function does not yet support dynamic arguments.\n" + "To visualize the circuit with dynamic parameters or wires, please use the\n" + "`compiler.python_compiler.visualization.generate_mlir_graph` function instead.", + UserWarning, + ) + mlir_module = _get_mlir_module(qnode, args, kwargs) + Compiler.run(mlir_module, callback=_draw_callback) + + if not cache: + return None + + return cache.get(level, cache[max(cache.keys())])[0] + + return wrapper diff --git a/frontend/catalyst/python_interface/visualization/mlir_graph.py b/frontend/catalyst/python_interface/visualization/mlir_graph.py new file mode 100644 index 0000000000..79196185ba --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/mlir_graph.py @@ -0,0 +1,121 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the implementation of the MLIR graph generation for the Unified +Compiler framework. +""" +from __future__ import annotations + +import io +import subprocess +from functools import wraps +from pathlib import Path +from typing import TYPE_CHECKING + +from xdsl.printer import Printer + +from catalyst.compiler import CompileError, _get_catalyst_cli_cmd +from catalyst.python_interface.compiler import Compiler + +from .draw import _get_mlir_module + +if TYPE_CHECKING: + from pennylane import QNode + from pennylane.typing import Callable + +try: + from graphviz import Source as GraphSource + + has_graphviz = True +except (ModuleNotFoundError, ImportError) as import_error: # pragma: no cover + has_graphviz = False + + +# TODO: This interface can be removed once the _quantum_opt interface +# implemented in catalyst.compiler returns the `stderr` output +def _quantum_opt_stderr(*args, stdin=None, stderr_return=False): + """Raw interface to quantum-opt""" + return _catalyst(("--tool", "opt"), *args, stdin=stdin, stderr_return=stderr_return) + + +def _catalyst(*args, stdin=None, stderr_return=False): + """Raw interface to catalyst""" + cmd = _get_catalyst_cli_cmd(*args, stdin=stdin) + try: + result = subprocess.run(cmd, input=stdin, check=True, capture_output=True, text=True) + if stderr_return: + return result.stdout, result.stderr + return result.stdout + except subprocess.CalledProcessError as e: + raise CompileError(f"catalyst failed with error code {e.returncode}: {e.stderr}") from e + + +def _mlir_graph_callback(previous_pass, module, next_pass, pass_level=0): + """Callback function for MLIR graph generation.""" + + pass_instance = previous_pass if previous_pass else next_pass + buffer = io.StringIO() + Printer(stream=buffer, print_generic_format=True).print_op(module) + _, dot_graph = _quantum_opt_stderr( + "--view-op-graph", stdin=buffer.getvalue(), stderr_return=True + ) + graph = GraphSource(dot_graph) + + out_dir = Path("mlir_generated_graphs") + out_dir.mkdir(exist_ok=True) + + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance + pass_name = f"after_{pass_name}" if pass_level else "no_transforms" + out_file = out_dir / f"QNode_level_{pass_level}_{pass_name}.svg" + + with open(out_file, "wb") as f: + f.write(graph.pipe(format="svg")) + + +def generate_mlir_graph(qnode: QNode) -> Callable: + """ + Generate an MLIR graph for the given QNode and saves it to a file. + + This function uses the callback mechanism of the unified compiler framework to generate + the MLIR graph in between compilation passes. The provided QNode is assumed to be decorated + with xDSL compilation passes. The ``qjit`` decorator is used to recompile the QNode with the + passes and the provided arguments. + + If no passes are applied, the original QNode is visualized. + + Args: + qnode (.QNode): the input QNode that is to be visualized. + + + Returns: + Callable: A wrapper function that generates the MLIR graph. + + """ + + if not has_graphviz: + raise ImportError( + "This feature requires graphviz, a library for graph visualization. " + "It can be installed with:\n\npip install graphviz" + ) # pragma: no cover + + @wraps(qnode) + def wrapper(*args, **kwargs): + # We re-compile the qnode to ensure the passes are applied + # with the args and kwargs provided by the user. + # TODO: we could integrate the callback mechanism within `qjit`, + # so that we wouldn't need to recompile the qnode twice. + mlir_module = _get_mlir_module(qnode, args, kwargs) + Compiler.run(mlir_module, callback=_mlir_graph_callback) + + return wrapper diff --git a/frontend/catalyst/python_interface/visualization/xdsl_conversion.py b/frontend/catalyst/python_interface/visualization/xdsl_conversion.py new file mode 100644 index 0000000000..917a292e94 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/xdsl_conversion.py @@ -0,0 +1,330 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the QMLCollector class, +which collects and maps PennyLane operations and measurements from xDSL.""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING + +from pennylane import ops +from pennylane.measurements import expval, probs, sample, state, var +from pennylane.operation import Operator +from pennylane.ops import MidMeasure +from pennylane.ops import __all__ as ops_all +from pennylane.ops import measure +from xdsl.dialects.builtin import DenseIntOrFPElementsAttr, IntegerAttr, IntegerType +from xdsl.dialects.tensor import ExtractOp as TensorExtractOp +from xdsl.ir import SSAValue + +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + ExtractOp, + GlobalPhaseOp, + MeasureOp, + MultiRZOp, + NamedObsOp, + QubitUnitaryOp, + SetBasisStateOp, + SetStateOp, +) + +if TYPE_CHECKING: + from pennylane.measurements import MeasurementProcess + +has_jax = True +try: + import jax +except ImportError: + has_jax = False + + +from_str_to_PL_gate = { + name: getattr(ops, name) + for name in ops_all + if inspect.isclass(getattr(ops, name, None)) and issubclass(getattr(ops, name), Operator) +} + +from_str_to_PL_measurement = { + "quantum.state": state, + "quantum.probs": probs, + "quantum.sample": sample, + "quantum.expval": expval, + "quantum.var": var, + "quantum.measure": measure, +} + + +###################################################### +### Gate/Measurement resolution +###################################################### + + +def _resolve(name: str, mapping: dict, kind: str): + try: + return mapping[name] + except KeyError as exc: + raise NotImplementedError(f"Unsupported {kind}: {name}") from exc + + +def resolve_gate(name: str) -> Operator: + """Resolve the gate from the name.""" + return _resolve(name, from_str_to_PL_gate, "gate") + + +def resolve_measurement(name: str) -> MeasurementProcess: + """Resolve the measurement from the name.""" + return _resolve(name, from_str_to_PL_measurement, "measurement") + + +###################################################### +### Helpers +###################################################### + + +def _tensor_shape_from_ssa(ssa: SSAValue) -> list[int]: + """Extract the concrete shape from an SSA tensor value.""" + # pylint: disable= protected-access + tensor_abstr_shape = ssa.owner.operand._type.shape.data + return [dim.data for dim in tensor_abstr_shape] + + +def _extract(op, attr: str, resolver: Callable, single: bool = False): + """Helper to extract and resolve attributes.""" + values = getattr(op, attr, None) + if not values: + return [] if not single else None + return resolver(values) if single else [resolver(v) for v in values if v is not None] + + +def _extract_dense_constant_value(op) -> float | int: + """Extract the first value from a stablehlo.constant op.""" + attr = op.properties.get("value") + if isinstance(attr, DenseIntOrFPElementsAttr): + # TODO: handle multi-value cases if needed + return attr.get_values()[0] + raise NotImplementedError(f"Unexpected attr type in constant: {type(attr)}") + + +def _apply_adjoint_and_ctrls(qml_op: Operator, xdsl_op) -> Operator: + """Apply adjoint and control modifiers to a gate if needed.""" + if xdsl_op.properties.get("adjoint"): + qml_op = ops.op_math.adjoint(qml_op) + ctrls = ssa_to_qml_wires(xdsl_op, control=True) + if ctrls: + cvals = ssa_to_qml_params(xdsl_op, control=True) + qml_op = ops.op_math.ctrl(qml_op, control=ctrls, control_values=cvals) + return qml_op + + +# pylint: disable=too-many-return-statements +def resolve_constant_params(ssa: SSAValue) -> float | int: + """Resolve a constant parameter SSA value to a Python float or int.""" + op = ssa.owner + + if isinstance(op, TensorExtractOp): + return resolve_constant_params(op.tensor) + + match op.name: + + case "arith.addf": + return sum(resolve_constant_params(o) for o in op.operands) + + case "arith.constant": + return op.value.value.data # Catalyst + + case "stablehlo.constant": + return _extract_dense_constant_value(op) + + case "stablehlo.convert" | "stablehlo.broadcast_in_dim": + return resolve_constant_params(op.operands[0]) + + case "stablehlo.concatenate": + return [resolve_constant_params(operand) for operand in op.operands] + + case "stablehlo.reshape": + res_type = op.result_types[0] + shape = res_type.get_shape() + type_ = res_type.get_element_type() + return jax.numpy.array(shape, dtype=int if isinstance(type_, IntegerType) else float) + + case _: + raise NotImplementedError(f"Cannot resolve parameters for operation: {op}") + + +def dispatch_wires_extract(op: ExtractOp): + """Dispatch the wire resolution for the given extract operation.""" + if op.idx_attr is not None: # used by Catalyst + return resolve_constant_wire(op.idx_attr) + return resolve_constant_wire(op.idx) # used by xDSL + + +def resolve_constant_wire(ssa: SSAValue) -> float | int: + """Resolve the wire for the given SSA qubit.""" + if isinstance(ssa, IntegerAttr): # Catalyst + return ssa.value.data + + op = ssa.owner + + match op: + + case TensorExtractOp(tensor=tensor): + return resolve_constant_wire(tensor) + + case _ if op.name == "stablehlo.convert": + return resolve_constant_wire(op.operands[0]) + + case _ if op.name == "stablehlo.constant": + return _extract_dense_constant_value(op) + + case ( + CustomOp() + | GlobalPhaseOp() + | QubitUnitaryOp() + | SetStateOp() + | MultiRZOp() + | SetBasisStateOp() + ): + all_qubits = list(getattr(op, "in_qubits", [])) + list( + getattr(op, "in_ctrl_qubits", []) + ) + return resolve_constant_wire(all_qubits[ssa.index]) + + case ExtractOp(): + return dispatch_wires_extract(op) + + case MeasureOp(in_qubit=in_qubit): + return resolve_constant_wire(in_qubit) + + case _: + raise NotImplementedError(f"Cannot resolve wire for op: {op}") + + +###################################################### +### Parameters/Wires Conversion +###################################################### + + +def ssa_to_qml_params( + op, control: bool = False, single: bool = False +) -> list[float | int] | float | int | None: + """Get the parameters from the operation.""" + return _extract(op, "in_ctrl_values" if control else "params", resolve_constant_params, single) + + +def ssa_to_qml_wires(op: CustomOp, control: bool = False) -> list[int]: + """Get the wires from the operation.""" + return _extract(op, "in_ctrl_qubits" if control else "in_qubits", resolve_constant_wire) + + +def ssa_to_qml_wires_named(op: NamedObsOp) -> int: + """Get the wire from the named observable operation.""" + if not op.qubit: + raise ValueError("No qubit found for named observable operation.") + return resolve_constant_wire(op.qubit) + + +############################################################ +### xDSL ---> PennyLane Operators/Measurements conversion +############################################################ + + +def xdsl_to_qml_op(op) -> Operator: + """Convert an xDSL operation into a PennyLane Operator. + + Args: + op: The xDSL operation to convert. + + Returns: + A PennyLane Operator. + """ + + match op.name: + + case "quantum.gphase": + gate = ops.GlobalPhase(ssa_to_qml_params(op, single=True), wires=ssa_to_qml_wires(op)) + + case "quantum.unitary": + gate = ops.qubit.matrix_ops.QubitUnitary( + U=jax.numpy.zeros(_tensor_shape_from_ssa(op.matrix)), wires=ssa_to_qml_wires(op) + ) + + case "quantum.set_state": + gate = ops.qubit.state_preparation.StatePrep( + state=jax.numpy.zeros(_tensor_shape_from_ssa(op.in_state)), + wires=ssa_to_qml_wires(op), + ) + + case "quantum.multirz": + gate = ops.qubit.parametric_ops_multi_qubit.MultiRZ( + theta=_extract(op, "theta", resolve_constant_params, single=True), + wires=ssa_to_qml_wires(op), + ) + + case "quantum.set_basis_state": + gate = ops.qubit.state_preparation.BasisState( + state=jax.numpy.zeros(_tensor_shape_from_ssa(op.basis_state)), + wires=ssa_to_qml_wires(op), + ) + + case "quantum.custom": + gate_cls = resolve_gate(op.properties.get("gate_name").data) + gate = gate_cls(*ssa_to_qml_params(op), wires=ssa_to_qml_wires(op)) + + case _: + raise NotImplementedError(f"Unsupported gate: {op.name}") + + return _apply_adjoint_and_ctrls(gate, op) + + +def xdsl_to_qml_measurement(op, *args, **kwargs) -> MeasurementProcess | Operator: + """Convert any xDSL measurement/observable operation to a PennyLane object. + + Args: + op: The xDSL measurement/observable operation to convert. + + Returns: + A PennyLane MeasurementProcess or Operator. + """ + + match op.name: + + case "quantum.measure": + postselect = op.postselect.value.data if op.postselect is not None else None + return MidMeasure([resolve_constant_wire(op.in_qubit)], postselect=postselect) + + case "quantum.namedobs": + return resolve_gate(op.type.data.value)(wires=ssa_to_qml_wires_named(op)) + + case "quantum.tensor": + return ops.op_math.prod( + *(xdsl_to_qml_measurement(operand.owner) for operand in op.operands) + ) + + case "quantum.hamiltonian": + coeffs = _extract(op, "coeffs", resolve_constant_params, single=True) + ops_list = [xdsl_to_qml_measurement(term.owner) for term in op.terms] + return ops.LinearCombination(coeffs, ops_list) + case "quantum.compbasis": + return _extract(op, "qubits", resolve_constant_wire) + + case ( + "quantum.state" | "quantum.probs" | "quantum.sample" | "quantum.expval" | "quantum.var" + ): + return resolve_measurement(op.name)(*args, **kwargs) + + case _: + raise NotImplementedError(f"Unsupported measurement/observable: {op.name}") diff --git a/frontend/catalyst/python_interface/xdsl_extras/__init__.py b/frontend/catalyst/python_interface/xdsl_extras/__init__.py new file mode 100644 index 0000000000..16785fe2c6 --- /dev/null +++ b/frontend/catalyst/python_interface/xdsl_extras/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains additional utilities and functionality not available upstream in xDSL.""" + +from .constraints import MemRefConstraint, NestedTupleOfConstraint, TensorConstraint +from .traits import ( + AllMatchSameOperatorTrait, + Elementwise, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +__all__ = [ + # Constraints + "NestedTupleOfConstraint", + "MemRefConstraint", + "TensorConstraint", + # Traits + "AllMatchSameOperatorTrait", + "Elementwise", + "SameOperandsAndResultElementType", + "SameOperandsAndResultShape", + "SameOperandsElementType", +] diff --git a/frontend/catalyst/python_interface/xdsl_extras/constraints.py b/frontend/catalyst/python_interface/xdsl_extras/constraints.py new file mode 100644 index 0000000000..c5e597f20b --- /dev/null +++ b/frontend/catalyst/python_interface/xdsl_extras/constraints.py @@ -0,0 +1,238 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains additional type and attribute constraints that are currently not available +upstream in xDSL.""" + +from abc import ABC, abstractmethod +from collections.abc import Collection, Mapping, Sequence +from dataclasses import dataclass +from typing import TypeVar + +from xdsl.dialects.builtin import ( + ArrayAttr, + IntAttr, + IntAttrConstraint, + MemRefType, + TensorType, + TupleType, +) +from xdsl.ir import Attribute, AttributeInvT +from xdsl.irdl import ( + AnyAttr, + AnyInt, + AttrConstraint, + ConstraintContext, + EqIntConstraint, + IntConstraint, + IntSetConstraint, + IRDLAttrConstraint, + RangeLengthConstraint, + RangeOf, + irdl_to_attr_constraint, +) +from xdsl.utils.exceptions import VerifyException + + +@dataclass(frozen=True, init=False) +class ContainerConstraint(AttrConstraint, ABC): + r"""Internal base class for constraining the element type and shape of container types. + + The shape of the container can be constrained by providing a constraint either directly + for the shape, or by providing a constraint for the rank. There are two ways to provide + an explicit shape constraint: + + * By providing an ``IRDLAttrConstraint`` for the ``shape`` argument. + * By providing a sequence of ``int``\ s specifying the concrete expected shape for + the ``shape`` argument. + + There are three ways to provide the rank constraint: + + * By providing an ``IRDLAttrConstraint`` for the ``rank`` argument. + * By providing an ``int`` representing the concrete expected rank for the ``rank`` + argument. + * By providing a collection of ``int``\ s specifying the various allowed ranks for the + ``rank`` argument. + + .. note:: + + Only one of the ``shape`` or ``rank`` constraint may be provided, not both. + + Args: + element_type (IRDLAttrConstraint | None): The constraint for the element type. + Default is ``None``, which indicates that any element type is allowed. + shape (IRDLAttrConstraint | Sequence[int] | None): The constraint for the shape. + Default is ``None``, which indicates that any shape is allowed. + rank (IRDLAttrConstraint | Collection[int] | int | None): The constraint for the + rank. Default is ``None``, which indicates that any rank is allowed. + """ + + element_type: IRDLAttrConstraint[AttributeInvT] + + shape: IRDLAttrConstraint[AttributeInvT] + + def __init__( + self, + *, + element_type: IRDLAttrConstraint[AttributeInvT] | None = None, + shape: IRDLAttrConstraint[AttributeInvT] | Sequence[int] | None = None, + rank: IRDLAttrConstraint[AttributeInvT] | Collection[int] | int | None = None, + ): + element_type = element_type or AnyAttr() + element_type_constr = element_type + shape_constr = None + + if shape is not None and rank is not None: + raise ValueError("Only one of 'shape' or 'rank' may be provided.") + + if shape is None and rank is None: + shape_constr = AnyAttr() + + elif shape is not None: + shape_constr = shape + if isinstance(shape_constr, Sequence): + shape_constr = ArrayAttr([IntAttr(s) for s in shape]) + + # rank is not None + else: + shape_constr = rank + if isinstance(shape_constr, (int, Collection)): + # Constrain shape to have length `rank` if `rank` is an int + if isinstance(shape_constr, int): + length_constr = EqIntConstraint(shape_constr) + else: + length_constr = IntSetConstraint(frozenset(shape_constr)) + shape_constr = ArrayAttr.constr( + RangeLengthConstraint( + constraint=RangeOf(IntAttrConstraint(AnyInt())), length=length_constr + ) + ) + + if not isinstance(element_type_constr, (Attribute, AttrConstraint)): + raise TypeError( + f"{element_type} is not a valid constraint for the 'element_type' argument. " + "'element_type' must be an AttrConstraint or Attribute." + ) + if not isinstance(shape_constr, (Attribute, AttrConstraint)): + if shape is not None: + raise TypeError( + f"{shape} is not a valid constraint for the 'shape' argument. 'shape' " + "must be an AttrConstraint, Attribute, or sequence of integers." + ) + raise TypeError( + f"{rank} is not a valid constraint for the 'rank' argument. 'rank' must be " + "an AttrConstraint, Attribute, integer, or collection of integers." + ) + + object.__setattr__(self, "element_type", element_type_constr) + object.__setattr__(self, "shape", shape_constr) + + @property + @abstractmethod + def expected_type(self) -> type[Attribute]: + """The expected IR type class (e.g., builtin.TensorType).""" + + @property + def type_name(self) -> str: + """The name of the type for use in error messages (e.g., 'tensor').""" + return self.expected_type.name + + def get_bases(self) -> set[type[Attribute]]: + """Get a set of base types that can satisfy this constraint (e.g., {builtin.TensorType}).""" + return {self.expected_type} + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + """Verify that the attribute meets the constraint.""" + if not isinstance(attr, self.expected_type): + raise VerifyException(f"{attr} should be of type {self.expected_type.__name__}.") + constr = self.expected_type.constr(element_type=self.element_type, shape=self.shape) + constr.verify(attr, constraint_context) + + # pylint: disable=unused-argument + def mapping_type_vars( + self, type_var_mapping: dict[TypeVar, AttrConstraint] + ) -> "ContainerConstraint": + """ + A helper function to make type vars used in attribute definitions concrete when + creating constraints for new attributes or operations. + """ + return self + + +@dataclass(frozen=True, init=False) +class TensorConstraint(ContainerConstraint): + """TensorType constraint for element type and shape.""" + + @property + def expected_type(self): + return TensorType + + +@dataclass(frozen=True, init=False) +class MemRefConstraint(ContainerConstraint): + """MemRefType constraint for element type and shape.""" + + @property + def expected_type(self): + return MemRefType + + +@dataclass(frozen=True, init=False) +class NestedTupleOfConstraint(AttrConstraint[TupleType]): + """Constrain a nested tuple whose flattened leaves all match any allowed constraints.""" + + elem_constraints: tuple[AttrConstraint, ...] + + def __init__(self, elem_constraints: Sequence[object]): + object.__setattr__( + self, + "elem_constraints", + tuple(irdl_to_attr_constraint(c) for c in elem_constraints), + ) + + def get_flattened(self, a: Attribute): + """Get the flattened leaves of a tuple.""" + if isinstance(a, TupleType): + for t in a.types.data: + yield from self.get_flattened(t) + else: + yield a + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + """Verify that the attribute is a tuple of allowed types.""" + if not isinstance(attr, TupleType): + raise VerifyException(f"expected TupleType, got {type(attr)}") + + leaves = list(self.get_flattened(attr)) + + for i, leaf in enumerate(leaves): + matched = False + for constr in self.elem_constraints: + try: + constr.verify(leaf, constraint_context) + matched = True + break + except VerifyException: + # Try next allowed constraint + pass + if not matched: + raise VerifyException(f"tuple leaf {i} failed all allowed constraints: {leaf}") + + # pylint: disable=unused-argument + def mapping_type_vars( + self, + type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint], + ) -> AttrConstraint: + """Map type variables to constraints.""" + return self diff --git a/frontend/catalyst/python_interface/xdsl_extras/traits.py b/frontend/catalyst/python_interface/xdsl_extras/traits.py new file mode 100644 index 0000000000..937d960a28 --- /dev/null +++ b/frontend/catalyst/python_interface/xdsl_extras/traits.py @@ -0,0 +1,241 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Traits for xDSL operations. + +This module provides operation traits that can be used to define operation invariants, +additional semantic information, or to group operations that have similar properties. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from xdsl.dialects.builtin import TensorType, VectorType +from xdsl.ir import Attribute, Operation +from xdsl.traits import OpTrait +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.type import get_element_type_or_self, have_compatible_shape + + +@dataclass(frozen=True) +class SameOperandsAndResultShape(OpTrait): + """Constrain the operation to have the same shape for all operands and results.""" + + # TODO: This trait should be added to ElementwiseBinaryOperation and + # ElementwiseUnaryOperation operations when upstreaming to xdsl. + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same shape for all operands and results.""" + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException(f"'{op.name}' requires at least one result or operand") + + # Get all types (operands and results) to check for compatible shapes + all_types = list(op.operand_types) + list(op.result_types) + + # Check that all types have compatible shapes + for type_to_check in all_types[1:]: + if not have_compatible_shape(all_types[0], type_to_check): + + raise VerifyException( + f"'{op.name}' requires the same shape for all operands and results" + ) + + +@dataclass(frozen=True) +class SameOperandsElementType(OpTrait): + """Constrain the operation to have the same element type for all operands.""" + + # TODO: This trait should be added to ElementwiseBinaryOperation and + # ElementwiseUnaryOperation operations when upstreaming to xdsl. + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same element type for all operands.""" + + if len(op.operands) <= 1: + return + + # Get the element type of the first operand + first_elem_type = get_element_type_or_self(op.operand_types[0]) + + # Check that all other operands have the same element type + for operand_type in op.operand_types[1:]: + elem_type = get_element_type_or_self(operand_type) + if elem_type != first_elem_type: + raise VerifyException( + f"'{op.name}' requires the same element type for all operands" + ) + + +@dataclass(frozen=True) +class SameOperandsAndResultElementType(OpTrait): + """Constrain the operation to have the same element type for all operands and results.""" + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same element type for all operands and results.""" + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException(f"'{op.name}' requires at least one result or operand") + + # Get the element type of the first operand + first_elem_type = get_element_type_or_self(op.operand_types[0]) + + all_types = list(op.operand_types) + list(op.result_types) + + # Check that all other operands have the same element type + for type_to_check in all_types[1:]: + elem_type = get_element_type_or_self(type_to_check) + if elem_type != first_elem_type: + raise VerifyException( + f"'{op.name}' requires the same element type for all operands and results" + ) + + +@dataclass(frozen=True) +class Elementwise(OpTrait): + """ + The following is the definition of the `Elementwise` trait from MLIR: + + https://github.com/llvm/llvm-project/blob/f8cb7987c64dcffb72414a40560055cb717dbf74/mlir/include/mlir/IR/OpDefinition.h#L1378-L1409 + + TODO: Add this trait to all the elementwise operations in xdsl when upstreaming. + + Tags elementwise operations on vectors or tensors. + + NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this trait. + In particular, broadcasting behavior is not allowed. + + An `Elementwise` op must satisfy the following properties: + + 1. If any result is a vector/tensor then at least one operand must also be a + vector/tensor. + 2. If any operand is a vector/tensor then there must be at least one result + and all results must be vectors/tensors. + 3. All operand and result vector/tensor types must be of the same shape. The + shape may be dynamic in which case the op's behaviour is undefined for + non-matching shapes. + 4. The operation must be elementwise on its vector/tensor operands and + results. When applied to single-element vectors/tensors, the result must + be the same per element. + + Rationale: + - 1. and 2. guarantee a well-defined iteration space and exclude the cases + of 0 non-scalar operands or 0 non-scalar results, which complicate a + generic definition of the iteration space. + - 3. guarantees that folding can be done across scalars/vectors/tensors with + the same pattern, as otherwise lots of special handling for type + mismatches would be needed. + - 4. guarantees that no error handling is needed. Higher-level dialects + should reify any needed guards or error handling code before lowering to + an Elementwise op. + """ + + def verify(self, op: Operation) -> None: + """Verify that the operation is elementwise.""" + + # Filter mappable types from results and operands (vectors/tensors only) + result_mappable_types = [t for t in op.result_types if Elementwise.is_mappable_type(t)] + operand_mappable_types = [t for t in op.operand_types if Elementwise.is_mappable_type(t)] + + # If the op only has scalar operand/result types, then we have nothing to check + if not result_mappable_types and not operand_mappable_types: + return + + # If a result is non-scalar, then at least one operand must be non-scalar + if result_mappable_types and not operand_mappable_types: + raise VerifyException( + f"'{op.name}': if a result is non-scalar, then at least one " + "operand must be non-scalar" + ) + + # At this point, operand_mappable_types should not be empty + assert operand_mappable_types, "At least one operand must be a vector or tensor" + + # If an operand is non-scalar, then there must be at least one non-scalar result + if not result_mappable_types: + raise VerifyException( + f"'{op.name}': if an operand is non-scalar, then there must be at " + "least one non-scalar result" + ) + + # If an operand is non-scalar, then all results must be non-scalar + if len(result_mappable_types) != len(op.results): + raise VerifyException( + f"'{op.name}': if an operand is non-scalar, then all results must be non-scalar" + ) + + # All non-scalar operands/results must have the same shape and base type + all_types = operand_mappable_types + result_mappable_types + + # Check that all types have compatible shapes + for type_to_check in all_types[1:]: + if not have_compatible_shape(all_types[0], type_to_check): + raise VerifyException( + f"'{op.name}': all non-scalar operands/results must have the " + "same shape and base type" + ) + + @staticmethod + def is_mappable_type(attr_type: Attribute) -> bool: + """Return True if the type is elementwise-mappable (vector or tensor). + + There is a TODO in MLIR to generalize this trait to avoid hardcoding vector/tensor. + We should update this when the TODO is resolved. + """ + return isinstance(attr_type, (VectorType, TensorType)) + + +@dataclass(frozen=True) +class AllMatchSameOperatorTrait(OpTrait): + """ + Verify that a list of operation attributes all match under the same operator + (e.g., size, rank, type, shape, element type). + + Parameters: + - attr_names: attribute names on the op to compare + - operator: callable taking the attribute value and returning a comparable value + - summary: human-readable name of the property used in error messages + """ + + attr_names: tuple[str, ...] + operator: Callable[[Any], Any] + summary: str + + def verify(self, op: Operation) -> None: + """Verify that the operation attributes all match under the same operator.""" + attributes = [] + for name in self.attr_names: + value = getattr(op, name, None) + if value is None: + return + attributes.append(value) + + if len(attributes) <= 1: + return + + names_str = ", ".join(self.attr_names) + try: + results = [self.operator(attr) for attr in attributes] + except (TypeError, ValueError, AttributeError) as e: + raise VerifyException(f"cannot compute {self.summary} for {{{names_str}}}: {e}") from e + + first = results[0] + if any(res != first for res in results[1:]): + results_str = ", ".join(str(r) for r in results) + raise VerifyException( + f"all of {{{names_str}}} must have the same {self.summary}: got " + f"{self.summary}s {results_str}" + ) diff --git a/frontend/test/lit/test_xdsl_passes.py b/frontend/test/lit/test_xdsl_passes.py index 88c113c6f6..57cd305560 100644 --- a/frontend/test/lit/test_xdsl_passes.py +++ b/frontend/test/lit/test_xdsl_passes.py @@ -19,7 +19,8 @@ """ import pennylane as qml -from pennylane.compiler.python_compiler.transforms import merge_rotations_pass + +from catalyst.python_interface.transforms import merge_rotations_pass def test_mlir_pass_no_attribute(): diff --git a/frontend/test/pytest/conftest.py b/frontend/test/pytest/conftest.py index 685bb20f8e..a37fbd5959 100644 --- a/frontend/test/pytest/conftest.py +++ b/frontend/test/pytest/conftest.py @@ -16,16 +16,12 @@ """ import os -import pathlib from tempfile import TemporaryDirectory from textwrap import dedent import pennylane as qml import pytest -TEST_PATH = os.path.dirname(__file__) -CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml") - @pytest.fixture(scope="function") def create_temporary_toml_file(request) -> str: diff --git a/frontend/test/pytest/device/test_decomposition.py b/frontend/test/pytest/device/test_decomposition.py index 2889545f40..5f3594de1f 100644 --- a/frontend/test/pytest/device/test_decomposition.py +++ b/frontend/test/pytest/device/test_decomposition.py @@ -14,21 +14,20 @@ """Unit test module for catalyst/device/decomposition.py""" -import os -import pathlib import platform import numpy as np import pennylane as qml import pytest from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties +from utils import CONFIG_CUSTOM_DEVICE from catalyst import CompileError, ctrl, qjit from catalyst.compiler import get_lib_path from catalyst.device.decomposition import catalyst_decomposer -TEST_PATH = os.path.dirname(__file__) -CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../../custom_device/custom_device.toml") +# TEST_PATH = os.path.dirname(__file__) +# CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../../custom_device/custom_device.toml") class TestGateAliases: diff --git a/frontend/test/pytest/python_interface/conftest.py b/frontend/test/pytest/python_interface/conftest.py new file mode 100644 index 0000000000..ca18891cd6 --- /dev/null +++ b/frontend/test/pytest/python_interface/conftest.py @@ -0,0 +1,204 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytest configuration for tests for the catalyst.python_interface submodule.""" + +from inspect import getsource +from io import StringIO + +import pytest + +from catalyst.python_interface import Compiler, QuantumParser +from catalyst.python_interface.conversion import parse_generic_to_xdsl_module + +deps_available = True + +try: + from filecheck.finput import FInput + from filecheck.matcher import Matcher + from filecheck.options import parse_argv_options + from filecheck.parser import Parser, pattern_for_opts + from xdsl.context import Context + from xdsl.dialects import test + from xdsl.passes import PassPipeline + from xdsl.printer import Printer + +except (ImportError, ModuleNotFoundError): + deps_available = False + + +def _run_filecheck_impl(program_str, pipeline=(), verify=False, roundtrip=False): + """Run filecheck on an xDSL module, comparing it to a program string containing + filecheck directives.""" + if not deps_available: + return + + ctx = Context() + xdsl_module = QuantumParser(ctx, program_str, extra_dialects=(test.Test,)).parse_module() + + if roundtrip: + # Print generic format + stream = StringIO() + Printer(stream=stream, print_generic_format=True).print_op(xdsl_module) + xdsl_module = QuantumParser(ctx, stream.getvalue()).parse_module() + + if verify: + xdsl_module.verify() + + pipeline = PassPipeline(pipeline) + pipeline.apply(ctx, xdsl_module) + + if verify: + xdsl_module.verify() + + stream = StringIO() + Printer(stream).print_op(xdsl_module) + opts = parse_argv_options(["filecheck", __file__]) + matcher = Matcher( + opts, + FInput("no-name", stream.getvalue()), + Parser(opts, StringIO(program_str), *pattern_for_opts(opts)), + ) + + exit_code = matcher.run() + assert ( + exit_code == 0 + ), f""" + filecheck failed with exit code {exit_code}. + + Original program string: + {program_str} + + Parsed module: + {stream.getvalue()} + """ + + +@pytest.fixture(scope="function") +def run_filecheck(): + """Fixture to run filecheck on an xDSL module. + + This fixture uses FileCheck to verify the correctness of a parsed MLIR string. Testers + can provide a pass pipeline to transform the IR, and verify correctness by including + FileCheck directives as comments in the input program string. + + Args: + program_str (str): The MLIR string containing the input program and FileCheck directives + pipeline (tuple[ModulePass]): A sequence containing all passes that should be applied + before running FileCheck + verify (bool): Whether or not to verify the IR after parsing and transforming. + ``False`` by default. + roundtrip (bool): Whether or not to use round-trip testing. This is useful for dialect + tests to verify that xDSL both parses and prints the IR correctly. If ``True``, we parse + the program string into an xDSL module, print it in generic format, and then parse the + generic program string back to an xDSL module. ``False`` by default. + """ + if not deps_available: + pytest.skip("Cannot run lit tests without xDSL and filecheck.") + + yield _run_filecheck_impl + + +def _get_filecheck_directives(qjit_fn): + """Return a string containing all FileCheck directives in the source function.""" + try: + src = getsource(qjit_fn) + except Exception as e: + raise RuntimeError(f"Could not get source for {qjit_fn}") from e + + filecheck_directives = [] + for line in src.splitlines(): + line = line.strip() + if line[0] != "#": + continue + + line = line[1:].strip() + if line.startswith("CHECK"): + filecheck_directives.append("// " + line) + + return "\n".join(filecheck_directives) + + +def _run_filecheck_qjit_impl(qjit_fn, verify=False): + """Run filecheck on a qjit-ed function, using FileCheck directives in its inline + comments to assert correctness.""" + if not deps_available: + return + + checks = _get_filecheck_directives(qjit_fn) + compiler = Compiler() + mlir_module = compiler.run(qjit_fn.mlir_module) + + # The following is done because ``mlir_module`` will be in the generic syntax, and + # we want as many ops to be pretty printed as possible. + mod_str = mlir_module.operation.get_asm( + binary=False, print_generic_op_form=True, assume_verified=True + ) + xdsl_module = parse_generic_to_xdsl_module(mod_str) + + if verify: + xdsl_module.verify() + + opts = parse_argv_options(["filecheck", __file__]) + matcher = Matcher( + opts, + FInput("no-name", str(xdsl_module)), + Parser(opts, StringIO(checks), *pattern_for_opts(opts)), + ) + + exit_code = matcher.run() + assert exit_code == 0, f"filecheck failed with exit code {exit_code}" + + +@pytest.fixture(scope="function") +def run_filecheck_qjit(): + """Fixture to run filecheck on a qjit-ed function. + + This fixture yields a function that takes a QJIT object as input, parses its + MLIR, applies any passes that are present, and uses FileCheck to check the + output IR against FileCheck directives that may be present in the source + function as inline comments. + + Args: + qjit_fn (Callable): The QJIT object on which we want to run lit tests + verify (bool): Whether or not to verify the IR after parsing and transforming. + ``False`` by default. + + An example showing how to use the fixture is shown below. We apply the + ``merge_rotations_pass`` and check that there is only one rotation in + the final IR: + + .. code-block:: python + + def test_qjit(self, run_filecheck_qjit): + # Test that the merge_rotations_pass works as expected when used with `qjit` + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @merge_rotations_pass + @qml.qnode(dev) + def circuit(x: float, y: float): + # CHECK: [[phi:%.*]] = arith.addf + # CHECK: quantum.custom "RX"([[phi]]) + # CHECK-NOT: quantum.custom + qml.RX(x, 0) + qml.RX(y, 0) + return qml.state() + + run_filecheck_qjit(circuit) + + """ + if not deps_available: + pytest.skip("Cannot run lit tests without xDSL and filecheck.") + + yield _run_filecheck_qjit_impl diff --git a/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py new file mode 100644 index 0000000000..4f5c149ac0 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py @@ -0,0 +1,105 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/catalyst.py.""" + +import pytest + +# pylint: disable=wrong-import-position + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from catalyst.python_interface.dialects import Catalyst + +all_ops = list(Catalyst.operations) +all_attrs = list(Catalyst.attributes) + +expected_ops_names = { + "AssertionOp": "catalyst.assert", + "CallbackCallOp": "catalyst.callback_call", + "CallbackOp": "catalyst.callback", + "CustomCallOp": "catalyst.custom_call", + "LaunchKernelOp": "catalyst.launch_kernel", + "ListDeallocOp": "catalyst.list_dealloc", + "ListInitOp": "catalyst.list_init", + "ListLoadDataOp": "catalyst.list_load_data", + "ListPopOp": "catalyst.list_pop", + "ListPushOp": "catalyst.list_push", + "PrintOp": "catalyst.print", +} + +expected_attrs_names = { + "ArrayListType": "catalyst.arraylist", +} + + +def test_catalyst_dialect_name(): + """Test that the Catalyst dialect name is correct.""" + assert Catalyst.name == "catalyst" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert ( + expected_name is not None + ), f"Unexpected operation {op_class_name} found in Catalyst dialect" + assert op.name == expected_name + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert ( + expected_name is not None + ), f"Unexpected attribute {attr_class_name} found in Catalyst dialect" + assert attr.name == expected_name + + +def test_assembly_format(run_filecheck): + """Test the assembly format of the catalyst ops.""" + program = """ + // CHECK: [[LIST:%.+]] = catalyst.list_init : !catalyst.arraylist + %list = catalyst.list_init : !catalyst.arraylist + + // CHECK: [[DATA:%.+]] = catalyst.list_load_data [[LIST]] : !catalyst.arraylist -> memref + %data = catalyst.list_load_data %list : !catalyst.arraylist -> memref + + // CHECK: [[VAL:%.+]] = "test.op"() : () -> f64 + %val = "test.op"() : () -> f64 + + // CHECK: [[POP_RESULT:%.+]] = catalyst.list_pop [[LIST]] : !catalyst.arraylist + %pop_result = catalyst.list_pop %list : !catalyst.arraylist + + // CHECK: catalyst.list_push [[VAL]], [[LIST]] : !catalyst.arraylist + catalyst.list_push %val, %list : !catalyst.arraylist + + // CHECK: catalyst.list_dealloc [[LIST]] : !catalyst.arraylist + catalyst.list_dealloc %list : !catalyst.arraylist + + // CHECK: [[CUSTOM_RESULT:%.+]] = catalyst.custom_call fn("custom_function") ([[VAL]]) : (f64) -> f64 + %custom_result = catalyst.custom_call fn("custom_function")(%val) : (f64) -> f64 + + // CHECK: [[KERNEL_RESULT:%.+]] = catalyst.launch_kernel @kernel_name([[VAL]]) : (f64) -> f64 + %kernel_result = catalyst.launch_kernel @kernel_name(%val) : (f64) -> f64 + + // CHECK: [[CALLBACK_RESULT:%.+]] = catalyst.callback_call @callback_func([[VAL]]) : (f64) -> f64 + %callback_result = catalyst.callback_call @callback_func(%val) : (f64) -> f64 + """ + + run_filecheck(program, roundtrip=True) diff --git a/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py new file mode 100644 index 0000000000..6e776e5175 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py @@ -0,0 +1,177 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/mbqc.py.""" + + +import pytest + +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from xdsl.context import Context +from xdsl.dialects import arith, builtin, test +from xdsl.parser import Parser +from xdsl.utils.exceptions import VerifyException + +from catalyst.python_interface.dialects import Quantum, mbqc + +all_ops = list(mbqc.MBQC.operations) +all_attrs = list(mbqc.MBQC.attributes) + +expected_ops_names = { + "MeasureInBasisOp": "mbqc.measure_in_basis", + "GraphStatePrepOp": "mbqc.graph_state_prep", +} + +expected_attrs_names = { + "MeasurementPlaneAttr": "mbqc.measurement_plane", +} + + +def test_mbqc_dialect_name(): + """Test that the MBQCDialect name is correct.""" + assert mbqc.MBQC.name == "mbqc" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert expected_name is not None, f"Unexpected operation {op_class_name} found in MBQCDialect" + assert op.name == expected_name + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert expected_name is not None, f"Unexpected attribute {attr_class_name} found in MBQCDialect" + assert attr.name == expected_name + + +def test_assembly_format(run_filecheck): + """Test the assembly format of the mbqc ops.""" + program = r""" + // CHECK: [[angle:%.+]] = "test.op"() : () -> f64 + %angle = "test.op"() : () -> f64 + + // CHECK: [[qubit:%.+]] = "test.op"() : () -> !quantum.bit + %qubit = "test.op"() : () -> !quantum.bit + + // CHECK: [[mres0:%.+]], [[out_qubit0:%.+]] = mbqc.measure_in_basis{{\s*}}[XY, [[angle]]] [[qubit]] : i1, !quantum.bit + %mres0, %out_qubit0 = mbqc.measure_in_basis [XY, %angle] %qubit : i1, !quantum.bit + + // CHECK: [[mres1:%.+]], [[out_qubit1:%.+]] = mbqc.measure_in_basis{{\s*}}[YZ, [[angle]]] [[qubit]] : i1, !quantum.bit + %mres1, %out_qubit1 = mbqc.measure_in_basis [YZ, %angle] %qubit : i1, !quantum.bit + + // CHECK: [[mres2:%.+]], [[out_qubit2:%.+]] = mbqc.measure_in_basis{{\s*}}[ZX, [[angle]]] [[qubit]] : i1, !quantum.bit + %mres2, %out_qubit2 = mbqc.measure_in_basis [ZX, %angle] %qubit : i1, !quantum.bit + + // CHECK: [[mres3:%.+]], [[out_qubit3:%.+]] = mbqc.measure_in_basis{{\s*}}[XY, [[angle]]] [[qubit]] postselect 0 : i1, !quantum.bit + %mres3, %out_qubit3 = mbqc.measure_in_basis [XY, %angle] %qubit postselect 0 : i1, !quantum.bit + + // CHECK: [[mres4:%.+]], [[out_qubit4:%.+]] = mbqc.measure_in_basis{{\s*}}[XY, [[angle]]] [[qubit]] postselect 1 : i1, !quantum.bit + %mres4, %out_qubit4 = mbqc.measure_in_basis [XY, %angle] %qubit postselect 1 : i1, !quantum.bit + + // COM: Check generic format + // CHECK: {{%.+}}, {{%.+}} = mbqc.measure_in_basis[XY, [[angle]]] [[qubit]] postselect 0 : i1, !quantum.bit + %res:2 = "mbqc.measure_in_basis"(%qubit, %angle) <{plane = #mbqc, postselect = 0 : i32}> : (!quantum.bit, f64) -> (i1, !quantum.bit) + + // CHECK: [[adj_matrix:%.+]] = arith.constant {{.*}} : tensor<6xi1> + // CHECK: [[graph_reg:%.+]] = mbqc.graph_state_prep{{\s*}}([[adj_matrix]] : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + %adj_matrix = arith.constant dense<[1, 0, 1, 0, 0, 1]> : tensor<6xi1> + %graph_reg = mbqc.graph_state_prep (%adj_matrix : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + """ + + run_filecheck(program, roundtrip=True) + + +class TestMeasureInBasisOp: + """Unit tests for the mbqc.measure_in_basis op.""" + + @pytest.mark.parametrize("plane,", ["XY", "YZ", "ZX"]) + @pytest.mark.parametrize("postselect", ["", "postselect 0", "postselect 1"]) + def test_measure_in_basis_properties(self, plane, postselect): + """Test the parsing of the mbqc.measure_in_basis op's properties.""" + program = rf""" + %angle = "test.op"() : () -> f64 + %qubit = "test.op"() : () -> !quantum.bit + + %mres, %out_qubit = mbqc.measure_in_basis [{plane}, %angle] %qubit {postselect} : i1, !quantum.bit + """ + + ctx = Context() + + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(test.Test) + ctx.load_dialect(Quantum) + ctx.load_dialect(mbqc.MBQC) + + module = Parser(ctx, program).parse_module() + + measure_in_basis_op: mbqc.MeasureInBasisOp = module.ops.last + assert isinstance(measure_in_basis_op, mbqc.MeasureInBasisOp) + + assert measure_in_basis_op.properties["plane"].data == plane + + if postselect: + assert measure_in_basis_op.properties["postselect"].value.data == int(postselect[-1]) + else: + assert measure_in_basis_op.properties.get("postselect") is None + + @pytest.mark.parametrize("postselect", [-1, 2]) + def test_invalid_postselect_raises_on_verify(self, postselect): + """Test that using an invalid postselect value (a value other than 0 or 1) raises a + VerifyException during verification.""" + + program = rf""" + %angle = "test.op"() : () -> f64 + %qubit = "test.op"() : () -> !quantum.bit + + %mres, %out_qubit = mbqc.measure_in_basis [XY, %angle] %qubit postselect {postselect} : i1, !quantum.bit + """ + + ctx = Context() + + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(test.Test) + ctx.load_dialect(Quantum) + ctx.load_dialect(mbqc.MBQC) + + module = Parser(ctx, program).parse_module() + + measure_in_basis_op: mbqc.MeasureInBasisOp = module.ops.last + assert isinstance(measure_in_basis_op, mbqc.MeasureInBasisOp) + + with pytest.raises(VerifyException, match="'postselect' must be 0 or 1"): + measure_in_basis_op.verify_() + + @pytest.mark.parametrize("init_op", ["Hadamard", builtin.StringAttr(data="Hadamard")]) + @pytest.mark.parametrize("entangle_op", ["CZ", builtin.StringAttr(data="CZ")]) + def test_graph_state_prep_instantiation(self, init_op, entangle_op): + """Test the instantiation of a mbqc.graph_state_prep op.""" + adj_matrix = [1, 0, 1, 0, 0, 1] + adj_matrix_op = arith.ConstantOp( + builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(builtin.IntegerType(1), shape=(6,)), data=adj_matrix + ) + ) + graph_state_prep_op = mbqc.GraphStatePrepOp(adj_matrix_op.result, init_op, entangle_op) + + assert graph_state_prep_op.adj_matrix == adj_matrix_op.result + assert graph_state_prep_op.init_op == builtin.StringAttr(data="Hadamard") + assert graph_state_prep_op.entangle_op == builtin.StringAttr(data="CZ") diff --git a/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py new file mode 100644 index 0000000000..02ed4a6f97 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py @@ -0,0 +1,94 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/qec.py.""" + +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from catalyst.python_interface.dialects import QEC + +all_ops = list(QEC.operations) +all_attrs = list(QEC.attributes) + +expected_ops_names = { + "FabricateOp": "qec.fabricate", + "LayerOp": "qec.layer", + "PPMeasurementOp": "qec.ppm", + "PPRotationOp": "qec.ppr", + "PrepareStateOp": "qec.prepare", + "SelectPPMeasurementOp": "qec.select.ppm", + "YieldOp": "qec.yield", +} + +expected_attrs_names = { + "LogicalInit": "qec.enum", +} + + +def test_qec_dialect_name(): + """Test that the QEC dialect name is correct.""" + assert QEC.name == "qec" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert expected_name is not None, f"Unexpected operation {op_class_name} found in QEC dialect" + assert op.name == expected_name + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert expected_name is not None, f"Unexpected attribute {attr_class_name} found in QEC dialect" + assert attr.name == expected_name + + +def test_assembly_format(run_filecheck): + """Test the assembly format of the qec ops.""" + program = """ + // CHECK: [[QUBIT:%.+]] = "test.op"() : () -> !quantum.bit + %qubit = "test.op"() : () -> !quantum.bit + + // CHECK: [[COND:%.+]] = "test.op"() : () -> i1 + %cond = "test.op"() : () -> i1 + + // CHECK: [[FABRICATED:%.+]] = qec.fabricate magic : !quantum.bit + %fabricated = qec.fabricate magic : !quantum.bit + + // CHECK: [[PREPARED:%.+]] = qec.prepare zero [[QUBIT]] : !quantum.bit + %prepared = qec.prepare zero %qubit : !quantum.bit + + // CHECK: [[ROTATED:%.+]] = qec.ppr ["X", "I", "Z"](4) [[QUBIT]] : !quantum.bit + %rotated = qec.ppr ["X", "I", "Z"](4) %qubit : !quantum.bit + + // CHECK: [[MEASURED:%.+]], [[OUT_QUBITS:%.+]] = qec.ppm ["X", "I", "Z"] [[QUBIT]] : i1, !quantum.bit + %measured, %out_qubits = qec.ppm ["X", "I", "Z"] %qubit : i1, !quantum.bit + + // CHECK: [[MEASURED_COND:%.+]], [[OUT_QUBITS_COND:%.+]] = qec.ppm ["X", "I", "Z"] [[QUBIT]] cond([[COND]]) : i1, !quantum.bit + %measured_cond, %out_qubits_cond = qec.ppm ["X", "I", "Z"] %qubit cond(%cond) : i1, !quantum.bit + + // CHECK: [[SELECT_MEASURED:%.+]], [[SELECT_OUT:%.+]] = qec.select.ppm([[COND]], ["X"], ["Z"]) [[QUBIT]] : i1, !quantum.bit + %select_measured, %select_out = qec.select.ppm (%cond, ["X"], ["Z"]) %qubit : i1, !quantum.bit + + """ + + run_filecheck(program) diff --git a/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py new file mode 100644 index 0000000000..d83561a193 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py @@ -0,0 +1,694 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/quantum.py.""" + +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from xdsl.dialects.builtin import ( + I32, + ComplexType, + Float64Type, + StringAttr, + TensorType, + UnitAttr, + i1, +) +from xdsl.dialects.test import TestOp +from xdsl.ir import AttributeCovT, OpResult + +from catalyst.python_interface.dialects import Quantum +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + NamedObservableAttr, + ObservableType, + QubitType, + QuregType, +) + +all_ops = list(Quantum.operations) +all_attrs = list(Quantum.attributes) + +expected_ops_names = { + "AdjointOp": "quantum.adjoint", + "AllocOp": "quantum.alloc", + "AllocQubitOp": "quantum.alloc_qb", + "ComputationalBasisOp": "quantum.compbasis", + "CountsOp": "quantum.counts", + "CustomOp": "quantum.custom", + "DeallocOp": "quantum.dealloc", + "DeallocQubitOp": "quantum.dealloc_qb", + "DeviceInitOp": "quantum.device", + "DeviceReleaseOp": "quantum.device_release", + "ExpvalOp": "quantum.expval", + "ExtractOp": "quantum.extract", + "FinalizeOp": "quantum.finalize", + "GlobalPhaseOp": "quantum.gphase", + "HamiltonianOp": "quantum.hamiltonian", + "HermitianOp": "quantum.hermitian", + "InitializeOp": "quantum.init", + "InsertOp": "quantum.insert", + "MeasureOp": "quantum.measure", + "MultiRZOp": "quantum.multirz", + "NamedObsOp": "quantum.namedobs", + "NumQubitsOp": "quantum.num_qubits", + "PCPhaseOp": "quantum.pcphase", + "ProbsOp": "quantum.probs", + "QubitUnitaryOp": "quantum.unitary", + "SampleOp": "quantum.sample", + "SetBasisStateOp": "quantum.set_basis_state", + "SetStateOp": "quantum.set_state", + "StateOp": "quantum.state", + "TensorOp": "quantum.tensor", + "VarianceOp": "quantum.var", + "YieldOp": "quantum.yield", +} + +expected_attrs_names = { + "ObservableType": "quantum.obs", + "QubitType": "quantum.bit", + "QuregType": "quantum.reg", + "ResultType": "quantum.res", + "NamedObservableAttr": "quantum.named_observable", +} + +TestOp.__test__ = False +"""Setting this attribute silences the PytestCollectionWarning that TestOp can not be collected +for testing, because it is a class with __init__ method.""" + + +# Test function taken from xdsl/utils/test_value.py +def create_ssa_value(t: AttributeCovT) -> OpResult[AttributeCovT]: + """Create a single SSA value with the given type for testing purposes.""" + op = TestOp(result_types=(t,)) + return op.results[0] + + +q0 = create_ssa_value(QubitType()) +q1 = create_ssa_value(QubitType()) +q2 = create_ssa_value(QubitType()) +qreg = create_ssa_value(QuregType()) +theta = create_ssa_value(Float64Type()) +dim = create_ssa_value(Float64Type()) +pauli_x = NamedObservableAttr("PauliX") +obs = create_ssa_value(ObservableType()) +i = create_ssa_value(I32) +matrix = create_ssa_value(TensorType(element_type=Float64Type, shape=(2, 2))) +coeffs = create_ssa_value(TensorType(Float64Type(), shape=(10,))) +samples = create_ssa_value(TensorType(Float64Type(), shape=(8, 7))) +basis_state = create_ssa_value(TensorType(i1, shape=(8,))) +state = create_ssa_value(TensorType(ComplexType(Float64Type()), shape=(16,))) + +expected_ops_init_kwargs = { + "AdjointOp": {"qreg": qreg, "region": (CustomOp(gate_name="CNOT", in_qubits=(q0, q1)),)}, + "AllocOp": {"nqubits": 3}, + "AllocQubitOp": {}, + "ComputationalBasisOp": {"operands": (q0, None), "result_types": (obs,)}, + "CountsOp": { + "operands": (obs, i, None, None), + "result_types": (TensorType(Float64Type(), shape=(1,)), TensorType(I32, shape=(1,))), + }, + "CustomOp": { + "gate_name": "RX", + "in_qubits": (q0, q1), + "in_ctrl_qubits": (q2,), + "params": (theta,), + "adjoint": True, + }, + "DeallocOp": {"qreg": qreg}, + "DeallocQubitOp": {"qubit": q0}, + "DeviceInitOp": { + "operands": (i,), + "properties": {"lib": StringAttr("lib"), "device_name": StringAttr("my_device")}, + }, + "DeviceReleaseOp": {}, + "ExpvalOp": {"obs": obs}, + "ExtractOp": {"qreg": qreg, "idx": i}, + "FinalizeOp": {}, + "GlobalPhaseOp": {"params": theta, "in_ctrl_qubits": q0}, + "HamiltonianOp": {"operands": (coeffs, (obs,)), "result_types": (obs,)}, + "HermitianOp": {"operands": (matrix, (q0, q1)), "result_types": (obs,)}, + "InitializeOp": {}, + "InsertOp": {"in_qreg": qreg, "idx": i, "qubit": q1}, + "MeasureOp": {"in_qubit": q0, "postselect": i}, + "MultiRZOp": { + "theta": theta, + "in_qubits": (q1, q0), + "in_ctrl_qubits": (q2,), + "in_ctrl_values": (i,), + "adjoint": UnitAttr(), + }, + "NamedObsOp": {"qubit": q0, "obs_type": pauli_x}, + "NumQubitsOp": {"result_types": (i,)}, + "PCPhaseOp": { + "theta": theta, + "dim": dim, + "in_qubits": (q1, q0), + "in_ctrl_qubits": (q2,), + "in_ctrl_values": (i,), + "adjoint": False, + }, + "ProbsOp": { + "operands": (obs, i, None), + "result_types": (TensorType(Float64Type(), shape=(8,)),), + }, + "QubitUnitaryOp": {"matrix": matrix, "in_qubits": (q2,), "adjoint": True}, + "SampleOp": {"operands": (obs, i, samples), "result_types": (samples,)}, + "SetBasisStateOp": {"operands": (basis_state, (q0, q2)), "result_types": ((q1, q2),)}, + "SetStateOp": {"operands": (state, (q0, q1)), "result_types": ((q0, q1),)}, + "StateOp": {"operands": (obs, i, state), "result_types": (state,)}, + "TensorOp": {"operands": ((obs, obs),), "result_types": (obs,)}, + "VarianceOp": {"obs": (obs,)}, + "YieldOp": {"operands": (qreg,)}, +} + + +def test_quantum_dialect_name(): + """Test that the QuantumDialect name is correct.""" + assert Quantum.name == "quantum" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert ( + expected_name is not None + ), f"Unexpected operation {op_class_name} found in QuantumDialect" + assert op.name == expected_name + + +def test_only_existing_operations_are_expected(): + """Test that the expected operations above only contain existing operations.""" + existing_ops_names = {op.__name__ for op in all_ops} + assert existing_ops_names == set(expected_ops_names) + + +@pytest.mark.parametrize("op", all_ops) +def test_operation_construction(op): + """Test the constructors of operations in the Quantum dialect.""" + kwargs = expected_ops_init_kwargs[op.__name__] + _ = op(**kwargs) + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert ( + expected_name is not None + ), f"Unexpected attribute {attr_class_name} found in QuantumDialect" + assert attr.name == expected_name + + +def test_only_existing_attributes_are_expected(): + """Test that the expected attributes above only contain existing attributes.""" + existing_attrs_names = {attr.__name__ for attr in all_attrs} + assert existing_attrs_names == set(expected_attrs_names) + + +class TestAssemblyFormat: + """Lit tests for assembly format of operations/attributes in the Quantum + dialect.""" + + def test_qubit_qreg_operations(self, run_filecheck): + """Test that the assembly format for operations for allocation/deallocation of + qubits/quantum registers works correctly.""" + + # Tests for allocation/deallocation ops: AllocOp, DeallocOp, AllocQubitOp, DeallocQubitOp + # Tests for extraction/insertion ops: ExtractOp, InsertOp + program = """ + ////////////////// **Allocation of register with dynamic number of wires** ////////////////// + // CHECK: [[NQUBITS:%.+]] = "test.op"() : () -> i64 + // CHECK: [[QREG_DYN:%.+]] = quantum.alloc([[NQUBITS]]) : !quantum.reg + %nqubits = "test.op"() : () -> i64 + %qreg_dynamic = quantum.alloc(%nqubits) : !quantum.reg + + ////////////////// **Deallocation of dynamic register** ////////////////// + // CHECK: quantum.dealloc [[QREG_DYN]] : !quantum.reg + quantum.dealloc %qreg_dynamic : !quantum.reg + + ////////////////// **Allocation of register with static number of wires** ////////////////// + // CHECK: [[QREG_STATIC:%.+]] = quantum.alloc(10) : !quantum.reg + %qreg_static = quantum.alloc(10) : !quantum.reg + + ////////////////// **Deallocation of static register** ////////////////// + // CHECK: quantum.dealloc [[QREG_STATIC]] : !quantum.reg + quantum.dealloc %qreg_static : !quantum.reg + + ////////////////// **Dynamic qubit allocation** ////////////////// + // CHECK: [[DYN_QUBIT:%.+]] = quantum.alloc_qb : !quantum.bit + %dyn_qubit = quantum.alloc_qb : !quantum.bit + + ////////////////// **Dynamic qubit deallocation** ////////////////// + // CHECK: quantum.dealloc_qb [[DYN_QUBIT]] : !quantum.bit + quantum.dealloc_qb %dyn_qubit : !quantum.bit + + ////////////////////////////////////////////////////// + ////////////////// Quantum register ////////////////// + ////////////////////////////////////////////////////// + // CHECK: [[QREG:%.+]] = "test.op"() : () -> !quantum.reg + %qreg = "test.op"() : () -> !quantum.reg + + ////////////////// **Static qubit extraction** ////////////////// + // CHECK: [[STATIC_QUBIT:%.+]] = quantum.extract [[QREG]][[[STATIC_INDEX:0]]] : !quantum.reg -> !quantum.bit + %static_qubit = quantum.extract %qreg[0] : !quantum.reg -> !quantum.bit + + ////////////////// **Dynamic qubit extraction** ////////////////// + // CHECK: [[DYN_INDEX:%.+]] = "test.op"() : () -> i64 + // CHECK: [[DYN_QUBIT1:%.+]] = quantum.extract [[QREG]][[[DYN_INDEX]]] : !quantum.reg -> !quantum.bit + %dyn_index = "test.op"() : () -> i64 + %dyn_qubit1 = quantum.extract %qreg[%dyn_index] : !quantum.reg -> !quantum.bit + + ////////////////// **Static qubit insertion** ////////////////// + // CHECK: [[QREG1:%.+]] = quantum.insert [[QREG]][[[STATIC_INDEX]]], [[STATIC_QUBIT]] : !quantum.reg, !quantum.bit + %qreg1 = quantum.insert %qreg[0], %static_qubit : !quantum.reg, !quantum.bit + + ////////////////// **Dynamic qubit insertion** ////////////////// + // CHECK: quantum.insert [[QREG1]][[[DYN_INDEX]]], [[DYN_QUBIT1]] : !quantum.reg, !quantum.bit + %qreg2 = quantum.insert %qreg1[%dyn_index], %dyn_qubit1 : !quantum.reg, !quantum.bit + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_quantum_ops(self, run_filecheck): + """Test that the assembly format for quantum non-terminal operations works correctly.""" + + # Tests for CustomOp, GlobalPhaseOp, MeasureOp, MultiRZOp, QubitUnitaryOp + program = """ + //////////////////////////////////////////////////////////////////////// + ////////////////// Qubits, params, and control values ////////////////// + //////////////////////////////////////////////////////////////////////// + ////////////////// **Qubits** ////////////////// + // CHECK: [[Q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q2:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q3:%.+]] = "test.op"() : () -> !quantum.bit + %q0 = "test.op"() : () -> !quantum.bit + %q1 = "test.op"() : () -> !quantum.bit + %q2 = "test.op"() : () -> !quantum.bit + %q3 = "test.op"() : () -> !quantum.bit + + ////////////////// **Params** ////////////////// + // CHECK: [[PARAM1:%.+]] = "test.op"() : () -> f64 + // CHECK: [[PARAM2:%.+]] = "test.op"() : () -> f64 + // CHECK: [[MAT_TENSOR:%.+]] = "test.op"() : () -> tensor<4x4xcomplex> + // CHECK: [[MAT_MEMREF:%.+]] = "test.op"() : () -> memref<4x4xcomplex> + %param1 = "test.op"() : () -> f64 + %param2 = "test.op"() : () -> f64 + %mat_tensor = "test.op"() : () -> tensor<4x4xcomplex> + %mat_memref = "test.op"() : () -> memref<4x4xcomplex> + + ////////////////// **Control values** ////////////////// + // CHECK: [[TRUE_CST:%.+]] = "test.op"() : () -> i1 + // CHECK: [[FALSE_CST:%.+]] = "test.op"() : () -> i1 + %true_cst = "test.op"() : () -> i1 + %false_cst = "test.op"() : () -> i1 + + /////////////////////////////////////////////////////////////////////// + ///////////////////////// **Operation tests** ///////////////////////// + /////////////////////////////////////////////////////////////////////// + + ////////////////// **CustomOp tests** ////////////////// + // No params, no control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.custom "Gate"() [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qc1, %qc2 = quantum.custom "Gate"() %q0, %q1 : !quantum.bit, !quantum.bit + + // Params, no control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.custom "ParamGate"([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qc3, %qc4 = quantum.custom "ParamGate"(%param1, %param2) %q0, %q1 : !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}} = quantum.custom "ControlledGate"() [[Q0]] ctrls([[Q1]]) ctrlvals([[TRUE_CST]]) : !quantum.bit ctrls !quantum.bit + %qc5, %qc6 = quantum.custom "ControlledGate"() %q0 ctrls(%q1) ctrlvals(%true_cst) : !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}} = quantum.custom "AdjGate"() [[Q0]] adj : !quantum.bit + %qc8 = quantum.custom "AdjGate"() %q0 adj : !quantum.bit + + ////////////////// **GlobalPhaseOp tests** ////////////////// + // No control wires + // CHECK: quantum.gphase([[PARAM1]]) : + quantum.gphase(%param1) : + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}} = quantum.gphase([[PARAM1]]) ctrls([[Q0]], [[Q1]]) ctrlvals([[FALSE_CST]], [[TRUE_CST]]) : !quantum.bit, !quantum.bit + %qg1, %qg2 = quantum.gphase(%param1) ctrls(%q0, %q1) ctrlvals(%false_cst, %true_cst) : !quantum.bit, !quantum.bit + + // Adjoint + // CHECK: {{%.+}} = quantum.gphase([[PARAM1]]) {adjoint} ctrls([[Q0]]) ctrlvals([[TRUE_CST]]) : !quantum.bit + %qg3 = quantum.gphase(%param1) {adjoint} ctrls(%q0) ctrlvals(%true_cst) : !quantum.bit + + ////////////////// **MultiRZOp tests** ////////////////// + // No control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.multirz([[PARAM1]]) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qm1, %qm2 = quantum.multirz(%param1) %q0, %q1 : !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}}, {{%.+}} = quantum.multirz([[PARAM1]]) [[Q0]], [[Q1]] ctrls([[Q2]]) ctrlvals([[TRUE_CST]]) : !quantum.bit, !quantum.bit + %qm3, %qm4, %qm5 = quantum.multirz(%param1) %q0, %q1 ctrls(%q2) ctrlvals(%true_cst) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}}, {{%.+}} = quantum.multirz([[PARAM1]]) [[Q0]], [[Q1]] adj : !quantum.bit, !quantum.bit + %qm6, %qm7 = quantum.multirz(%param1) %q0, %q1 adj : !quantum.bit, !quantum.bit + + ////////////////// **PCPhaseOp tests** ////////////////// + // No control wires + // CHECK: {{%.+}}, {{%.+}}, {{%.+}} = quantum.pcphase([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]], [[Q2]] : !quantum.bit, !quantum.bit, !quantum.bit + %qp1, %qp2, %qp3 = quantum.pcphase(%param1, %param2) %q0, %q1, %q2 : !quantum.bit, !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}}, {{%.+}} = quantum.pcphase([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]] ctrls([[Q2]]) ctrlvals([[TRUE_CST]]) : !quantum.bit, !quantum.bit ctrls !quantum.bit + %qp4, %qp5, %qp6 = quantum.pcphase(%param1, %param2) %q0, %q1 ctrls(%q2) ctrlvals(%true_cst) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}}, {{%.+}} = quantum.pcphase([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]] adj : !quantum.bit, !quantum.bit + %qp7, %qp8 = quantum.pcphase(%param1, %param2) %q0, %q1 adj : !quantum.bit, !quantum.bit + + ////////////////// **QubitUnitaryOp tests** ////////////////// + // No control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.unitary([[MAT_TENSOR]] : tensor<4x4xcomplex>) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qb1, %qb2 = quantum.unitary(%mat_tensor : tensor<4x4xcomplex>) %q0, %q1 : !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}} {{%.+}} = quantum.unitary([[MAT_TENSOR]] : tensor<4x4xcomplex>) [[Q0]], [[Q1]] ctrls([[Q2]]) ctrlvals([[FALSE_CST]]) : !quantum.bit, !quantum.bit ctrls !quantum.bit + %qb3, %qb4, %qb5 = quantum.unitary(%mat_tensor : tensor<4x4xcomplex>) %q0, %q1 ctrls(%q2) ctrlvals(%false_cst) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}}, {{%.+}} = quantum.unitary([[MAT_TENSOR]] : tensor<4x4xcomplex>) [[Q0]], [[Q1]] adj : !quantum.bit, !quantum.bit + %qb6, %qb7 = quantum.unitary(%mat_tensor : tensor<4x4xcomplex>) %q0, %q1 adj : !quantum.bit, !quantum.bit + + // MemRef + // CHECK: {{%.+}}, {{%.+}} = quantum.unitary([[MAT_MEMREF]] : memref<4x4xcomplex>) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qb8, %qb9 = quantum.unitary(%mat_memref : memref<4x4xcomplex>) %q0, %q1 : !quantum.bit, !quantum.bit + + ////////////////// **MeasureOp tests** ////////////////// + // No postselection + // CHECK: {{%.+}}, {{%.+}} = quantum.measure [[Q0]] : i1, !quantum.bit + %mres1, %mqubit1 = quantum.measure %q0 : i1, !quantum.bit + + // Postselection + // CHECK: {{%.+}}, {{%.+}} = quantum.measure [[Q1]] postselect 0 : i1, !quantum.bit + // CHECK: {{%.+}}, {{%.+}} = quantum.measure [[Q2]] postselect 1 : i1, !quantum.bit + %mres2, %mqubit2 = quantum.measure %q1 postselect 0 : i1, !quantum.bit + %mres3, %mqubit3 = quantum.measure %q2 postselect 1 : i1, !quantum.bit + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_state_prep(self, run_filecheck): + """Test that the assembly format for state prep operations works correctly.""" + + # Tests for SetBasisStateOp, SetStateOp + program = """ + //////////////////////////////////////////// + ////////////////// Qubits ////////////////// + //////////////////////////////////////////// + // CHECK: [[Q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q1:%.+]] = "test.op"() : () -> !quantum.bit + %q0 = "test.op"() : () -> !quantum.bit + %q1 = "test.op"() : () -> !quantum.bit + + ////////////////// **SetBasisStateOp tests** ////////////////// + // Basis state containers + // CHECK: [[BASIS_TENSOR:%.+]] = "test.op"() : () -> tensor<2xi1> + // CHECK: [[BASIS_MEMREF:%.+]] = "test.op"() : () -> memref<2xi1> + %basis_tensor = "test.op"() : () -> tensor<2xi1> + %basis_memref = "test.op"() : () -> memref<2xi1> + + // Basis state operations + // CHECK: [[Q2:%.+]], [[Q3:%.+]] = quantum.set_basis_state([[BASIS_TENSOR]]) [[Q0]], [[Q1]] : (tensor<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + // CHECK: [[Q4:%.+]], [[Q5:%.+]] = quantum.set_basis_state([[BASIS_MEMREF]]) [[Q2]], [[Q3]] : (memref<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q2, %q3 = quantum.set_basis_state(%basis_tensor) %q0, %q1 : (tensor<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q4, %q5 = quantum.set_basis_state(%basis_memref) %q2, %q3 : (memref<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + + ////////////////// **SetStateOp tests** ////////////////// + // State vector containers + // CHECK: [[STATE_TENSOR:%.+]] = "test.op"() : () -> tensor<4xcomplex> + // CHECK: [[STATE_MEMREF:%.+]] = "test.op"() : () -> memref<4xcomplex> + %state_tensor = "test.op"() : () -> tensor<4xcomplex> + %state_memref = "test.op"() : () -> memref<4xcomplex> + + // State prep operations + // CHECK: [[Q6:%.+]], [[Q7:%.+]] = quantum.set_state([[STATE_TENSOR]]) [[Q4]], [[Q5]] : (tensor<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + // CHECK: quantum.set_state([[STATE_MEMREF]]) [[Q6]], [[Q7]] : (memref<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q6, %q7 = quantum.set_state(%state_tensor) %q4, %q5 : (tensor<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q8, %q9 = quantum.set_state(%state_memref) %q6, %q7 : (memref<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_observables(self, run_filecheck): + """Test that the assembly format for observable operations works correctly.""" + + # Tests for observables: ComputationalBasisOp, HamiltonianOp, HermitianOp, + # NamedObsOp, TensorOp + program = """ + ////////////////////////////////////////////////////// + //////////// Quantum register and qubits //////////// + ////////////////////////////////////////////////////// + // CHECK: [[QREG:%.+]] = "test.op"() : () -> !quantum.reg + %qreg = "test.op"() : () -> !quantum.reg + + // CHECK: [[Q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q2:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q3:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q4:%.+]] = "test.op"() : () -> !quantum.bit + %q0 = "test.op"() : () -> !quantum.bit + %q1 = "test.op"() : () -> !quantum.bit + %q2 = "test.op"() : () -> !quantum.bit + %q3 = "test.op"() : () -> !quantum.bit + %q4 = "test.op"() : () -> !quantum.bit + + ////////////////////////////////////////////// + //////////// **Observable tests** //////////// + ////////////////////////////////////////////// + + //////////// **NamedObsOp** //////////// + // CHECK: [[X_OBS:%.+]] = quantum.namedobs [[Q0]][PauliX] : !quantum.obs + // CHECK: [[Y_OBS:%.+]] = quantum.namedobs [[Q1]][PauliY] : !quantum.obs + // CHECK: [[Z_OBS:%.+]] = quantum.namedobs [[Q2]][PauliZ] : !quantum.obs + // CHECK: [[H_OBS:%.+]] = quantum.namedobs [[Q3]][Hadamard] : !quantum.obs + // CHECK: [[I_OBS:%.+]] = quantum.namedobs [[Q4]][Identity] : !quantum.obs + %x_obs = quantum.namedobs %q0[PauliX] : !quantum.obs + %y_obs = quantum.namedobs %q1[PauliY] : !quantum.obs + %z_obs = quantum.namedobs %q2[PauliZ] : !quantum.obs + %h_obs = quantum.namedobs %q3[Hadamard] : !quantum.obs + %i_obs = quantum.namedobs %q4[Identity] : !quantum.obs + + //////////// **HermitianOp** //////////// + // Create tensor/memref + // CHECK: [[HERM_TENSOR:%.+]] = "test.op"() : () -> tensor<2x2xcomplex> + // CHECK: [[HERM_MEMREF:%.+]] = "test.op"() : () -> memref<2x2xcomplex> + %herm_tensor = "test.op"() : () -> tensor<2x2xcomplex> + %herm_memref = "test.op"() : () -> memref<2x2xcomplex> + + // Create Hermitians + // CHECK: [[HERM1:%.+]] = quantum.hermitian([[HERM_TENSOR]] : tensor<2x2xcomplex>) [[Q0]] : !quantum.obs + // CHECK: [[HERM2:%.+]] = quantum.hermitian([[HERM_MEMREF]] : memref<2x2xcomplex>) [[Q1]] : !quantum.obs + %herm1 = quantum.hermitian(%herm_tensor : tensor<2x2xcomplex>) %q0 : !quantum.obs + %herm2 = quantum.hermitian(%herm_memref : memref<2x2xcomplex>) %q1 : !quantum.obs + + //////////// **TensorOp** //////////// + // CHECK: [[TENSOR_OBS:%.+]] = quantum.tensor [[X_OBS]], [[HERM2]], [[I_OBS]] : !quantum.obs + %tensor_obs = quantum.tensor %x_obs, %herm2, %i_obs : !quantum.obs + + //////////// **HamiltonianOp** //////////// + // Create tensor/memref + // CHECK: [[HAM_TENSOR:%.+]] = "test.op"() : () -> tensor<3xf64> + // CHECK: [[HAM_MEMREF:%.+]] = "test.op"() : () -> memref<3xf64> + %ham_tensor = "test.op"() : () -> tensor<3xf64> + %ham_memref = "test.op"() : () -> memref<3xf64> + + // Create Hamiltonians + // CHECK: {{%.+}} = quantum.hamiltonian([[HAM_TENSOR]] : tensor<3xf64>) [[TENSOR_OBS]], [[X_OBS]], [[HERM1]] : !quantum.obs + // CHECK: {{%.+}} = quantum.hamiltonian([[HAM_MEMREF]] : memref<3xf64>) [[TENSOR_OBS]], [[X_OBS]], [[HERM1]] : !quantum.obs + %ham1 = quantum.hamiltonian(%ham_tensor : tensor<3xf64>) %tensor_obs, %x_obs, %herm1 : !quantum.obs + %ham2 = quantum.hamiltonian(%ham_memref : memref<3xf64>) %tensor_obs, %x_obs, %herm1 : !quantum.obs + + //////////// **ComputationalBasisOp** //////////// + // CHECK: {{%.+}} = quantum.compbasis qubits [[Q0]], [[Q1]] : !quantum.obs + // CHECK: {{%.+}} = quantum.compbasis qreg [[QREG]] : !quantum.obs + %cb_01 = quantum.compbasis qubits %q0, %q1 : !quantum.obs + %cb_all = quantum.compbasis qreg %qreg : !quantum.obs + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_measurements(self, run_filecheck): + """Test that the assembly format for measurement operations works correctly.""" + + # Tests for measurements: CountsOp, ExpvalOp, MeasureOp, ProbsOp, SampleOp, + # StateOp, VarianceOp + program = """ + /////////////////////////////////////////////////// + //////////// Observables and constants //////////// + /////////////////////////////////////////////////// + // CHECK: [[OBS:%.+]] = "test.op"() : () -> !quantum.obs + %obs = "test.op"() : () -> !quantum.obs + + // CHECK: [[DYN_WIRES:%.+]] = "test.op"() : () -> i64 + %dyn_wires = "test.op"() : () -> i64 + // CHECK: [[DYN_SHOTS:%.+]] = "test.op"() : () -> i64 + %dyn_shots = "test.op"() : () -> i64 + + /////////////////////////////////////////////// + //////////// **Measurement tests** //////////// + /////////////////////////////////////////////// + + ///////////////////// **ExpvalOp** ///////////////////// + // CHECK: {{%.+}} = quantum.expval [[OBS]] : f64 + %expval = quantum.expval %obs : f64 + + ///////////////////// **VarianceOp** ///////////////////// + // CHECK: {{%.+}} = quantum.var [[OBS]] : f64 + %var = quantum.var %obs : f64 + + ///////////////////// **CountsOp** ///////////////////// + // Counts with static shape + // CHECK: {{%.+}}, {{%.+}} = quantum.counts [[OBS]] : tensor<6xf64>, tensor<6xi64> + %eigvals1, %counts1 = quantum.counts %obs : tensor<6xf64>, tensor<6xi64> + + // Counts with dynamic shape + // CHECK: {{%.+}}, {{%.+}} = quantum.counts [[OBS]] shape [[DYN_WIRES]] : tensor, tensor + %eigvals2, %counts2 = quantum.counts %obs shape %dyn_wires : tensor, tensor + + // Counts with no results (mutate memref in-place) + // CHECK: [[EIGVALS_IN:%.+]] = "test.op"() : () -> memref<16xf64> + // CHECK: [[COUNTS_IN:%.+]] = "test.op"() : () -> memref<16xi64> + // CHECK: quantum.counts [[OBS]] in([[EIGVALS_IN]] : memref<16xf64>, [[COUNTS_IN]] : memref<16xi64>) + %eigvals_in = "test.op"() : () -> memref<16xf64> + %counts_in = "test.op"() : () -> memref<16xi64> + quantum.counts %obs in(%eigvals_in : memref<16xf64>, %counts_in : memref<16xi64>) + + ///////////////////// **ProbsOp** ///////////////////// + // Probs with static shape + // CHECK: {{%.+}} = quantum.probs [[OBS]] : tensor<8xf64> + %probs1 = quantum.probs %obs : tensor<8xf64> + + // Probs with dynamic shape + // CHECK: {{%.+}} = quantum.probs [[OBS]] shape [[DYN_WIRES]] : tensor + %probs2 = quantum.probs %obs shape %dyn_wires : tensor + + // Probs with no results (mutate memref in-place) + // CHECK: [[PROBS_IN:%.+]] = "test.op"() : () -> memref<16xf64> + // CHECK: quantum.probs [[OBS]] in([[PROBS_IN]] : memref<16xf64>) + %probs_in = "test.op"() : () -> memref<16xf64> + quantum.probs %obs in(%probs_in : memref<16xf64>) + + ///////////////////// **StateOp** ///////////////////// + // State with static shape + // CHECK: {{%.+}} = quantum.state [[OBS]] : tensor<8xcomplex> + %state1 = quantum.state %obs : tensor<8xcomplex> + + // State with dynamic shape + // CHECK: {{%.+}} = quantum.state [[OBS]] shape [[DYN_WIRES]] : tensor> + %state2 = quantum.state %obs shape %dyn_wires : tensor> + + // State with no results (mutate memref in-place) + // CHECK: [[STATE_IN:%.+]] = "test.op"() : () -> memref<16xcomplex> + // CHECK: quantum.state [[OBS]] in([[STATE_IN]] : memref<16xcomplex>) + %state_in = "test.op"() : () -> memref<16xcomplex> + quantum.state %obs in(%state_in : memref<16xcomplex>) + + ///////////////////// **SampleOp** ///////////////////// + // Samples with static shape + // CHECK: {{%.+}} = quantum.sample [[OBS]] : tensor<10x3xf64> + %samples1 = quantum.sample %obs : tensor<10x3xf64> + + // Samples with dynamic wires + // CHECK: {{%.+}} = quantum.sample [[OBS]] shape [[DYN_WIRES]] : tensor<10x?xf64> + %samples2 = quantum.sample %obs shape %dyn_wires : tensor<10x?xf64> + + // Samples with dynamic shots + // CHECK: {{%.+}} = quantum.sample [[OBS]] shape [[DYN_SHOTS]] : tensor + %samples3 = quantum.sample %obs shape %dyn_shots : tensor + + // Samples with dynamic wires and shots + // CHECK: {{%.+}} = quantum.sample [[OBS]] shape [[DYN_SHOTS]], [[DYN_WIRES]] : tensor + %samples4 = quantum.sample %obs shape %dyn_shots, %dyn_wires : tensor + + // Samples with no results (mutate memref in-place) + // CHECK: [[SAMPLES_IN:%.+]] = "test.op"() : () -> memref<7x4xf64> + // CHECK: quantum.sample [[OBS]] in([[SAMPLES_IN]] : memref<7x4xf64>) + %samples_in = "test.op"() : () -> memref<7x4xf64> + quantum.sample %obs in(%samples_in : memref<7x4xf64>) + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_miscellaneous_operations(self, run_filecheck): + """Test that the assembly format for miscelleneous operations + works correctly.""" + + # Tests for AdjointOp, DeviceInitOp, DeviceReleaseOp, FinalizeOp, InitializeOp, + # NumQubitsOp, YieldOp + program = """ + ////////////////////////////////////////// + //////////// Quantum register //////////// + ////////////////////////////////////////// + // CHECK: [[QREG:%.+]] = "test.op"() : () -> !quantum.reg + %qreg = "test.op"() : () -> !quantum.reg + + //////////// **AdjointOp and YieldOp tests** //////////// + // CHECK: quantum.adjoint([[QREG]]) : !quantum.reg { + // CHECK-NEXT: ^bb0([[ARG_QREG:%.+]] : !quantum.reg): + // CHECK-NEXT: quantum.yield [[ARG_QREG]] : !quantum.reg + // CHECK-NEXT: } + %qreg1 = quantum.adjoint(%qreg) : !quantum.reg { + ^bb0(%arg_qreg: !quantum.reg): + quantum.yield %arg_qreg : !quantum.reg + } + + //////////// **DeviceInitOp tests** //////////// + // Integer SSA value for shots + // CHECK: [[SHOTS:%.+]] = "test.op"() : () -> i64 + %shots = "test.op"() : () -> i64 + + // No auto qubit management + // CHECK: quantum.device shots([[SHOTS]]) ["foo", "bar", "baz"] + quantum.device shots(%shots) ["foo", "bar", "baz"] + + // Auto qubit management + // CHECK: quantum.device shots([[SHOTS]]) ["foo", "bar", "baz"] {auto_qubit_management} + quantum.device shots(%shots) ["foo", "bar", "baz"] {auto_qubit_management} + + //////////// **DeviceReleaseOp tests** //////////// + // CHECK: quantum.device_release + quantum.device_release + + //////////// **FinalizeOp tests** //////////// + // CHECK: quantum.finalize + quantum.finalize + + //////////// **InitializeOp tests** //////////// + // CHECK: quantum.init + quantum.init + + //////////// **NumQubitsOp tests** //////////// + // CHECK: quantum.num_qubits : i64 + %nqubits = quantum.num_qubits : i64 + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py new file mode 100644 index 0000000000..c63ed80e44 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py @@ -0,0 +1,959 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/stablehlo.py.""" +# pylint: disable=line-too-long +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + + +def test_all_unary_operations(run_filecheck): + """Test all unary elementwise operations.""" + program = r""" + // CHECK: %[[tf32:.*]] = "test.op"() : () -> tensor + %tf32 = "test.op"() : () -> tensor + + // CHECK: %[[tf64:.*]] = "test.op"() : () -> tensor + %tf64 = "test.op"() : () -> tensor + + // CHECK: %[[tcomplex:.*]] = "test.op"() : () -> tensor> + %tcomplex = "test.op"() : () -> tensor> + + // CHECK: %convert = "stablehlo.convert"(%[[tf32]]) : (tensor) -> tensor + %convert = "stablehlo.convert"(%tf32) : (tensor) -> tensor + + // CHECK: %cos = "stablehlo.cosine"(%[[tf32]]) : (tensor) -> tensor + %cos = "stablehlo.cosine"(%tf32) : (tensor) -> tensor + + // CHECK: %exp = "stablehlo.exponential"(%[[tf32]]) : (tensor) -> tensor + %exp = "stablehlo.exponential"(%tf32) : (tensor) -> tensor + + // CHECK: %exponential_minus_one = "stablehlo.exponential_minus_one"(%[[tf32]]) : (tensor) -> tensor + %exponential_minus_one = "stablehlo.exponential_minus_one"(%tf32) : (tensor) -> tensor + + // CHECK: %floor = "stablehlo.floor"(%[[tf64]]) : (tensor) -> tensor + %floor = "stablehlo.floor"(%tf64) : (tensor) -> tensor + + // CHECK: %imag = "stablehlo.imag"(%[[tcomplex]]) : (tensor>) -> tensor + %imag = "stablehlo.imag"(%tcomplex) : (tensor>) -> tensor + + // CHECK: %is_finite = "stablehlo.is_finite"(%[[tf32]]) : (tensor) -> tensor + %is_finite = "stablehlo.is_finite"(%tf32) : (tensor) -> tensor + + // CHECK: %log = "stablehlo.log"(%[[tf32]]) : (tensor) -> tensor + %log = "stablehlo.log"(%tf32) : (tensor) -> tensor + + // CHECK: %log_plus_one = "stablehlo.log_plus_one"(%[[tf64]]) : (tensor) -> tensor + %log_plus_one = "stablehlo.log_plus_one"(%tf64) : (tensor) -> tensor + + // CHECK: %logistic = "stablehlo.logistic"(%[[tf32]]) : (tensor) -> tensor + %logistic = "stablehlo.logistic"(%tf32) : (tensor) -> tensor + + // CHECK: %negate = "stablehlo.negate"(%[[tf32]]) : (tensor) -> tensor + %negate = "stablehlo.negate"(%tf32) : (tensor) -> tensor + + // CHECK: %real = "stablehlo.real"(%[[tcomplex]]) : (tensor>) -> tensor + %real = "stablehlo.real"(%tcomplex) : (tensor>) -> tensor + + // CHECK: %round_afz = "stablehlo.round_nearest_afz"(%[[tf64]]) : (tensor) -> tensor + %round_afz = "stablehlo.round_nearest_afz"(%tf64) : (tensor) -> tensor + + // CHECK: %round_even = "stablehlo.round_nearest_even"(%[[tf64]]) : (tensor) -> tensor + %round_even = "stablehlo.round_nearest_even"(%tf64) : (tensor) -> tensor + + // CHECK: %rsqrt = "stablehlo.rsqrt"(%[[tf32]]) : (tensor) -> tensor + %rsqrt = "stablehlo.rsqrt"(%tf32) : (tensor) -> tensor + + // CHECK: %sign = "stablehlo.sign"(%[[tf32]]) : (tensor) -> tensor + %sign = "stablehlo.sign"(%tf32) : (tensor) -> tensor + + // CHECK: %sin = "stablehlo.sine"(%[[tf32]]) : (tensor) -> tensor + %sin = "stablehlo.sine"(%tf32) : (tensor) -> tensor + + // CHECK: %sqrt = "stablehlo.sqrt"(%[[tf32]]) : (tensor) -> tensor + %sqrt = "stablehlo.sqrt"(%tf32) : (tensor) -> tensor + + // CHECK: %tan = "stablehlo.tan"(%[[tf64]]) : (tensor) -> tensor + %tan = "stablehlo.tan"(%tf64) : (tensor) -> tensor + + // CHECK: %tanh = "stablehlo.tanh"(%[[tf32]]) : (tensor) -> tensor + %tanh = "stablehlo.tanh"(%tf32) : (tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_all_binary_operations(run_filecheck): + """Test all binary elementwise operations.""" + program = r""" + // CHECK: %[[tf32_1:.*]] = "test.op"() : () -> tensor + %tf32_1 = "test.op"() : () -> tensor + + // CHECK: %[[tf32_2:.*]] = "test.op"() : () -> tensor + %tf32_2 = "test.op"() : () -> tensor + + // CHECK: %[[tf64_1:.*]] = "test.op"() : () -> tensor + %tf64_1 = "test.op"() : () -> tensor + + // CHECK: %[[tf64_2:.*]] = "test.op"() : () -> tensor + %tf64_2 = "test.op"() : () -> tensor + + // CHECK: %complex = "stablehlo.complex"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor> + %complex = "stablehlo.complex"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor> + + // CHECK: %divide = "stablehlo.divide"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %divide = "stablehlo.divide"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %maximum = "stablehlo.maximum"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %maximum = "stablehlo.maximum"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %minimum = "stablehlo.minimum"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %minimum = "stablehlo.minimum"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %power = "stablehlo.power"(%[[tf64_1]], %[[tf64_2]]) : (tensor, tensor) -> tensor + %power = "stablehlo.power"(%tf64_1, %tf64_2) : (tensor, tensor) -> tensor + + // CHECK: %remainder = "stablehlo.remainder"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %remainder = "stablehlo.remainder"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_all_other_operations(run_filecheck): + """Test all other elementwise operations.""" + program = r""" + // CHECK: %[[tf32:.*]] = "test.op"() : () -> tensor + %tf32 = "test.op"() : () -> tensor + + // CHECK: %[[tf64:.*]] = "test.op"() : () -> tensor + %tf64 = "test.op"() : () -> tensor + + // CHECK: %[[ti1:.*]] = "test.op"() : () -> tensor + %ti1 = "test.op"() : () -> tensor + + // CHECK: %clamp = "stablehlo.clamp"(%[[tf32]], %[[tf32]], %[[tf32]]) : (tensor, tensor, tensor) -> tensor + %clamp = "stablehlo.clamp"(%tf32, %tf32, %tf32) : (tensor, tensor, tensor) -> tensor + + // CHECK: %compare = stablehlo.compare EQ, %[[tf32]], %[[tf32]] : (tensor, tensor) -> tensor + %compare = "stablehlo.compare"(%tf32, %tf32) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + + // CHECK: %map = "stablehlo.map"(%[[tf32]], %[[tf32]]) ({ + // CHECK: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK: %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%0) : (tensor) -> () + // CHECK: }) {dimensions = array} : (tensor, tensor) -> tensor + %map = "stablehlo.map"(%tf32, %tf32) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + dimensions = array + } : (tensor, tensor) -> tensor + + // CHECK: %reduce_precision = "stablehlo.reduce_precision"(%[[tf64]]) {exponent_bits = 5 : i32, mantissa_bits = 10 : i32} : (tensor) -> tensor + %reduce_precision = "stablehlo.reduce_precision"(%tf64) {exponent_bits = 5 : i32, mantissa_bits = 10 : i32} : (tensor) -> tensor + + // CHECK: %select = "stablehlo.select"(%[[ti1]], %[[tf32]], %[[tf32]]) : (tensor, tensor, tensor) -> tensor + %select = "stablehlo.select"(%ti1, %tf32, %tf32) : (tensor, tensor, tensor) -> tensor + + // CHECK: %constant1 = "stablehlo.constant"() <{value = dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + %constant1 = "stablehlo.constant"() <{value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_shape_mismatch(run_filecheck): + """Test that operations with shape mismatches are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + %tf64_3x2 = "test.op"() : () -> tensor<3x2xf64> + + // This should fail verification due to shape mismatch + %convert = "stablehlo.convert"(%tf32_2x3) : (tensor<2x3xf32>) -> tensor<3x2xf64> + """ + + with pytest.raises( + Exception, match="all non-scalar operands/results must have the same shape and base type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_type_mismatch(run_filecheck): + """Test that operations with type mismatches are properly rejected.""" + program = r""" + %ti32 = "test.op"() : () -> tensor<2x3xi32> + + // This should fail verification due to type mismatch (cosine expects float/complex) + %cos = "stablehlo.cosine"(%ti32) : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match="'operand' at position 0 does not verify"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_missing_operands(run_filecheck): + """Test that operations with missing operands are properly rejected.""" + program = r""" + %result = "stablehlo.convert"() : () -> tensor<2x3xf64> + """ + + with pytest.raises(Exception, match="Expected 1 operand"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_trait_verification_failure(run_filecheck): + """Test that operations that violate trait constraints are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + %tf64_3x2 = "test.op"() : () -> tensor<3x2xf64> + + // This should fail verification due to shape mismatch between operands + %complex = "stablehlo.complex"(%tf32_2x3, %tf64_3x2) : (tensor<2x3xf32>, tensor<3x2xf64>) -> tensor<2x3xcomplex> + """ + + with pytest.raises(Exception, match="requires the same shape"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_operand_result_shape_mismatch(run_filecheck): + """Test that operations with operand vs result shape mismatches are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + + // This should fail verification due to shape mismatch between operand and result + %convert = "stablehlo.convert"(%tf32_2x3) : (tensor<2x3xf32>) -> tensor<3x2xf64> + """ + + with pytest.raises( + Exception, match="all non-scalar operands/results must have the same shape and base type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_control_flow_operations(run_filecheck): + """Test the IfOp operation.""" + program = r""" + // Test IfOp: + + // CHECK: %[[pred:.*]] = "test.op"() : () -> tensor + %pred = "test.op"() : () -> tensor + + // CHECK: %[[result:.*]] = "stablehlo.if"(%[[pred]]) ({ + // CHECK: "stablehlo.return"(%[[pred]]) : (tensor) -> () + // CHECK: }, { + // CHECK: "stablehlo.return"(%[[pred]]) : (tensor) -> () + // CHECK: }) : (tensor) -> tensor + %result = "stablehlo.if"(%pred) ({ + "stablehlo.return"(%pred) : (tensor) -> () + }, { + "stablehlo.return"(%pred) : (tensor) -> () + }) : (tensor) -> tensor + + // Test WhileOp: + + // CHECK: %[[init_i:.*]] = "test.op"() : () -> tensor + %init_i = "test.op"() : () -> tensor + + // CHECK: %[[init_sum:.*]] = "test.op"() : () -> tensor + %init_sum = "test.op"() : () -> tensor + + // CHECK: %[[ten:.*]] = "test.op"() : () -> tensor + %ten = "test.op"() : () -> tensor + + // CHECK: %[[one:.*]] = "test.op"() : () -> tensor + %one = "test.op"() : () -> tensor + + // CHECK: %[[results:.*]], %[[results_1:.*]] = "stablehlo.while"(%[[init_i]], %[[init_sum]]) ({ + // CHECK: ^{{.*}}(%[[arg0:.*]] : tensor, %[[arg1:.*]] : tensor): + // CHECK: %[[cond:.*]] = stablehlo.compare LT, %[[arg0]], %[[ten]] : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%[[cond]]) : (tensor) -> () + // CHECK: }, { + // CHECK: ^{{.*}}(%[[arg0_1:.*]] : tensor, %[[arg1_1:.*]] : tensor): + // CHECK: %[[new_sum:.*]] = "stablehlo.add"(%[[arg1_1]], %[[one]]) : (tensor, tensor) -> tensor + // CHECK: %[[new_i:.*]] = "stablehlo.add"(%[[arg0_1]], %[[one]]) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%[[new_i]], %[[new_sum]]) : (tensor, tensor) -> () + // CHECK: }) : (tensor, tensor) -> (tensor, tensor) + %results:2 = "stablehlo.while"(%init_i, %init_sum) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %cond = "stablehlo.compare"(%arg0, %ten) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%cond) : (tensor) -> () + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + %new_sum = "stablehlo.add"(%arg1, %one) : (tensor, tensor) -> tensor + %new_i = "stablehlo.add"(%arg0, %one) : (tensor, tensor) -> tensor + "stablehlo.return"(%new_i, %new_sum) : (tensor, tensor) -> () + }) : (tensor, tensor) -> (tensor, tensor) + + // Test OptimizationBarrierOp: + + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor + %operand = "test.op"() : () -> tensor + + // CHECK: %[[result2:.*]] = "stablehlo.optimization_barrier"(%[[operand]]) : (tensor) -> tensor + %result2 = "stablehlo.optimization_barrier"(%operand) : (tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_data_movement_operations(run_filecheck): + """Test all data movement operations.""" + program = r""" + ////////////////// Setup test operations ////////////////// + // CHECK: %[[input1:.*]] = "test.op"() : () -> tensor<3x2xi64> + %input1 = "test.op"() : () -> tensor<3x2xi64> + + // CHECK: %[[input2:.*]] = "test.op"() : () -> tensor<1x2xi64> + %input2 = "test.op"() : () -> tensor<1x2xi64> + + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor<2x3x4x2xi32> + %operand = "test.op"() : () -> tensor<2x3x4x2xi32> + + // CHECK: %[[start_indices:.*]] = "test.op"() : () -> tensor<2x2x3x2xi64> + %start_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // CHECK: %[[reshape_input:.*]] = "test.op"() : () -> tensor<2xf32> + %reshape_input = "test.op"() : () -> tensor<2xf32> + + // CHECK: %[[scatter_input:.*]] = "test.op"() : () -> tensor<2x3x4x2xi64> + %scatter_input = "test.op"() : () -> tensor<2x3x4x2xi64> + + // CHECK: %[[scatter_indices:.*]] = "test.op"() : () -> tensor<2x2x3x2xi64> + %scatter_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // CHECK: %[[scatter_updates:.*]] = "test.op"() : () -> tensor<2x2x3x2x2xi64> + %scatter_updates = "test.op"() : () -> tensor<2x2x3x2x2xi64> + + // CHECK: %[[slice_input:.*]] = "test.op"() : () -> tensor<3x4xi64> + %slice_input = "test.op"() : () -> tensor<3x4xi64> + + // CHECK: %[[broadcast_input:.*]] = "test.op"() : () -> tensor<1x3xi32> + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + ////////////////// Test ConcatenateOp ////////////////// + // CHECK: %concatenate = "stablehlo.concatenate"(%[[input1]], %[[input2]]) <{dimension = 0 : i64}> : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + %concatenate = "stablehlo.concatenate"(%input1, %input2) {dimension = 0 : i64} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + + ////////////////// Test GatherOp ////////////////// + // CHECK: %gather = "stablehlo.gather"(%[[operand]], %[[start_indices]]) + // CHECK-SAME: dimension_numbers = #stablehlo.gather< + // CHECK-NEXT: offset_dims = [3, 4], + // CHECK-NEXT: collapsed_slice_dims = [1], + // CHECK-NEXT: operand_batching_dims = [0], + // CHECK-NEXT: start_indices_batching_dims = [1], + // CHECK-NEXT: start_index_map = [2, 1], + // CHECK-NEXT: index_vector_dim = 3 + // CHECK-NEXT: slice_sizes = array, indices_are_sorted = false + %gather = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32> + + ////////////////// Test ReshapeOp ////////////////// + // CHECK: %reshape = stablehlo.reshape %[[reshape_input]] : (tensor<2xf32>) -> tensor<1x2xf32> + %reshape = "stablehlo.reshape"(%reshape_input) : (tensor<2xf32>) -> tensor<1x2xf32> + + ////////////////// Test ScatterOp ////////////////// + // CHECK: %scatter = "stablehlo.scatter"(%[[scatter_input]], %[[scatter_indices]], %[[scatter_updates]]) + // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter< + // CHECK-NEXT: update_window_dims = [3, 4], + // CHECK-NEXT: inserted_window_dims = [1], + // CHECK-NEXT: input_batching_dims = [0], + // CHECK-NEXT: scatter_indices_batching_dims = [1], + // CHECK-NEXT: scatter_dims_to_operand_dims = [2, 1], + // CHECK-NEXT: index_vector_dim = 3 + // CHECK-NEXT: indices_are_sorted = false, unique_indices = false + // CHECK-NEXT: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK-NEXT: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK-NEXT: "stablehlo.return"(%0) : (tensor) -> () + %scatter = "stablehlo.scatter"(%scatter_input, %scatter_indices, %scatter_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + + ////////////////// Test SliceOp ////////////////// + // CHECK: %slice = "stablehlo.slice"(%[[slice_input]]) + // CHECK-SAME: start_indices = array, + // CHECK-SAME: limit_indices = array, + // CHECK-SAME: strides = array + // CHECK-SAME: : (tensor<3x4xi64>) -> tensor<2x2xi64> + %slice = "stablehlo.slice"(%slice_input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xi64>) -> tensor<2x2xi64> + + ////////////////// Test BroadcastInDimOp ////////////////// + // CHECK: %broadcast = stablehlo.broadcast_in_dim %[[broadcast_input]], dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + %broadcast = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + + ////////////////// Test DynamicSliceOp ////////////////// + // CHECK: %[[dyn_operand:.*]] = "test.op"() : () -> tensor<4x4xi32> + %dyn_operand = "test.op"() : () -> tensor<4x4xi32> + + // CHECK: %[[start0:.*]] = "test.op"() : () -> tensor + %start0 = "test.op"() : () -> tensor + + // CHECK: %[[start1:.*]] = "test.op"() : () -> tensor + %start1 = "test.op"() : () -> tensor + + // CHECK: %dynamic_slice = "stablehlo.dynamic_slice"(%[[dyn_operand]], %[[start0]], %[[start1]]) + // CHECK-SAME: slice_sizes = array + // CHECK-SAME: : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x3xi32> + %dynamic_slice = "stablehlo.dynamic_slice"(%dyn_operand, %start0, %start1) { + slice_sizes = array + } : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_slice_operations(run_filecheck): + """Test invalid slice operations that should fail verification.""" + program_slice_mismatch = r""" + // CHECK: %input = "test.op"() : () -> tensor<3x8xi64> + %input = "test.op"() : () -> tensor<3x8xi64> + + // This should fail verification due to mismatched array sizes + // CHECK: %slice = "stablehlo.slice"(%input) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x8xi64>) -> tensor<2x2xi64> + %slice = "stablehlo.slice"(%input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x8xi64>) -> tensor<2x2xi64> + """ + + with pytest.raises( + Exception, + match="all of \\{start_indices, limit_indices, strides\\} must have the same size: got sizes 2, 3, 2", + ): + run_filecheck(program_slice_mismatch, roundtrip=True, verify=True) + + +def test_invalid_slice_element_type_mismatch(run_filecheck): + """Test that SliceOp rejects mismatched operand/result element types.""" + program = r""" + %slice_input = "test.op"() : () -> tensor<3x4xi64> + // CHECK: %slice_input = "test.op"() : () -> tensor<3x4xi64> + // Mismatched element type: operand is i64, result is f32 + %slice = "stablehlo.slice"(%slice_input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xi64>) -> tensor<2x2xf32> + """ + + # Expect verification failure due to element type mismatch + with pytest.raises( + Exception, match="requires the same element type for all operands and results" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_gather_element_type_mismatch(run_filecheck): + """Test that GatherOp rejects mismatched operand/result element types.""" + program = r""" + %operand = "test.op"() : () -> tensor<2x3x4x2xi32> + %start_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // Mismatched element type: operand is i32, result is f32 + %gather_bad = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xf32> + """ + + # Expect verification failure due to element type mismatch between operand and result + with pytest.raises( + Exception, match=r"all of \{operand, result\} must have the same element type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_reshape_operations(run_filecheck): + """Test invalid reshape operations that should fail verification.""" + program_reshape_mismatch = r""" + %reshape_input = "test.op"() : () -> tensor<2xf32> + + // This should fail verification due to element count mismatch (2 != 4) + %reshape_bad = "stablehlo.reshape"(%reshape_input) : (tensor<2xf32>) -> tensor<2x2xf32> + """ + + with pytest.raises(Exception, match="number of output elements"): + run_filecheck(program_reshape_mismatch, roundtrip=True, verify=True) + + +def test_invalid_broadcast_in_dim_operations(run_filecheck): + """Test invalid broadcast_in_dim operations that should fail verification.""" + # Test dims size mismatch. + program_broadcast_dims_size_mismatch = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // dims has size 1, but operand rank is 2 + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions size .* does not match operand rank"): + run_filecheck(program_broadcast_dims_size_mismatch, roundtrip=True, verify=True) + + # Test duplicate dims. + program_broadcast_duplicate_dims = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // duplicate entries in broadcast_dimensions are not allowed + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions should not have duplicates"): + run_filecheck(program_broadcast_duplicate_dims, roundtrip=True, verify=True) + + # Test dim index out of bounds. + program_broadcast_dim_oob = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // result rank is 2, but dims contains 2 (out of bounds) + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions contains invalid value"): + run_filecheck(program_broadcast_dim_oob, roundtrip=True, verify=True) + + # Test operand dim not 1 and not equal to result dim. + program_broadcast_dim_mismatch = r""" + %broadcast_input = "test.op"() : () -> tensor<2x3xi32> + + // operand[0] = 2, result[0] = 4; dims = [0, 2] -> mismatch on dim 0 + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<2x3xi32>) -> tensor<4x3x2xi32> + """ + + with pytest.raises( + Exception, + match="size of operand dimension .* is not equal to 1 or size of result dimension", + ): + run_filecheck(program_broadcast_dim_mismatch, roundtrip=True, verify=True) + + +def test_dynamism_operations(run_filecheck): + """Test all dynamism operations.""" + program = r""" + ////////////////// Setup ////////////////// + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor<1x3xi64> + %operand = "test.op"() : () -> tensor<1x3xi64> + + // CHECK: %[[out_dims:.*]] = "test.op"() : () -> tensor<3xi64> + %out_dims = "test.op"() : () -> tensor<3xi64> + + ////////////////// Test DynamicBroadcastInDimOp ////////////////// + // CHECK: %dynamic_bcast = stablehlo.dynamic_broadcast_in_dim %[[operand]], %[[out_dims]], dims = [2, 1] : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + %dynamic_bcast = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out_dims) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_reduction_operations(run_filecheck): + """Test all reduction operations.""" + program = r""" + ////////////////// Setup ////////////////// + // CHECK: %[[input:.*]] = "test.op"() : () -> tensor<1x6xi64> + %input = "test.op"() : () -> tensor<1x6xi64> + + // CHECK: %[[init:.*]] = "test.op"() : () -> tensor + %init = "test.op"() : () -> tensor + + ////////////////// Test ReduceOp ////////////////// + // CHECK: %reduce = "stablehlo.reduce"(%[[input]], %[[init]]) <{dimensions = array}> ({ + // CHECK: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%0) : (tensor) -> () + // CHECK: }) : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_reduction_operations(run_filecheck): + """Test invalid cases for ReduceOp verifier.""" + + # Duplicate dimensions + program_dup_dims = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"dimensions should not have duplicates"): + run_filecheck(program_dup_dims, roundtrip=True, verify=True) + + # Dimension out of range + program_dim_oob = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"dimensions contains an invalid value"): + run_filecheck(program_dim_oob, roundtrip=True, verify=True) + + # Input/init element type mismatch + program_elem_mismatch = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"input and init_value must have the same element type"): + run_filecheck(program_elem_mismatch, roundtrip=True, verify=True) + + # Reducer wrong arity (expects 2 args per input; give 1) + program_wrong_arity = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%acc: tensor): + "stablehlo.return"(%acc) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer must take 2 arguments, got 1"): + run_filecheck(program_wrong_arity, roundtrip=True, verify=True) + + # Reducer arg wrong rank (should be 0D) + program_arg_rank = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor<2xi64>, %arg1: tensor<2xi64>): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> + "stablehlo.return"(%0) : (tensor<2xi64>) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer arguments must be rank-0 tensors"): + run_filecheck(program_arg_rank, roundtrip=True, verify=True) + + # Reducer return wrong count + program_return_count = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "stablehlo.return"() : () -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer must return exactly one value per input"): + run_filecheck(program_return_count, roundtrip=True, verify=True) + + +def test_custom_call_basic(run_filecheck): + """CustomCallOp minimal form without layouts should verify.""" + program = r""" + // CHECK: %[[ARG:.*]] = "test.op"() : () -> tensor<2x3xi32> + %arg = "test.op"() : () -> tensor<2x3xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG]]) + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: api_version = #stablehlo + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_with_layouts(run_filecheck): + """CustomCallOp with matching operand/result layouts should verify.""" + program = r""" + // CHECK: %[[ARG:.*]] = "test.op"() : () -> tensor<2x3xi32> + %arg = "test.op"() : () -> tensor<2x3xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG]]) + // CHECK-SAME: operand_layouts = [dense<[1, 0]> : tensor<2xindex>] + // CHECK-SAME: result_layouts = [dense<[1, 0]> : tensor<2xindex>] + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_missing_result_layouts(run_filecheck): + """Providing only operand_layouts should fail (must provide both or none).""" + program = r""" + %arg = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises( + Exception, + match=r"either both operands and results or none", + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_layouts_mismatch(run_filecheck): + """Number of layouts must match number of operands/results.""" + program = r""" + %arg0 = "test.op"() : () -> tensor<2x3xi32> + %arg1 = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg0, %arg1) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises( + Exception, match=r"Number of operands must match the number of operand layouts" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_incorrect_layout_perm(run_filecheck): + """Layout must be a permutation of [0, rank).""" + program = r""" + %arg = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[0]> : tensor<1xindex>], + result_layouts = [dense<[0]> : tensor<1xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match=r"layout must be a permutation of \[0, 2\)"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_single_tuple_result_with_element_layouts(run_filecheck): + """Single tuple result with element-wise layouts should verify (common case).""" + program = r""" + // CHECK: %[[ARG0:.*]] = "test.op"() : () -> tensor<2x3xi32> + // CHECK: %[[ARG1:.*]] = "test.op"() : () -> tensor<1xi32> + %arg0 = "test.op"() : () -> tensor<2x3xi32> + %arg1 = "test.op"() : () -> tensor<1xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG0]]) + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: api_version = #stablehlo + // CHECK-SAME: operand_layouts = [dense<[1, 0]> : tensor<2xindex>] + // CHECK-SAME: result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>] + %res = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0]> : tensor<1xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tuple, tensor<1xi32>> + """ + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_dynamic_broadcast_in_dim_operations(run_filecheck): + """Test invalid dynamic_broadcast_in_dim cases that should fail verification.""" + + # dims size mismatch (c2) + program_dims_size_mismatch = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<3xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + """ + + with pytest.raises( + Exception, match=r"broadcast_dimensions size \(1\) does not match operand rank \(2\)" + ): + run_filecheck(program_dims_size_mismatch, roundtrip=True, verify=True) + + # result rank < operand rank (c3) + program_result_rank_too_small = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<1xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<1xi64>) -> tensor<3xi64> + """ + + with pytest.raises(Exception, match=r"result rank \(1\) is less than operand rank \(2\)"): + run_filecheck(program_result_rank_too_small, roundtrip=True, verify=True) + + # duplicate dims (c4) + program_duplicate_dims = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises(Exception, match=r"broadcast_dimensions should not have duplicates"): + run_filecheck(program_duplicate_dims, roundtrip=True, verify=True) + + # dim index out of bounds (c5 bounds) + program_dim_oob = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, match=r"broadcast_dimensions contains invalid value 2 for result with rank 2" + ): + run_filecheck(program_dim_oob, roundtrip=True, verify=True) + + # per-dimension size compatibility (c5 compatibility) + program_dim_incompatible = r""" + %operand = "test.op"() : () -> tensor<2x3xi32> + %out = "test.op"() : () -> tensor<3xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<2x3xi32>, tensor<3xi64>) -> tensor<4x3x2xi32> + """ + + with pytest.raises( + Exception, + match=r"size of operand dimension 0 \(2\) is not compatible with size of result dimension 0 \(4\)", + ): + run_filecheck(program_dim_incompatible, roundtrip=True, verify=True) + + # output_dimensions length incompatible with result rank when static (c7) + program_outlen_mismatch = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3x2xi64> + """ + + with pytest.raises( + Exception, + match=r"length of output_dimensions \(2\) is not compatible with result rank \(3\)", + ): + run_filecheck(program_outlen_mismatch, roundtrip=True, verify=True) + + # duplicate expansion hints across both lists (c8) + program_dup_hints = r""" + %operand = "test.op"() : () -> tensor<1x1xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor<1x1xi64>, tensor<2xi64>) -> tensor<2x1xi64> + """ + + with pytest.raises( + Exception, match=r"duplicate expansion hint for at least one operand dimension" + ): + run_filecheck(program_dup_hints, roundtrip=True, verify=True) + + # hint refers to invalid operand dimension (c9/c10) + program_hint_oob = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array, + known_expanding_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, + match=r"hint for expanding dimension 5 does not refer to a valid operand dimension", + ): + run_filecheck(program_hint_oob, roundtrip=True, verify=True) diff --git a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py new file mode 100644 index 0000000000..64aba9a85b --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py @@ -0,0 +1,147 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/transform.py.""" + +from dataclasses import dataclass + +import pytest + +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from xdsl import passes +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.dialects.builtin import DictionaryAttr, IntegerAttr, i64 +from xdsl.dialects.transform import AnyOpType +from xdsl.passes import PassPipeline +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.test_value import create_ssa_value + +from catalyst.python_interface.conversion import xdsl_from_docstring +from catalyst.python_interface.dialects import transform +from catalyst.python_interface.dialects.transform import ApplyRegisteredPassOp +from catalyst.python_interface.pass_api import ( + ApplyTransformSequence, + compiler_transform, +) + + +def test_dict_options(): + """Test ApplyRegisteredPassOp constructor with dict options.""" + target = create_ssa_value(AnyOpType()) + options = {"option1": 1, "option2": True} + + op = ApplyRegisteredPassOp("canonicalize", target, options) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({"option1": 1, "option2": True}) + assert op.verify_() is None + + +def test_attr_options(): + """Test ApplyRegisteredPassOp constructor with DictionaryAttr options.""" + target = create_ssa_value(AnyOpType()) + options = DictionaryAttr({"test-option": IntegerAttr(42, i64)}) + + # This should trigger the __init__ method + op = ApplyRegisteredPassOp("canonicalize", target, options) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({"test-option": IntegerAttr(42, i64)}) + assert op.verify_() is None + + +def test_none_options(): + """Test ApplyRegisteredPassOp constructor with None options.""" + target = create_ssa_value(AnyOpType()) + + # This should trigger the __init__ method + op = ApplyRegisteredPassOp("canonicalize", target, None) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({}) + assert op.verify_() is None + + +def test_invalid_options(): + """Test ApplyRegisteredPassOp constructor with invalid options type.""" + target = create_ssa_value(AnyOpType()) + + with pytest.raises( + VerifyException, match="invalid_options should be of base attribute dictionary" + ): + ApplyRegisteredPassOp("canonicalize", target, "invalid_options").verify_() + + +def test_transform_dialect_filecheck(run_filecheck): + """Test that the transform dialect operations are parsed correctly.""" + program = """ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.any_op): + %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op + // CHECK: options = {"invalid-option" = 1 : i64} + %1 = "transform.apply_registered_pass"(%0) <{options = {"invalid-option" = 1 : i64}, pass_name = "canonicalize"}> : (!transform.any_op) -> !transform.any_op + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + """ + + run_filecheck(program) + + +def test_integration_for_transform_interpreter(capsys): + """Test that a pass with options is run via the transform interpreter""" + + @compiler_transform + @dataclass(frozen=True) + class _HelloWorld(passes.ModulePass): + name = "test-hello-world" + + custom_print: str | None = None + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + """Apply the pass.""" + if self.custom_print: + print(self.custom_print) + else: + print("hello world") + + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = "transform.apply_registered_pass"(%arg0) <{options = {"custom_print" = "Hello from custom option!"}, pass_name = "test-hello-world"}> : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + + mod = program() + pipeline = PassPipeline((ApplyTransformSequence(),)) + pipeline.apply(ctx, mod) + + assert "Hello from custom option!" in capsys.readouterr().out diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py new file mode 100644 index 0000000000..12c329f829 --- /dev/null +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -0,0 +1,532 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/impl.py""" + +from dataclasses import dataclass + +# pylint: disable=wrong-import-position,line-too-long +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import jax +import pennylane as qml +from jaxlib.mlir.ir import Module as jModule +from pennylane.capture import enabled as capture_enabled +from xdsl import passes +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.interpreters import Interpreter +from xdsl.passes import PassPipeline + +from catalyst import CompileError, qjit +from catalyst.passes import apply_pass +from catalyst.passes import cancel_inverses as catalyst_cancel_inverses +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface import Compiler +from catalyst.python_interface.conversion import ( + mlir_from_docstring, + mlir_module, + xdsl_from_docstring, +) +from catalyst.python_interface.dialects import transform +from catalyst.python_interface.pass_api import ( + ApplyTransformSequence, + TransformFunctionsExt, + TransformInterpreterPass, + available_passes, + compiler_transform, +) +from catalyst.python_interface.transforms import ( + iterative_cancel_inverses_pass, + merge_rotations_pass, +) + + +@dataclass(frozen=True) +class HelloWorldPass(passes.ModulePass): + """A simple pass that prints 'hello world' when run.""" + + name = "hello-world" + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + """Apply the pass.""" + print("hello world") + + +hello_world_pass = compiler_transform(HelloWorldPass) + + +def test_compiler(): + """Test that we can pass a jax module into the compiler. + + In this particular case, the compiler is not doing anything + because this module does not contain nested modules which is what + is expected of Catalyst. + + So, it just tests that Compiler.run does not trigger an assertion + and returns a valid + """ + + @mlir_module + @jax.jit + def identity(x): + """Identity function""" + return x + + input_module = identity(1) + retval = Compiler.run(input_module) + assert isinstance(retval, jModule) + assert str(retval) == str(input_module) + + +def test_generic_catalyst_program(): + """ + test that actually will trigger the transform interpreter + """ + + @mlir_from_docstring + def program(): + """ + "builtin.module"() <{sym_name = "circuit"}> ({ + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "jit_circuit", sym_visibility = "public"}> ({ + %8 = "catalyst.launch_kernel"() <{callee = @module_circuit::@circuit}> : () -> tensor<2xcomplex> + "func.return"(%8) : (tensor<2xcomplex>) -> () + }) {llvm.emit_c_interface} : () -> () + "builtin.module"() <{sym_name = "module_circuit"}> ({ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.op<"builtin.module">) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.op<"builtin.module">): + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "circuit", sym_visibility = "public"}> ({ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + "quantum.device"(%0) <{kwargs = "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}", lib = "/usr/local/lib/python3.11/dist-packages/pennylane_lightning/liblightning_qubit_catalyst.so", name = "LightningSimulator"}> : (i64) -> () + %1 = "quantum.alloc"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + %2 = "quantum.extract"(%1) <{idx_attr = 0 : i64}> : (!quantum.reg) -> !quantum.bit + %3 = "quantum.custom"(%2) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %4 = "quantum.custom"(%3) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %5 = "quantum.insert"(%1, %4) <{idx_attr = 0 : i64}> : (!quantum.reg, !quantum.bit) -> !quantum.reg + %6 = "quantum.compbasis"(%5) <{operandSegmentSizes = array}> : (!quantum.reg) -> !quantum.obs + %7 = "quantum.state"(%6) <{operandSegmentSizes = array}> : (!quantum.obs) -> tensor<2xcomplex> + "quantum.dealloc"(%5) : (!quantum.reg) -> () + "quantum.device_release"() : () -> () + "func.return"(%7) : (tensor<2xcomplex>) -> () + }) {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage, qnode} : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "setup"}> ({ + "quantum.init"() : () -> () + "func.return"() : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "teardown"}> ({ + "quantum.finalize"() : () -> () + "func.return"() : () -> () + }) : () -> () + }) : () -> () + """ + + retval = Compiler.run(program()) + assert isinstance(retval, jModule) + + +def test_generic_catalyst_program_as_string(): + """ + test that actually will trigger the transform interpreter + """ + + program: str = """ + "builtin.module"() <{sym_name = "circuit"}> ({ + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "jit_circuit", sym_visibility = "public"}> ({ + %8 = "catalyst.launch_kernel"() <{callee = @module_circuit::@circuit}> : () -> tensor<2xcomplex> + "func.return"(%8) : (tensor<2xcomplex>) -> () + }) {llvm.emit_c_interface} : () -> () + "builtin.module"() <{sym_name = "module_circuit"}> ({ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.op<"builtin.module">) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.op<"builtin.module">): + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "circuit", sym_visibility = "public"}> ({ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + "quantum.device"(%0) <{kwargs = "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}", lib = "/usr/local/lib/python3.11/dist-packages/pennylane_lightning/liblightning_qubit_catalyst.so", name = "LightningSimulator"}> : (i64) -> () + %1 = "quantum.alloc"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + %2 = "quantum.extract"(%1) <{idx_attr = 0 : i64}> : (!quantum.reg) -> !quantum.bit + %3 = "quantum.custom"(%2) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %4 = "quantum.custom"(%3) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %5 = "quantum.insert"(%1, %4) <{idx_attr = 0 : i64}> : (!quantum.reg, !quantum.bit) -> !quantum.reg + %6 = "quantum.compbasis"(%5) <{operandSegmentSizes = array}> : (!quantum.reg) -> !quantum.obs + %7 = "quantum.state"(%6) <{operandSegmentSizes = array}> : (!quantum.obs) -> tensor<2xcomplex> + "quantum.dealloc"(%5) : (!quantum.reg) -> () + "quantum.device_release"() : () -> () + "func.return"(%7) : (tensor<2xcomplex>) -> () + }) {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage, qnode} : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "setup"}> ({ + "quantum.init"() : () -> () + "func.return"() : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "teardown"}> ({ + "quantum.finalize"() : () -> () + "func.return"() : () -> () + }) : () -> () + }) : () -> () + """ + + retval = Compiler.run(program) + assert isinstance(retval, str) + + +def test_raises_error_when_pass_does_not_exists(): + """Attempts to run pass "this-pass-does-not-exists" on an empty module. + + This should raise an error + """ + + @xdsl_from_docstring + def empty_module(): + """ + builtin.module {} + """ + + @xdsl_from_docstring + def schedule_module(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "this-pass-does-not-exists" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + schedule = TransformInterpreterPass.find_transform_entry_point( + schedule_module(), "__transform_main" + ) + interpreter = Interpreter(empty_module()) + interpreter.register_implementations(TransformFunctionsExt(ctx, {})) + with pytest.raises(CompileError): + interpreter.call_op(schedule, (empty_module(),)) + + +def test_decorator(): + """Test that the decorator has modified the available_passes dictionary""" + assert "hello-world" in available_passes + assert available_passes["hello-world"]() == HelloWorldPass + + +def test_integration_for_transform_interpreter(capsys): + """Test that a pass is run via the transform interpreter""" + + # The hello-world pass is in the IR + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "hello-world" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + + pipeline = PassPipeline((ApplyTransformSequence(),)) + pipeline.apply(ctx, program()) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + +class TestCatalystIntegration: + """Tests for integration of the Python compiler with Catalyst""" + + @pytest.mark.usefixtures("use_capture") + def test_integration_catalyst_no_passes_with_capture(self): + """Test that the xDSL plugin can be used even when no passes are applied + when capture is enabled.""" + + assert capture_enabled() + + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + + def test_integration_catalyst_no_passes_no_capture(self): + """Test that the xDSL plugin can be used even when no passes are applied + when capture is disabled.""" + + assert not capture_enabled() + + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + + @pytest.mark.usefixtures("use_capture") + def test_integration_catalyst_xdsl_pass_with_capture(self, capsys): + """Test that a pass is run via the transform interpreter when using with a + qjit workflow and capture is enabled.""" + + assert capture_enabled() + + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @hello_world_pass + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + def test_integration_catalyst_xdsl_pass_no_capture(self, capsys): + """Test that a pass is run via the transform interpreter when using with a + qjit workflow and capture is disabled.""" + + assert not capture_enabled() + + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @apply_pass("hello-world") + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + @pytest.mark.usefixtures("use_capture") + def test_integration_catalyst_mixed_passes_with_capture(self, capsys): + """Test that both Catalyst and Python compiler passes can be used with qjit + when capture is enabled.""" + + assert capture_enabled() + + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @hello_world_pass + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + qml.X(0) + qml.X(0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + def test_integration_catalyst_mixed_passes_no_capture(self, capsys): + """Test that both Catalyst and Python compiler passes can be used with qjit + when capture is disabled.""" + + assert not capture_enabled() + + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @apply_pass("hello-world") + @catalyst_cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + qml.X(0) + qml.X(0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + +class TestCallbackIntegration: + """Test the integration of the callback functionality""" + + def test_callback_integration(self, capsys): + """Test that the callback mechanism works with the transform interpreter""" + + # pylint: disable=unused-variable + @compiler_transform + @dataclass(frozen=True) + class NonePass(passes.ModulePass): + """Dummy pass for testing.""" + + name = "none-pass" + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + """Apply the pass. Do nothing; the test if for callbacks.""" + return + + def print_between_passes(*_, pass_level=0): + """Print between passes callback.""" + if pass_level == 0: + return + print("hello world") + + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "none-pass" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + pipeline = PassPipeline((ApplyTransformSequence(callback=print_between_passes),)) + pipeline.apply(ctx, program()) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + def test_callback_prints_module_after_each_pass(self, capsys): + """Test that the callback prints the module after each pass""" + + def print_between_passes(_, module, __, pass_level=0): + if pass_level == 0: + return + print("=== Between Pass ===") + print(module) + + @xdsl_from_docstring + def program_2_passes(): + """ + builtin.module { + builtin.module @module_foo { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "xdsl-cancel-inverses" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + %1 = transform.apply_registered_pass "xdsl-merge-rotations" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield %1 : !transform.op<"builtin.module"> + func.func public @foo() { + %2 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %3 = tensor.extract %2[] : tensor + %4 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %5 = quantum.alloc(1) : !quantum.reg + %6 = tensor.extract %2[] : tensor + %7 = quantum.extract %5[%6] : !quantum.reg -> !quantum.bit + %8 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor + %9 = tensor.extract %8[] : tensor + %10 = quantum.custom "RX"(%9) %7 : !quantum.bit + %11 = tensor.extract %8[] : tensor + %12 = quantum.custom "RX"(%11) %10 : !quantum.bit + %13 = quantum.custom "Hadamard"() %12 : !quantum.bit + %14 = quantum.custom "Hadamard"() %13 : !quantum.bit + %15 = tensor.extract %1[] : tensor + %16 = quantum.insert %5[%15], %14 : !quantum.reg, !quantum.bit + %17 = quantum.compbasis qreg %16 : !quantum.obs + %18 = quantum.probs %17 : tensor<2xf64> + } + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + pipeline = PassPipeline((ApplyTransformSequence(callback=print_between_passes),)) + pipeline.apply(ctx, program_2_passes()) + + out = capsys.readouterr().out + printed_modules = out.split("=== Between Pass ===")[1:] + + assert ( + len(printed_modules) == 2 + ), "Callback should have been called twice (after each pass)." + + # callback after cancel-inverses + assert 'quantum.custom "RX"' in printed_modules[0] + assert 'quantum.custom "Hadamard"' not in printed_modules[0] + + # callback after merge-rotations + # We expect an `arith.addf` if rotations were merged + assert "arith.addf" in printed_modules[1], "Expected merged RX gates into a single rotation" + assert 'quantum.custom "RX"' in printed_modules[1] + + assert printed_modules[0] != printed_modules[1], "IR should differ between passes" + + @pytest.mark.usefixtures("use_capture") + def test_callback_run_integration(self, capsys): + """Test that the callback is integrated into the pass pipeline with the Compiler.run() method""" + + def print_between_passes(_, module, __, pass_level=0): + """Callback to print something between passes.""" + if pass_level == 0: + return + print("=== Between Pass ===") + print(module) + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @iterative_cancel_inverses_pass + @merge_rotations_pass + @qml.qnode(qml.device("null.qubit", wires=2)) + def circuit(): + qml.RX(0.1, 0) + qml.RX(2.0, 0) + qml.Hadamard(1) + qml.Hadamard(1) + return qml.state() + + Compiler.run(circuit.mlir_module, callback=print_between_passes) + out = capsys.readouterr().out + printed_modules = out.split("=== Between Pass ===")[1:] + + assert ( + len(printed_modules) == 2 + ), "Callback should have been called twice (after each pass)." + + # callback after merge-rotations + # We expect an `arith.addf` if rotations were merged + assert "arith.addf" in printed_modules[0], "Expected merged RX gates into a single rotation" + assert 'quantum.custom "RX"' in printed_modules[0] + assert 'quantum.custom "Hadamard"' in printed_modules[0] + + # callback after cancel-inverses + assert "arith.addf" in printed_modules[1], "Expected merged RX gates into a single rotation" + assert 'quantum.custom "RX"' in printed_modules[1] + assert 'quantum.custom "Hadamard"' not in printed_modules[1] + + assert printed_modules[0] != printed_modules[1], "IR should differ between passes" + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/test_xdsl_utils.py b/frontend/test/pytest/python_interface/test_xdsl_utils.py new file mode 100644 index 0000000000..7d5a7dff8b --- /dev/null +++ b/frontend/test/pytest/python_interface/test_xdsl_utils.py @@ -0,0 +1,119 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for xDSL utilities.""" + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +# pylint: disable=wrong-import-position,line-too-long +from xdsl.dialects import arith, builtin, tensor, test + +from catalyst.python_interface.dialects.stablehlo import ConstantOp as hloConstantOp +from catalyst.python_interface.utils import get_constant_from_ssa + + +class TestGetConstantFromSSA: + """Unit tests for ``get_constant_from_ssa``.""" + + def test_non_constant(self): + """Test that ``None`` is returned if the input is not a constant.""" + val = test.TestOp(result_types=(builtin.Float64Type(),)).results[0] + assert get_constant_from_ssa(val) is None + + @pytest.mark.parametrize( + "const, attr_type, dtype", + [ + (11, builtin.IntegerAttr, builtin.IntegerType(64)), + (5, builtin.IntegerAttr, builtin.IndexType()), + (2.5, builtin.FloatAttr, builtin.Float64Type()), + ], + ) + def test_scalar_constant_arith(self, const, attr_type, dtype): + """Test that constants created by ``arith.constant`` are returned correctly.""" + const_attr = attr_type(const, dtype) + val = arith.ConstantOp(value=const_attr).results[0] + + assert get_constant_from_ssa(val) == const + + @pytest.mark.parametrize( + "const, elt_type", + [ + (11, builtin.IntegerType(64)), + (9, builtin.IndexType()), + (2.5, builtin.Float64Type()), + (-1.1 + 2.3j, builtin.ComplexType(builtin.Float64Type())), + ], + ) + @pytest.mark.parametrize("constant_op", [arith.ConstantOp, hloConstantOp]) + def test_scalar_constant_extracted_from_rank0_tensor(self, const, elt_type, constant_op): + """Test that constants created by ``stablehlo.constant`` are returned correctly.""" + data = const + if isinstance(const, complex): + # For complex numbers, the number must be split into a 2-tuple containing + # the real and imaginary part when initializing a dense elements attr. + data = (const.real, const.imag) + + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=elt_type, shape=()), + data=(data,), + ) + tensor_ = constant_op(value=dense_attr).results[0] + val = tensor.ExtractOp(tensor=tensor_, indices=[], result_type=elt_type).results[0] + + assert get_constant_from_ssa(val) == const + + def test_tensor_constant_arith(self): + """Test that ``None`` is returned if the input is a tensor created by ``arith.constant``.""" + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), + data=(1, 2, 3), + ) + val = arith.ConstantOp(value=dense_attr).results[0] + + assert get_constant_from_ssa(val) is None + + def test_tensor_constant_stablehlo(self): + """Test that ``None`` is returned if the input is a tensor created by ``stablehlo.constant``.""" + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), + data=(1.0, 2.0, 3.0), + ) + val = hloConstantOp(value=dense_attr).results[0] + + assert get_constant_from_ssa(val) is None + + def test_extract_scalar_from_constant_tensor_stablehlo(self): + """Test that ``None`` is returned if the input is a scalar constant, but it was extracted + from a non-scalar constant.""" + # Index SSA value to be used for extracting a value from a tensor + dummy_index = test.TestOp(result_types=(builtin.IndexType(),)).results[0] + + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), + data=(1.0, 2.0, 3.0), + ) + tensor_ = hloConstantOp(value=dense_attr).results[0] + val = tensor.ExtractOp( + tensor=tensor_, indices=[dummy_index], result_type=builtin.Float64Type() + ).results[0] + # val is a value that we got by indexing into a constant tensor with rank >= 1 + assert isinstance(val.type, builtin.Float64Type) + + assert get_constant_from_ssa(val) is None + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py new file mode 100644 index 0000000000..96c19ed3e1 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py @@ -0,0 +1,198 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the graph state utils""" + +# pylint: disable=wrong-import-position + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from pennylane.exceptions import CompileError + +from catalyst.python_interface.transforms.mbqc.graph_state_utils import ( + _adj_matrix_generation_helper, + edge_iter, + get_graph_state_edges, + get_num_aux_wires, + n_vertices_from_packed_adj_matrix, +) + + +@pytest.fixture(scope="module", name="mbqc_single_qubit_graph") +def fixture_mbqc_single_qubit_graph(): + """Fixture that returns the densely packed adjacency matrix for the graph state used for + representing single-qubit gates in the MBQC formalism. + + The graph state is as follows: + + 0 -- 1 -- 2 -- 3 + """ + # fmt: off + packed_adj_matrix = [ + # 0 1 2 + 1, # 1 + 0, 1, # 2 + 0, 0, 1 # 3 + ] + return packed_adj_matrix + + +@pytest.fixture(scope="module", name="mbqc_cnot_graph") +def fixture_mbqc_cnot_graph(): + """Fixture that returns the densely packed adjacency matrix for the graph state used for + representing a CNOT gate in the MBQC formalism. + + The graph state is as follows: + + 0 -- 1 -- 2 -- 3 -- 4 -- 5 + | + 6 + | + 7 -- 8 -- 9 -- 10 - 11 - 12 + """ + # fmt: off + packed_adj_matrix = [ + # 0 1 2 3 4 5 6 7 8 9 10 11 + 1, + 0, 1, + 0, 0, 1, + 0, 0, 0, 1, + 0, 0, 0, 0, 1, + 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 1, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 + ] + return packed_adj_matrix + + +class TestGraphStateUtils: + """Unit tests for graph state utils.""" + + def test_unsupported_gate(self): + """Test error raised for unsupported gates""" + with pytest.raises(ValueError, match="rotxzx is not supported in the MBQC formalism."): + get_graph_state_edges("rotxzx") + + with pytest.raises(ValueError, match="rotXzx is not supported in the MBQC formalism."): + get_num_aux_wires("rotXzx") + + def test_adj_matrix_generation_helper(self): + """Test that error raised for unsupported gates.""" + num_vertices = 4 + edges = [(0, 1), (1, 2), (2, 3)] + adj_matrix = _adj_matrix_generation_helper(num_vertices, edges) + assert adj_matrix == [1, 0, 1, 0, 0, 1] + + @pytest.mark.parametrize("n_vertices", range(1, 16)) + def test_n_vertices(self, n_vertices: int): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function returns correct results + when given a valid densely packed adjacency matrix as input. + + This test performs the inverse operation of _n_vertices_from_packed_adj_matrix by computing + the number of elements in the densely packed adjacency matrix from the given number of + vertices, `n_vertices`, generates a null list of this length, and checks the function output + given this list is the same as `n_vertices`. + """ + n_elements = int(n_vertices * (n_vertices - 1) / 2) + adj_matrix = [0] * n_elements + n_observed = n_vertices_from_packed_adj_matrix(adj_matrix) + + assert n_observed == n_vertices + + @pytest.mark.parametrize("n", [2, 4, 5, 7, 8, 9]) + def test_n_vertices_raises_on_invalid(self, n): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function raises a CompileError when + given an invalid densely packed adjacency matrix as input. + """ + with pytest.raises(CompileError, match="densely packed adjacency matrix"): + adj_matrix = [0] * n + _ = n_vertices_from_packed_adj_matrix(adj_matrix) + + @pytest.mark.parametrize( + "adj_matrix, expected_edges", + [ + ([], []), + ([0], []), + ([1], [(0, 1)]), + ([1, 0, 0], [(0, 1)]), + ([1, 1, 0], [(0, 1), (0, 2)]), + ([1, 1, 1], [(0, 1), (0, 2), (1, 2)]), + ([False], []), + ([True], [(0, 1)]), + ], + ) + def test_edge_iter(self, adj_matrix, expected_edges): + """Test that the ``_edge_iter`` generator function yields correct results when given a valid + densely packed adjacency matrix as input.""" + edges = list(edge_iter(adj_matrix)) + assert edges == expected_edges + + @pytest.mark.parametrize("n", [2, 4, 5, 7, 8, 9]) + def test_edge_iter_raises_on_invalid(self, n): + """Test that the ``_edge_iter`` generator function raises a CompileError when given an + invalid densely packed adjacency matrix as input. + """ + with pytest.raises(CompileError, match="densely packed adjacency matrix"): + adj_matrix = [0] * n + _ = list(edge_iter(adj_matrix)) + + def test_n_vertices_mbqc_single_qubit(self, mbqc_single_qubit_graph): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function correctly determines that + the number of vertices in the densely packed adjacency matrix for the graph state used for + representing single-qubit gates in the MBQC formalism is equal to 4. + """ + n_observed = n_vertices_from_packed_adj_matrix(mbqc_single_qubit_graph) + assert n_observed == 4 + + def test_n_vertices_mbqc_cnot(self, mbqc_cnot_graph): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function correctly determines that + the number of vertices in the densely packed adjacency matrix for the graph state used for + representing a CNOT gate in the MBQC formalism is equal to 13. + """ + n_observed = n_vertices_from_packed_adj_matrix(mbqc_cnot_graph) + assert n_observed == 13 + + def test_edge_iter_mbqc_single_qubit(self, mbqc_single_qubit_graph): + """Test that the ``_edge_iter`` generator function applied to the densely packed adjacency + matrix for the graph state used for representing single-qubit gates in the MBQC formalism + yields the correct edges. + + For reference, the graph is: + + 0 -- 1 -- 2 -- 3 + """ + edges_observed = list(edge_iter(mbqc_single_qubit_graph)) + + assert edges_observed == get_graph_state_edges("RZ") + + def test_edge_iter_mbqc_cnot(self, mbqc_cnot_graph): + """Test that the ``_edge_iter`` generator function applied to the densely packed adjacency + matrix for the graph state used for representing a CNOT gate in the MBQC formalism yields + the correct edges. + + For reference, the graph is: + + 0 -- 1 -- 2 -- 3 -- 4 -- 5 + | + 6 + | + 7 -- 8 -- 9 -- 10 - 11 - 12 + """ + edges_observed = list(edge_iter(mbqc_cnot_graph)) + assert edges_observed == get_graph_state_edges("CNOT") diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py new file mode 100644 index 0000000000..9bfe8328f0 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py @@ -0,0 +1,573 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the convert to MBQC formalism transform""" +import pytest + +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml +from pennylane.ftqc import RotXZX + +from catalyst.ftqc import mbqc_pipeline +from catalyst.python_interface.transforms import ( + ConvertToMBQCFormalismPass, + convert_to_mbqc_formalism_pass, + decompose_graph_state_pass, + measurements_from_samples_pass, +) + + +class TestConvertToMBQCFormalismPass: + """Unit tests for ConvertToMBQCFormalismPass.""" + + def test_unsupported_gate(self, run_filecheck): + """Test for error threw for unsupported gate""" + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + %1 = quantum.custom "IsingXX"() %0 : !quantum.bit + return + } + """ + with pytest.raises(NotImplementedError): + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_unconverted_gate_set(self, run_filecheck): + """Test for supported gates that are not converted in the pass""" + program = """ + func.func @test_func(%arg0 :f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "PauliX"() [[q0:%.+]] : !quantum.bit + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "PauliY"() [[q0:%.+]] : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "PauliZ"() [[q0:%.+]] : !quantum.bit + %3 = quantum.custom "PauliZ"() %2 : !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "Identity"() [[q0:%.+]] : !quantum.bit + %4 = quantum.custom "Identity"() %3 : !quantum.bit + // CHECK-NEXT: quantum.gphase + quantum.gphase %arg0 + return + } + """ + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_hadamard_gate(self, run_filecheck): + """Test for lowering a Hadamard gate to a MBQC formalism.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @hadamard_in_mbqc([[q0:%.+]]) : (!quantum.bit) -> !quantum.bit + %1 = quantum.custom "Hadamard"() %0 : !quantum.bit + return + } + + // CHECK: func.func private @hadamard_in_mbqc(%0 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %1 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %2 = mbqc.graph_state_prep(%1 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %3 = quantum.extract %2[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %4 = quantum.extract %2[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %5 = quantum.extract %2[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %6 = quantum.extract %2[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7, %8 = quantum.custom "CZ"() %0, %3 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %9 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %10 = arith.constant 1.5707963267948966 : f64 + // CHECK-NEXT: %11, %12 = mbqc.measure_in_basis[XY, %9] %7 : i1, !quantum.bit + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %10] %8 : i1, !quantum.bit + // CHECK-NEXT: %15, %16 = mbqc.measure_in_basis[XY, %10] %4 : i1, !quantum.bit + // CHECK-NEXT: %17, %18 = mbqc.measure_in_basis[XY, %10] %5 : i1, !quantum.bit + // CHECK-NEXT: %19 = arith.xori %11, %15 : i1 + // CHECK-NEXT: %20 = arith.xori %19, %17 : i1 + // CHECK-NEXT: %21 = arith.constant true + // CHECK-NEXT: %22 = arith.cmpi eq, %20, %21 : i1 + // CHECK-NEXT: %23 = scf.if %22 -> (!quantum.bit) { + // CHECK-NEXT: %24 = quantum.custom "PauliX"() %6 : !quantum.bit + // CHECK-NEXT: scf.yield %24 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %6 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %25 = arith.xori %13, %15 : i1 + // CHECK-NEXT: %26 = arith.constant true + // CHECK-NEXT: %27 = arith.cmpi eq, %25, %26 : i1 + // CHECK-NEXT: %28 = scf.if %27 -> (!quantum.bit) { + // CHECK-NEXT: %29 = quantum.custom "PauliZ"() %23 : !quantum.bit + // CHECK-NEXT: scf.yield %29 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %23 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %16 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %12 : !quantum.bit + // CHECK-NEXT: func.return %28 : !quantum.bit + // CHECK-NEXT: } + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_s_gate(self, run_filecheck): + """Test for lowering a S gate to a MBQC formalism.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @s_in_mbqc([[q0:%.+]]) : (!quantum.bit) -> !quantum.bit + %1 = quantum.custom "S"() %0 : !quantum.bit + return + } + + // CHECK: func.func private @s_in_mbqc(%0 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %1 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %2 = mbqc.graph_state_prep(%1 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %3 = quantum.extract %2[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %4 = quantum.extract %2[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %5 = quantum.extract %2[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %6 = quantum.extract %2[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7, %8 = quantum.custom "CZ"() %0, %3 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %9 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %10 = arith.constant 1.5707963267948966 : f64 + // CHECK-NEXT: %11, %12 = mbqc.measure_in_basis[XY, %9] %7 : i1, !quantum.bit + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %9] %8 : i1, !quantum.bit + // CHECK-NEXT: %15, %16 = mbqc.measure_in_basis[XY, %10] %4 : i1, !quantum.bit + // CHECK-NEXT: %17, %18 = mbqc.measure_in_basis[XY, %9] %5 : i1, !quantum.bit + // CHECK-NEXT: %19 = arith.xori %13, %17 : i1 + // CHECK-NEXT: %20 = arith.constant true + // CHECK-NEXT: %21 = arith.cmpi eq, %19, %20 : i1 + // CHECK-NEXT: %22 = scf.if %21 -> (!quantum.bit) { + // CHECK-NEXT: %23 = quantum.custom "PauliX"() %6 : !quantum.bit + // CHECK-NEXT: scf.yield %23 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %6 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %24 = arith.xori %11, %13 : i1 + // CHECK-NEXT: %25 = arith.xori %24, %15 : i1 + // CHECK-NEXT: %26 = arith.constant true + // CHECK-NEXT: %27 = arith.xori %25, %26 : i1 + // CHECK-NEXT: %28 = arith.constant true + // CHECK-NEXT: %29 = arith.cmpi eq, %27, %28 : i1 + // CHECK-NEXT: %30 = scf.if %29 -> (!quantum.bit) { + // CHECK-NEXT: %31 = quantum.custom "PauliZ"() %22 : !quantum.bit + // CHECK-NEXT: scf.yield %31 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %22 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %16 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %12 : !quantum.bit + // CHECK-NEXT: func.return %30 : !quantum.bit + // CHECK-NEXT: } + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_rz_gate(self, run_filecheck): + """Test for lowering a RZ gate to a MBQC formalism.""" + program = """ + func.func @test_func(%param0: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @rz_in_mbqc(%param0, %0) : (f64, !quantum.bit) -> !quantum.bit + %1 = quantum.custom "RZ"(%param0) %0 : !quantum.bit + return + } + // CHECK: func.func private @rz_in_mbqc(%0 : f64, %1 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %2 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %3 = mbqc.graph_state_prep(%2 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %4 = quantum.extract %3[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %5 = quantum.extract %3[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %6 = quantum.extract %3[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7 = quantum.extract %3[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %8, %9 = quantum.custom "CZ"() %1, %4 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %10 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %11, %12 = mbqc.measure_in_basis[XY, %10] %8 : i1, !quantum.bit + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %10] %9 : i1, !quantum.bit + // CHECK-NEXT: %15 = arith.constant true + // CHECK-NEXT: %16 = arith.cmpi eq, %13, %15 : i1 + // CHECK-NEXT: %17, %18 = scf.if %16 -> (i1, !quantum.bit) { + // CHECK-NEXT: %19, %20 = mbqc.measure_in_basis[XY, %0] %5 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %19, %20 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %21 = arith.negf %0 : f64 + // CHECK-NEXT: %22, %23 = mbqc.measure_in_basis[XY, %21] %5 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %22, %23 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %24, %25 = mbqc.measure_in_basis[XY, %10] %6 : i1, !quantum.bit + // CHECK-NEXT: %26 = arith.xori %13, %24 : i1 + // CHECK-NEXT: %27 = arith.constant true + // CHECK-NEXT: %28 = arith.cmpi eq, %26, %27 : i1 + // CHECK-NEXT: %29 = scf.if %28 -> (!quantum.bit) { + // CHECK-NEXT: %30 = quantum.custom "PauliX"() %7 : !quantum.bit + // CHECK-NEXT: scf.yield %30 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %7 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %31 = arith.xori %11, %17 : i1 + // CHECK-NEXT: %32 = arith.constant true + // CHECK-NEXT: %33 = arith.cmpi eq, %31, %32 : i1 + // CHECK-NEXT: %34 = scf.if %33 -> (!quantum.bit) { + // CHECK-NEXT: %35 = quantum.custom "PauliZ"() %29 : !quantum.bit + // CHECK-NEXT: scf.yield %35 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %29 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %25 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %12 : !quantum.bit + // CHECK-NEXT: func.return %34 : !quantum.bit + // CHECK-NEXT: } + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_rotxzx_gate(self, run_filecheck): + """Test for lowering a RotXZX gate to a MBQC formalism.""" + program = """ + func.func @test_func(%param0: f64, %param1: f64, %param2: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @rotxzx_in_mbqc(%param0, %param1, %param2, %0) : (f64, f64, f64, !quantum.bit) -> !quantum.bit + %1 = quantum.custom "RotXZX"(%param0, %param1, %param2) %0 : !quantum.bit + return + } + // CHECK: func.func private @rotxzx_in_mbqc(%0 : f64, %1 : f64, %2 : f64, %3 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %4 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %5 = mbqc.graph_state_prep(%4 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %6 = quantum.extract %5[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7 = quantum.extract %5[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %8 = quantum.extract %5[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %9 = quantum.extract %5[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %10, %11 = quantum.custom "CZ"() %3, %6 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %12 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %12] %10 : i1, !quantum.bit + // CHECK-NEXT: %15 = arith.constant true + // CHECK-NEXT: %16 = arith.cmpi eq, %13, %15 : i1 + // CHECK-NEXT: %17, %18 = scf.if %16 -> (i1, !quantum.bit) { + // CHECK-NEXT: %19, %20 = mbqc.measure_in_basis[XY, %0] %11 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %19, %20 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %21 = arith.negf %0 : f64 + // CHECK-NEXT: %22, %23 = mbqc.measure_in_basis[XY, %21] %11 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %22, %23 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %24 = arith.constant true + // CHECK-NEXT: %25 = arith.cmpi eq, %17, %24 : i1 + // CHECK-NEXT: %26, %27 = scf.if %25 -> (i1, !quantum.bit) { + // CHECK-NEXT: %28, %29 = mbqc.measure_in_basis[XY, %1] %7 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %28, %29 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %30 = arith.negf %1 : f64 + // CHECK-NEXT: %31, %32 = mbqc.measure_in_basis[XY, %30] %7 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %31, %32 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %33 = arith.xori %13, %26 : i1 + // CHECK-NEXT: %34 = arith.constant true + // CHECK-NEXT: %35 = arith.cmpi eq, %33, %34 : i1 + // CHECK-NEXT: %36, %37 = scf.if %35 -> (i1, !quantum.bit) { + // CHECK-NEXT: %38, %39 = mbqc.measure_in_basis[XY, %2] %8 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %38, %39 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %40 = arith.negf %2 : f64 + // CHECK-NEXT: %41, %42 = mbqc.measure_in_basis[XY, %40] %8 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %41, %42 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %43 = arith.xori %17, %36 : i1 + // CHECK-NEXT: %44 = arith.constant true + // CHECK-NEXT: %45 = arith.cmpi eq, %43, %44 : i1 + // CHECK-NEXT: %46 = scf.if %45 -> (!quantum.bit) { + // CHECK-NEXT: %47 = quantum.custom "PauliX"() %9 : !quantum.bit + // CHECK-NEXT: scf.yield %47 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %9 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %48 = arith.xori %13, %26 : i1 + // CHECK-NEXT: %49 = arith.constant true + // CHECK-NEXT: %50 = arith.cmpi eq, %48, %49 : i1 + // CHECK-NEXT: %51 = scf.if %50 -> (!quantum.bit) { + // CHECK-NEXT: %52 = quantum.custom "PauliZ"() %46 : !quantum.bit + // CHECK-NEXT: scf.yield %52 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %46 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %27 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %37 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: func.return %51 : !quantum.bit + // CHECK-NEXT:} + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_cnot_gate(self, run_filecheck): + """Test for lowering a CNOT gate to a MBQC formalism.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]], [[q1:%.+]]= func.call @cnot_in_mbqc([[q0:%.+]], [[q1:%.+]]) : (!quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %2, %3 = quantum.custom "CNOT"() %0, %1 : !quantum.bit, !quantum.bit + return + } + // CHECK: func.func private @cnot_in_mbqc(%0 : !quantum.bit, %1 : !quantum.bit) -> (!quantum.bit, !quantum.bit) attributes {mbqc_transform = none} { + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_switch_statement(self, run_filecheck): + """Test that the convert_to_mbqc_formalism_pass works correctly with a switch statement.""" + program = """ + func.func @test_func(%qubits: !quantum.bit, %l : index) { + %0 = scf.index_switch %l -> !quantum.bit + case 0 { + // CHECK-NOT: quantum.custom "Hadamard"() + %q1 = quantum.custom "Hadamard"() %qubits : !quantum.bit + scf.yield %q1 : !quantum.bit + } + default { + // CHECK-NOT: quantum.custom "S"() + %q2 = quantum.custom "S"() %qubits : !quantum.bit + scf.yield %q2 : !quantum.bit + } + + return + } + """ + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_function_no_body(self, run_filecheck): + """Test that the convert_to_mbqc_formalism_pass works correctly with a function that has no body.""" + program = """ + func.func @test_func() { + // CHECK: func.func private @func_1(f64, f64, i1) -> f64 + func.func private @func_1(f64, f64, i1) -> f64 + // CHECK: func.func private @func_2(memref, f64, f64, i1) + func.func private @func_2(memref, f64, f64, i1) + return + } + """ + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.usefixtures("use_capture") + def test_gates_in_mbqc_gate_set_lowering(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and unrolled loops.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: circuit + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK-NOT: scf.if + # CHECK: scf.for + # CHECK: func.call @hadamard_in_mbqc + # CHECK: func.call @s_in_mbqc + # CHECK: func.call @rotxzx_in_mbqc + # CHECK: func.call @rz_in_mbqc + # CHECK: func.call @cnot_in_mbqc + # CHECK: mbqc.graph_state_prep + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + for i in range(50): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_gates_in_mbqc_gate_set_lowering_for(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and for-loop structure.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.for_loop(1, 1000, 1) + def loop_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: circuit + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK: scf.for + # CHECK: func.call @hadamard_in_mbqc + # CHECK: func.call @s_in_mbqc + # CHECK: func.call @rotxzx_in_mbqc + # CHECK: func.call @rz_in_mbqc + # CHECK: func.call @cnot_in_mbqc + # CHECK: mbqc.graph_state_prep + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + loop_for() # pylint: disable=no-value-for-parameter + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_gates_in_mbqc_gate_set_lowering_graph_state_decomp(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and for-loop structure.""" + dev = qml.device("null.qubit", wires=1000) + + def loop_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: mbqc.graph_state_prep + # CHECK: scf.for + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + for i in range(1000): + loop_for(i) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_gates_in_mbqc_gate_set_lowering_while(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and while-loop structure.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.while_loop(lambda i: i > 1000) + def while_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + autograph=True, + ) + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK: scf.while + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + while_for(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_gates_in_mbqc_gate_set_e2e(self): + """Test that the convert_to_mbqc_formalism_pass end to end on null.qubit.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @measurements_from_samples_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + for i in range(1000): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + res = circuit() + assert res == 1.0 diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py new file mode 100644 index 0000000000..9b48c93835 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py @@ -0,0 +1,505 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit and integration tests for the Python compiler `decompose-graph-state` transform. + +FileCheck notation hint: + + Qubit variable names are written as q0a, q0b, q1a, etc. in the FileCheck program. The leading + number indicates the wire index of that qubit, and the second letter increments by one after + each use. +""" + +# pylint: disable=wrong-import-position,line-too-long + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from catalyst.python_interface.transforms import ( + DecomposeGraphStatePass, + NullDecomposeGraphStatePass, +) + + +class TestDecomposeGraphStatePass: + """Unit tests for the decompose-graph-state pass.""" + + def test_1_qubit(self, run_filecheck): + """Test the decompose-graph-state pass for a 1-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[]> : tensor<0xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(1) : !quantum.reg + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK: [[out_reg0:%.+]] = quantum.insert [[graph_reg]][0], [[q0b]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[]> : tensor<0xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<0xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_2_qubit_chain(self, run_filecheck): + """Test the decompose-graph-state pass for a 2-qubit graph state. The qubit connectivity is: + + 0 -- 1 + + which has the adjacency matrix representation + + 0 1 + 1 0 + + and densely packed adjacency matrix representation + + [1] + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "Hadamard"() [[q1a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CZ"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0c]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1c]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_3_qubit_chain(self, run_filecheck): + """Test the decompose-graph-state pass for a 3-qubit graph state. The qubit connectivity is: + + 0 -- 1 -- 2 + + which has the adjacency matrix representation + + 0 1 0 + 1 0 1 + 0 1 0 + + and densely packed adjacency matrix representation + + [1, 0, 1] + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(3) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q2a:%.+]] = quantum.extract [[graph_reg]][2] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "Hadamard"() [[q1a]] : !quantum.bit + // CHECK-NEXT: [[q2b:%.+]] = quantum.custom "Hadamard"() [[q2a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CZ"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q1d:%.+]], [[q2c:%.+]] = quantum.custom "CZ"() [[q1c]], [[q2b]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0c]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg02:%.+]] = quantum.insert [[out_reg01]][2], [[q2c]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_4_qubit_square_lattice(self, run_filecheck): + """Test the decompose-graph-state pass for a 4-qubit graph state. The qubit connectivity is: + + 0 -- 1 + | | + 2 -- 3 + + which has the adjacency matrix representation + + 0 1 1 0 + 1 0 0 1 + 1 0 0 1 + 0 1 1 0 + + and densely packed adjacency matrix representation + + [1, 1, 0, 0, 1, 1] + + [(0, 1), (0, 2), (1, 3), (2, 3)] + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 1, 0, 0, 1, 1]> : tensor<6xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(4) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q2a:%.+]] = quantum.extract [[graph_reg]][2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q3a:%.+]] = quantum.extract [[graph_reg]][3] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "Hadamard"() [[q1a]] : !quantum.bit + // CHECK-NEXT: [[q2b:%.+]] = quantum.custom "Hadamard"() [[q2a]] : !quantum.bit + // CHECK-NEXT: [[q3b:%.+]] = quantum.custom "Hadamard"() [[q3a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CZ"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q0d:%.+]], [[q2c:%.+]] = quantum.custom "CZ"() [[q0c]], [[q2b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q1d:%.+]], [[q3c:%.+]] = quantum.custom "CZ"() [[q1c]], [[q3b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q2d:%.+]], [[q3d:%.+]] = quantum.custom "CZ"() [[q2c]], [[q3c]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg02:%.+]] = quantum.insert [[out_reg01]][2], [[q2d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg03:%.+]] = quantum.insert [[out_reg02]][3], [[q3d]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1, 1, 0, 0, 1, 1]> : tensor<6xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_2_qubit_chain_non_standard_init(self, run_filecheck): + """Test the decompose-graph-state pass for a 2-qubit graph state, using non-standard `init` + and `entangle` attributes. + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "S"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "S"() [[q1a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CNOT"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0c]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1c]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "S", entangle "CNOT"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_register_use(self, run_filecheck): + """Test that the uses of the register resulting from a graph_state_prep op are still correct + after decomposing it into its quantum ops with the decompose-graph-state pass. + + We do not rigorously test the decomposition here, only that the last resulting register from + the decomposition (specifically, from the last quantum.insert op) is correctly picked up by + the ops that used the register resulting from the original graph_state_prep op. + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + // CHECK: quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + + // CHECK: quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + // CHECK-NEXT: quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + + // CHECK: quantum.custom "CZ"() {{%.+}}, {{%.+}} : !quantum.bit, !quantum.bit + + // CHECK: quantum.insert {{%.+}}[0], {{%.+}} : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_qreg0:%.+]] = quantum.insert {{%.+}}[1], {{%.+}} : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + + // CHECK: quantum.extract [[out_qreg0]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: quantum.extract [[out_qreg0]][1] : !quantum.reg -> !quantum.bit + %q0a = quantum.extract %qreg[0] : !quantum.reg -> !quantum.bit + %q1a = quantum.extract %qreg[1] : !quantum.reg -> !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_0 = arith.constant 0.0 : f64 + %m0, %q0b = mbqc.measure_in_basis [XY, %angle_0] %q0a : i1, !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_pi_2 = arith.constant 1.5707963267948966 : f64 + %m1, %q1b = mbqc.measure_in_basis [XY, %angle_pi_2] %q1a : i1, !quantum.bit + + // CHECK: [[out_qreg1:%.+]] = quantum.insert [[out_qreg0]][0], {{%.+}} : !quantum.reg, !quantum.bit + // CHECK: [[out_qreg2:%.+]] = quantum.insert [[out_qreg1]][1], {{%.+}} : !quantum.reg, !quantum.bit + %reg0 = quantum.insert %qreg[0], %q0b : !quantum.reg, !quantum.bit + %reg1 = quantum.insert %reg0[1], %q1b : !quantum.reg, !quantum.bit + + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_adj_matrix_reuse(self, run_filecheck): + """Test that the decompose-graph-state pass supports the case where we reuse the adjacency + matrix resulting from a constant op multiple times. + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + // CHECK: quantum.alloc(3) + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + %qreg2 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_with_stablehlo_constant(self, run_filecheck): + """Test that the decompose-graph-state pass supports the case where the adjacency matrix + results from a `stablehlo.constant` op (rather than an `arith.constant` op). + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: stablehlo.constant + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + + %adj_matrix = "stablehlo.constant"() <{value = dense<[1, 0, 1]> : tensor<3xi1>}> : () -> tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + +class TestNullDecomposeGraphStatePass: + """Unit tests for the null-decompose-graph-state pass.""" + + def test_1_qubit(self, run_filecheck): + """Test the null-decompose-graph-state pass for a 1-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[]> : tensor<0xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(1) : !quantum.reg + // CHECK-NOT: quantum.extract + // CHECK-NOT: quantum.custom "Hadamard" + // CHECK-NOT: quantum.custom "CZ" + // CHECK-NOT: quantum.insert + + %adj_matrix = arith.constant dense<[]> : tensor<0xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<0xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_2_qubit_chain(self, run_filecheck): + """Test the null-decompose-graph-state pass for a 2-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + // CHECK-NOT: quantum.extract + // CHECK-NOT: quantum.custom "Hadamard" + // CHECK-NOT: quantum.custom "CZ" + // CHECK-NOT: quantum.insert + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_3_qubit_chain(self, run_filecheck): + """Test the null-decompose-graph-state pass for a 3-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(3) : !quantum.reg + // CHECK-NOT: quantum.extract + // CHECK-NOT: quantum.custom "Hadamard" + // CHECK-NOT: quantum.custom "CZ" + // CHECK-NOT: quantum.insert + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_register_use(self, run_filecheck): + """Test that the uses of the register resulting from a graph_state_prep op are still correct + after decomposing it into its quantum ops with the null-decompose-graph-state pass. + + We do not rigorously test the decomposition here, only that the resulting register from the + decomposition (specifically, from the quantum.alloc op) is correctly picked up by the ops + that used the register resulting from the original graph_state_prep op. + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + + // CHECK: quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + %q0a = quantum.extract %qreg[0] : !quantum.reg -> !quantum.bit + %q1a = quantum.extract %qreg[1] : !quantum.reg -> !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_0 = arith.constant 0.0 : f64 + %m0, %q0b = mbqc.measure_in_basis [XY, %angle_0] %q0a : i1, !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_pi_2 = arith.constant 1.5707963267948966 : f64 + %m1, %q1b = mbqc.measure_in_basis [XY, %angle_pi_2] %q1a : i1, !quantum.bit + + // CHECK: [[out_qreg1:%.+]] = quantum.insert [[graph_reg]][0], {{%.+}} : !quantum.reg, !quantum.bit + // CHECK: [[out_qreg2:%.+]] = quantum.insert [[out_qreg1]][1], {{%.+}} : !quantum.reg, !quantum.bit + %reg0 = quantum.insert %qreg[0], %q0b : !quantum.reg, !quantum.bit + %reg1 = quantum.insert %reg0[1], %q1b : !quantum.reg, !quantum.bit + + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_adj_matrix_reuse(self, run_filecheck): + """Test that the null-decompose-graph-state pass supports the case where we reuse the + adjacency matrix resulting from a constant op multiple times. + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + // CHECK: quantum.alloc(3) + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + %qreg2 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_with_stablehlo_constant(self, run_filecheck): + """Test that the null-decompose-graph-state pass supports the case where the adjacency matrix + results from a `stablehlo.constant` op (rather than an `arith.constant` op). + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: stablehlo.constant + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + + %adj_matrix = "stablehlo.constant"() <{value = dense<[1, 0, 1]> : tensor<3xi1>}> : () -> tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py new file mode 100644 index 0000000000..eb9525ec5e --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py @@ -0,0 +1,394 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the outline state evolution transform""" +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml +from pennylane.ftqc import RotXZX + +from catalyst.ftqc import mbqc_pipeline +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + OutlineStateEvolutionPass, + convert_to_mbqc_formalism_pass, + decompose_graph_state_pass, + diagonalize_final_measurements_pass, + measurements_from_samples_pass, + outline_state_evolution_pass, +) + + +@qml.while_loop(lambda i: i < 1000) +def _while_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + +class TestOutlineStateEvolutionPass: + """Unit tests for OutlineStateEvolutionPass.""" + + def test_func_wo_qnode_attr(self, run_filecheck): + """Test outline state evolution pass is not applied to a func without a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit() -> tensor { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: quantum.custom "PauliX"() + // CHECK-NOT: call @circuit.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK-NOT: func.func public @circuit.state_evolution + } + """ + + pipeline = (OutlineStateEvolutionPass(),) + run_filecheck(program, pipeline) + + def test_func_w_qnode_attr(self, run_filecheck): + """Test outline state evolution pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit() -> tensor attributes {qnode} { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK-NOT: quantum.custom "PauliX"() + // CHECK: call @circuit.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func public @circuit.state_evolution + } + """ + + pipeline = (OutlineStateEvolutionPass(),) + run_filecheck(program, pipeline) + + def test_multiple_func_w_qnode_attr(self, run_filecheck): + """Test outline state evolution pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit1() -> tensor attributes {qnode} { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: call @circuit1.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + func.func public @circuit2() -> tensor attributes {qnode} { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: call @circuit2.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func public @circuit1.state_evolution + // CHECK: func.func public @circuit2.state_evolution + } + """ + + pipeline = (OutlineStateEvolutionPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.usefixtures("use_capture") + def test_outline_state_evolution_no_error(self): + """Test outline_state_evolution_pass does not raise error for circuit with classical + operations only.""" + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + def circuit(x, y): + return x * y + 5 + + circuit(1, 4) + + @pytest.mark.usefixtures("use_capture") + def test_outline_state_evolution_no_terminal_op_error(self): + """Test outline_state_evolution_pass raises error when no terminal_boundary_op is found in + circuit with quantum operation.""" + # TODOs: we can resolve this issue if the boundary op is inserted when + # the program is captured. + dev = qml.device("null.qubit", wires=10) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + @qml.qnode(dev) + def circuit(): + return qml.state() + + with pytest.raises( + RuntimeError, match="A terminal_boundary_op op is not found in the circuit." + ): + circuit() + + @pytest.mark.usefixtures("use_capture") + def test_outline_state_evolution_pass_only(self, run_filecheck_qjit): + """Test the outline_state_evolution_pass only.""" + dev = qml.device("lightning.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # CHECK-NOT: scf.while + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT" + # CHECK-NOT: func.func public @circuit.state_evolution + # CHECK: quantum.alloc + # CHECK-NEXT: func.call @circuit.state_evolution + # CHECK-LABEL: func.func public @circuit.state_evolution + # CHECK-NOT: quantum.alloc + # CHECK-NOT: quantum.namedobs + # CHECK: scf.while + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "S"() + # CHECK: quantum.custom "RotXZX" + # CHECK: quantum.custom "RZ" + # CHECK: quantum.custom "CNOT" + _while_for(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_filecheck_qjit): + """Test if the outline_state_evolution_pass works with the convert-to-mbqc-formalism pass + on lightning.qubit.""" + dev = qml.device("lightning.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: mbqc.measure_in_basis + # CHECK-NOT: scf.if + # CHECK-NOT: quantum.dealloc_qb + # CHECK-LABEL: func.func public @circuit.state_evolution + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.namedobs + # CHECK: scf.while + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + _while_for(0) + qml.H(0) + qml.S(1) + RotXZX(0.1, 0.2, 0.3, wires=[2]) + qml.RZ(phi=0.1, wires=[3]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.X(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qjit): + """Test if the outline_state_evolution_pass works with all mbqc transform pipeline on + null.qubit.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @measurements_from_samples_pass + @diagonalize_final_measurements_pass + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # NOTE: There is scf.if, mbqc.measure_in_basis in the circuit() + # scope as X obs is decomposed into H@Z and H is converted to MBQC formalism. + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-LABEL: func.func public @circuit.state_evolution + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.namedobs + # CHECK: scf.while + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + _while_for(0) + qml.H(0) + qml.S(1) + RotXZX(0.1, 0.2, 0.3, wires=[2]) + qml.RZ(phi=0.1, wires=[3]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.X(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self): + """Test if a circuit can be transfored with the outline_state_evolution_pass and all mbqc + transform pipeline can be executed on null.qubit.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @measurements_from_samples_pass + @diagonalize_final_measurements_pass + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + _while_for(0) + qml.H(0) + qml.S(1) + RotXZX(0.1, 0.2, 0.3, wires=[2]) + qml.RZ(phi=0.1, wires=[3]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.X(wires=0)) + + res = circuit() + assert res == 1.0 + + @pytest.mark.usefixtures("use_capture") + def test_lightning_execution_with_structure(self): + """Test that the outline_state_evolution_pass on lightning.qubit for a circuit with program + structure is executable and returns results as expected.""" + dev = qml.device("lightning.qubit", wires=10) + + @qml.for_loop(0, 10, 1) + def for_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + + @qml.while_loop(lambda i: i < 10) + def while_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + @qml.qnode(dev) + def circuit(): + for_fn() # pylint: disable=no-value-for-parameter + while_fn(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.prod(qml.X(0), qml.Z(1))) + + res = circuit() + + @qml.qjit( + target="mlir", + ) + @qml.qnode(dev) + def circuit_ref(): + for_fn() # pylint: disable=no-value-for-parameter + while_fn(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.prod(qml.X(0), qml.Z(1))) + + res_ref = circuit_ref() + assert res == res_ref diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py new file mode 100644 index 0000000000..f5b8eb1159 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py @@ -0,0 +1,230 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the iterative cancel inverses transform""" + +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + IterativeCancelInversesPass, + iterative_cancel_inverses_pass, +) + + +class TestIterativeCancelInversesPass: + """Unit tests for IterativeCancelInversesPass.""" + + def test_no_inverses_same_qubit(self, run_filecheck): + """Test that nothing changes when there are no inverses.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "PauliY"() [[q1]] : !quantum.bit + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_inverses_different_qubits(self, run_filecheck): + """Test that nothing changes when there are no inverses.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q1]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + %3 = quantum.custom "PauliX"() %1 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_simple_self_inverses(self, run_filecheck): + """Test that inverses are cancelled.""" + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: quantum.custom + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliX"() %1 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_nested_self_inverses(self, run_filecheck): + """Test that nested self-inverses are cancelled.""" + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: quantum.custom + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + %3 = quantum.custom "PauliZ"() %2 : !quantum.bit + %4 = quantum.custom "PauliZ"() %3 : !quantum.bit + %5 = quantum.custom "PauliY"() %4 : !quantum.bit + %6 = quantum.custom "PauliX"() %5 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_cancel_ops_with_control_qubits(self, run_filecheck): + """Test that ops with control qubits can be cancelled.""" + program = """ + func.func @test_func() { + %0 = "arith.constant"() <{value = true}> : () -> i1 + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: quantum.custom + %4, %5, %6 = quantum.custom "PauliY"() %1 ctrls(%2, %3) ctrlvals(%0, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %7, %8, %9 = quantum.custom "PauliY"() %4 ctrls(%5, %6) ctrlvals(%0, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_cancel_ops_with_same_control_qubits_and_values(self, run_filecheck): + """Test that ops with control qubits and control values can be + cancelled.""" + program = """ + func.func @test_func() { + %0 = arith.constant false + %1 = arith.constant true + %2 = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + %4 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: "quantum.custom" + %5, %6, %7 = quantum.custom "PauliY"() %2 ctrls(%3, %4) ctrlvals(%1, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %8, %9, %10 = quantum.custom "PauliY"() %5 ctrls(%6, %7) ctrlvals(%1, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_ops_with_control_qubits_different_control_values(self, run_filecheck): + """Test that ops with the same control qubits but different + control values don't cancel.""" + program = """ + func.func @test_func() { + // CHECK-DAG: [[cval0:%.+]] = arith.constant false + // CHECK-DAG: [[cval1:%.+]] = arith.constant true + %0 = arith.constant false + %1 = arith.constant true + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + %4 = "test.op"() : () -> !quantum.bit + // CHECK: [[q3:%.+]], [[q4:%.+]], [[q5:%.+]] = quantum.custom "PauliY"() [[q0]] ctrls([[q1]], [[q2]]) ctrlvals([[cval1]], [[cval0]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK: quantum.custom "PauliY"() [[q3]] ctrls([[q4]], [[q5]]) ctrlvals([[cval0]], [[cval1]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %5, %6, %7 = quantum.custom "PauliY"() %2 ctrls(%3, %4) ctrlvals(%1, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %8, %9, %10 = quantum.custom "PauliY"() %5 ctrls(%6, %7) ctrlvals(%0, %1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_non_consecutive_self_inverse_ops(self, run_filecheck): + """Test that self-inverse gates on the same qubit that are not + consecutive are not cancelled.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: [[q2:%.+]] = quantum.custom "PauliY"() [[q1]] : !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q2]] : !quantum.bit + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + %3 = quantum.custom "PauliX"() %2 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + +@pytest.mark.usefixtures("use_capture") +class TestIterativeCancelInversesIntegration: + """Integration tests for the IterativeCancelInversesPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the IterativeCancelInversesPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @iterative_cancel_inverses_pass + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom + qml.H(0) + qml.X(0) + qml.X(0) + qml.H(0) + return qml.state() + + run_filecheck_qjit(circuit) + + def test_qjit_no_cancellation(self, run_filecheck_qjit): + """Test that the IterativeCancelInversesPass works correctly with qjit when + there are no operations that can be cancelled.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @iterative_cancel_inverses_pass + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom + qml.H(1) + qml.X(1) + qml.X(0) + qml.H(0) + return qml.state() + + with pytest.raises(AssertionError, match="filecheck failed"): + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py new file mode 100644 index 0000000000..6d35b7c2ef --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py @@ -0,0 +1,241 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the combine global phases transform""" +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + CombineGlobalPhasesPass, + combine_global_phases_pass, +) + + +class TestCombineGlobalPhasesPass: + """Unit tests for CombineGlobalPhasesPass.""" + + def test_combinable_ops_without_control_flow(self, run_filecheck): + """Test that combines global phases in a func without control flow.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi_sum:%.+]] = arith.addf %arg1, %arg0 : f64 + // CHECK: quantum.gphase([[phi_sum]]) + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: [[q1:%.+]] = quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_with_control_flow(self, run_filecheck): + """Test that combines global phases in a func with control flow.""" + program = """ + func.func @test_func(%cond: i32, %arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + quantum.gphase %arg0 + // CHECK: [[ret:%.+]] = scf.if %cond -> (f64) { + %ret = scf.if %cond -> (f64) { + // CHECK: [[two:%.+]] = arith.constant {{2.+}} : f64 + %two = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf %arg0, [[two]] : f64 + %arg0x2 = arith.mulf %arg0, %two : f64 + // CHECK: scf.yield [[arg02]] : f64 + scf.yield %arg0x2 : f64 + } else { + // CHECK: [[two_1:%.+]] = arith.constant {{2.+}} : f64 + %two_1 = arith.constant 2 : f64 + // CHECK: [[arg12:%.+]] = arith.mulf %arg1, [[two_1]] : f64 + %arg1x2 = arith.mulf %arg1, %two_1 : f64 + // CHECK: scf.yield [[arg12:%.+]] : f64 + scf.yield %arg1x2 : f64 + } + // CHECK: [[phi_sum:%.+]] = arith.addf [[ret]], %arg0 : f64 + // CHECK: quantum.gphase([[phi_sum]]) : + quantum.gphase %ret + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_in_control_flow_if(self, run_filecheck): + """Test that combines global phases in a func without control a flow. + Here the control flow is an `if` operation. + """ + program = """ + // CHECK: func.func @test_func(%cond : i32, [[arg0:%.+]] : f64, [[arg1:%.+]] : f64) + func.func @test_func(%cond : i32, %arg0 : f64, %arg1 : f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[ret:%.+]] = scf.if [[cond:%.+]] -> (f64) { + %ret = scf.if %cond -> (f64) { + // CHECK: [[two0:%.+]] = arith.constant {{2.+}} : f64 + %two0 = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf [[arg0]], [[two0]] : f64 + %arg02 = arith.mulf %arg0, %two0 : f64 + // CHECK: [[t0:%.+]] = "test.op"() : () -> f64 + %t0 = "test.op"() : () -> f64 + // CHECK: [[phi_sum_0:%.+]] = arith.addf [[t0]], [[t0]] : f64 + // CHECK: quantum.gphase([[phi_sum_0]]) + quantum.gphase %t0 + quantum.gphase %t0 + // CHECK: scf.yield [[arg02]] : f64 + scf.yield %arg02 : f64 + // CHECK: } else { + } else { + // CHECK: [[two1:%.+]] = arith.constant {{2.+}} : f64 + %two1 = arith.constant 2 : f64 + // CHECK: [[arg1x2:%.+]] = arith.mulf [[arg1]], [[two1]] : f64 + %arg1x2 = arith.mulf %arg1, %two1 : f64 + // CHECK: [[phi_sum_1:%.+]] = arith.addf [[arg1x2]], [[arg1x2]] : f64 + // CHECK: quantum.gphase([[phi_sum_1]]) + quantum.gphase %arg1x2 + quantum.gphase %arg1x2 + // CHECK: scf.yield [[arg1x2]] : f64 + scf.yield %arg1x2 : f64 + // CHECK: } + } + // CHECK: [[phi_sum_2:%.+]] = arith.addf [[arg1]], [[arg0]] : f64 + // CHECK: quantum.gphase([[phi_sum_2:%.+]]) + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_in_control_flow_for(self, run_filecheck): + """Test that combines global phases in a func with a control flow. + Here the control flow is a `for` operation. + """ + program = """ + // CHECK: func.func @test_func(%n : i32, [[arg0:%.+]] : f64, [[arg1:%.+]] : f64) + func.func @test_func(%n : i32, %arg0 : f64, %arg1 : f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[c0:%.+]] = arith.constant 0 : i32 + // CHECK: [[c1:%.+]] = arith.constant 1 : i32 + %0 = "test.op"() : () -> !quantum.bit + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + scf.for %i = %c0 to %n step %c1 { + // CHECK: [[two0:%.+]] = arith.constant {{2.+}} : f64 + %two0 = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf [[arg0]], [[two0]] : f64 + %arg02 = arith.mulf %arg0, %two0 : f64 + // CHECK: [[t0:%.+]] = "test.op"() : () -> f64 + %t0 = "test.op"() : () -> f64 + // CHECK: [[phi_sum_0:%.+]] = arith.addf [[t0]], [[t0]] : f64 + // CHECK: quantum.gphase([[phi_sum_0]]) + quantum.gphase %t0 + quantum.gphase %t0 + } + // CHECK: [[phi_sum_1:%.+]] = arith.addf [[arg1]], [[arg0]] : f64 + // CHECK: quantum.gphase([[phi_sum_1]]) + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_in_control_flow_while(self, run_filecheck): + """Test that combines global phases in a func with control flow. + Here the control flow is a `while` operation. + """ + program = """ + // CHECK: func.func @test_func(%n : i32, [[arg0:%.+]] : f64, [[arg1:%.+]] : f64) + func.func @test_func(%n : i32, %arg0 : f64, %arg1 : f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + scf.while (%current_i = %n) : (i32) -> () { + %cond = arith.cmpi slt, %current_i, %n : i32 + scf.condition(%cond) %current_i : i32 + } do { + ^bb0(%i : i32): + %next_i = arith.addi %i, %c1 : i32 + // CHECK: [[two0:%.+]] = arith.constant {{2.+}} : f64 + %two0 = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf [[arg0]], [[two0]] : f64 + %arg02 = arith.mulf %arg0, %two0 : f64 + // CHECK: [[t0:%.+]] = "test.op"() : () -> f64 + %t0 = "test.op"() : () -> f64 + // CHECK: [[phi_sum_0:%.+]] = arith.addf [[t0]], [[t0]] : f64 + // CHECK: quantum.gphase([[phi_sum_0]]) + quantum.gphase %t0 + quantum.gphase %t0 + scf.yield %next_i : i32 + } + // CHECK: [[phi_sum_1:%.+]] = arith.addf [[arg1]], [[arg0]] : f64 + // CHECK: quantum.gphase([[phi_sum_1]]) : + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + +# pylint: disable=too-few-public-methods +@pytest.mark.usefixtures("use_capture") +class TestCombineGlobalPhasesIntegration: + """Integration tests for the CombineGlobalPhasesPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the CombineGlobalPhasesPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @combine_global_phases_pass + @qml.qnode(dev) + def circuit(x: float, y: float): + # CHECK: [[phi:%.+]] = arith.addf + # CHECK: quantum.gphase([[phi]]) + # CHECK-NOT: quantum.gphase + qml.GlobalPhase(x) + qml.GlobalPhase(y) + return qml.state() + + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py new file mode 100644 index 0000000000..d719ca9f7e --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py @@ -0,0 +1,588 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the xDSL implementation of the diagonalize_final_measurements pass""" + + +# pylint: disable=wrong-import-position + +import numpy as np +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes import apply_pass, xdsl_plugin +from catalyst.python_interface.transforms import ( + DiagonalizeFinalMeasurementsPass, + diagonalize_final_measurements_pass, +) + + +class TestDiagonalizeFinalMeasurementsPass: + """Unit tests for the diagonalize-final-measurements pass.""" + + def test_with_pauli_z(self, run_filecheck): + """Test that a PauliZ observable is not affected by diagonalization""" + + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: quantum.namedobs %0[PauliZ] : !quantum.obs + %1 = quantum.namedobs %0[PauliZ] : !quantum.obs + + // CHECK: quantum.expval %1 : f64 + %2 = quantum.expval %1 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_identity(self, run_filecheck): + """Test that an Identity observable is not affected by diagonalization.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: quantum.namedobs %0[Identity] : !quantum.obs + %1 = quantum.namedobs %0[Identity] : !quantum.obs + + // CHECK: quantum.var %1 : f64 + %2 = quantum.var %1 : f64 + return + } + """ + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_pauli_x(self, run_filecheck): + """Test that when diagonalizing a PauliX observable, the expected diagonalizing + gates are inserted and the observable becomes PauliZ.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "Hadamard"() [[q0]] + // CHECK-NEXT: [[q0_2:%.*]] = quantum.namedobs [[q0_1]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliX] + %1 = quantum.namedobs %0[PauliX] : !quantum.obs + + // CHECK: quantum.expval [[q0_2]] + %2 = quantum.expval %1 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_pauli_y(self, run_filecheck): + """Test that when diagonalizing a PauliY observable, the expected diagonalizing + gates are inserted and the observable becomes PauliZ.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "PauliZ"() [[q0]] + // CHECK-NEXT: [[q0_2:%.*]] = quantum.custom "S"() [[q0_1]] + // CHECK-NEXT: [[q0_3:%.*]] = quantum.custom "Hadamard"() [[q0_2]] + // CHECK-NEXT: [[q0_4:%.*]] = quantum.namedobs [[q0_3]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliY] + %1 = quantum.namedobs %0[PauliY] : !quantum.obs + + // CHECK: quantum.expval [[q0_4]] + %2 = quantum.expval %1 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_hadamard(self, run_filecheck): + """Test that when diagonalizing a Hadamard observable, the expected diagonalizing + gates are inserted and the observable becomes PauliZ.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: [[quarter_pi:%.*]] = arith.constant -0.78539816339744828 : f64 + // CHECK-NEXT: [[q0_1:%.*]] = quantum.custom "RY"([[quarter_pi]]) [[q0]] + // CHECK-NEXT: [[q0_2:%.*]] = quantum.namedobs [[q0_1]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][Hadamard] + %1 = quantum.namedobs %0[Hadamard] : !quantum.obs + + // CHECK: quantum.expval [[q0_2]] + %2 = quantum.expval %1 : f64 + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_composite_observable(self, run_filecheck): + """Test transform on a measurement process with a composite observable. In this + case, the simplified program is based on the MLIR generated by the circuit + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + return qml.expval(qml.Y(0)@qml.X(1) + qml.Z(2)) + """ + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.*]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "PauliZ"() [[q0]] + // CHECK: [[q0_2:%.*]] = quantum.custom "S"() [[q0_1]] + // CHECK: [[q0_3:%.*]] = quantum.custom "Hadamard"() [[q0_2]] + // CHECK: [[q_y:%.*]] = quantum.namedobs [[q0_3]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliY] + %3 = quantum.namedobs %0[PauliY] : !quantum.obs + + // CHECK: [[q1_1:%.*]] = quantum.custom "Hadamard"() [[q1]] + // CHECK: [[q_x:%.*]] = quantum.namedobs [[q1_1]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliX] + %4 = quantum.namedobs %1[PauliX] : !quantum.obs + + // CHECK: [[tensor0:%.*]] = quantum.tensor [[q_y]], [[q_x]] : !quantum.obs + %5 = quantum.tensor %3, %4 : !quantum.obs + + // CHECK: [[q_z:%.*]] = quantum.namedobs [[q2]][PauliZ] : !quantum.obs + %6 = quantum.namedobs %2[PauliZ] : !quantum.obs + + // CHECK: [[size:%.*]] = "test.op"() : () -> tensor<2xf64> + %size_info = "test.op"() : () -> tensor<2xf64> + + // CHECK: quantum.hamiltonian([[size]] : tensor<2xf64>) [[tensor0]], [[q_z]] : !quantum.obs + %7 = quantum.hamiltonian(%size_info : tensor<2xf64>) %5, %6 : !quantum.obs + + // CHECK: quantum.expval + %8 = quantum.expval %7 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_multiple_measurements(self, run_filecheck): + """Test diagonalizing a circuit with multiple measurements. The simplified program + for this test is based on the circuit + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + return qml.var(qml.Y(0)), qml.var(qml.X(1)) + """ + + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + + // CHECK: quantum.custom "PauliZ"() + // CHECK-NEXT: quantum.custom "S"() + // CHECK-NEXT: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %2 = quantum.namedobs %0[PauliY] : !quantum.obs + // CHECK: quantum.var + %3 = quantum.var %2 : f64 + + + // CHECK: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %4 = quantum.namedobs %1[PauliX] : !quantum.obs + + // CHECK: quantum.expval + %5 = quantum.expval %4 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_overlapping_observables_raises_error(self, run_filecheck): + """Test the case where multiple overlapping (commuting) observables exist in + the same circuit (an error is raised - split_non_commuting should have been applied). + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + return qml.var(qml.X(0)), qml.var(qml.X(0)) + """ + + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %1 = quantum.namedobs %0[PauliX] : !quantum.obs + %2 = quantum.var %1 : f64 + // CHECK: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %3 = quantum.namedobs %0[PauliX] : !quantum.obs + %4 = quantum.var %3 : f64 + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + + with pytest.raises( + RuntimeError, match="the circuit contains multiple observables with the same wire" + ): + run_filecheck(program, pipeline) + + def test_additional_qubit_uses_are_updated(self, run_filecheck): + """Test that when diagonalizing the circuit, if the MLIR contains + later manipulations of the qubit going into the observable, these are + updated as well. While quantum.custom operations can't be applied to + the same SSA value that is passed to the observable, it can still + be inserted into a register or deallocated. + + The simplified program for this test is based on the circuit + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + return qml.expval(qml.X(1)) + """ + + # we expect that instead of the SSA value that comes out of quantum.extract being passed to + # both quantum.namedobs and the quantum.insert, it will be passed to the Hadamard, and the + # SSA value that is output by the *Hadmard* operation will be passed to namedobs and insert. + program = """ + func.func @test_func() { + %0 = quantum.alloc(3) : !quantum.reg + %1 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %2 = tensor.extract %1[] : tensor + // CHECK: [[q0:%.*]] = quantum.extract + %3 = quantum.extract %0[%2] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "Hadamard"() [[q0]] + // CHECK-NEXT: quantum.namedobs [[q0_1]][PauliZ] + %4 = quantum.namedobs %3[PauliX] : !quantum.obs + %5 = quantum.expval %4 : f64 + + // CHECK: quantum.insert [[q:%.+]][[[q:%.+]]], [[q0_1]] + %6 = tensor.extract %1[] : tensor + %7 = quantum.insert %0[%6], %3 : !quantum.reg, !quantum.bit + quantum.dealloc %7 : !quantum.reg + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + +class TestDiagonalizeFinalMeasurementsProgramCaptureExecution: + """Integration tests going through plxpr (program capture enabled)""" + + # pylint: disable=unnecessary-lambda + @pytest.mark.usefixtures("use_capture") + @pytest.mark.parametrize( + "mp, obs, expected_res", + [ + (qml.expval, qml.Identity, lambda x: 1), + (qml.var, qml.Identity, lambda x: 0), + (qml.expval, qml.X, lambda x: 0), + (qml.var, qml.X, lambda x: 1), + (qml.expval, qml.Y, lambda x: -np.sin(x)), + (qml.var, qml.Y, lambda x: 1 - np.sin(x) ** 2), + (qml.expval, qml.Z, lambda x: np.cos(x)), + (qml.var, qml.Z, lambda x: 1 - np.cos(x) ** 2), + (qml.expval, qml.Hadamard, lambda x: np.cos(x) / np.sqrt(2)), + (qml.var, qml.Hadamard, lambda x: (2 - np.cos(x) ** 2) / 2), + ], + ) + def test_with_single_obs(self, mp, obs, expected_res): + """Test the diagonalization transform for a circuit with a single measurement + of a single supported observable""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit_ref(phi): + qml.RX(phi, 0) + return mp(obs(0)) + + angle = 0.7692 + + assert np.allclose( + expected_res(angle), circuit_ref(angle) + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + diagonalize_final_measurements_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.allclose(expected_res(angle), circuit_compiled(angle)) + + @pytest.mark.usefixtures("use_capture") + def test_with_composite_observables(self): + """Test the transform works for an observable built using operator arithmetic + (sprod, prod, sum)""" + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + qml.RY(y / 2, 2) + return qml.expval(qml.Y(0) @ qml.X(1) + 3 * qml.X(2)) + + def expected_res(x, y): + y0_res = -np.sin(x) + x1_res = np.sin(y) + x2_res = np.sin(y / 2) + return y0_res * x1_res + 3 * x2_res + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + diagonalize_final_measurements_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + @pytest.mark.usefixtures("use_capture") + def test_with_multiple_measurements(self): + """Test that the transform runs and returns the expected results for + a circuit with multiple measurements""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + return qml.expval(qml.Y(0)), qml.var(qml.X(1)) + + def expected_res(x, y): + return -np.sin(x), 1 - np.sin(y) ** 2 + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + diagonalize_final_measurements_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + @pytest.mark.usefixtures("use_capture") + def test_overlapping_observables_raises_error(self): + """Test the case where multiple overlapping (commuting) observables exist in + the same circuit (an error is raised - split_non_commuting should have been applied).""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @diagonalize_final_measurements_pass + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.var(qml.Y(0)) + + with pytest.raises( + RuntimeError, match="the circuit contains multiple observables with the same wire" + ): + _ = circuit(1.23) + + @pytest.mark.xfail(reason="for now, assume split_non_commuting is always applied") + @pytest.mark.usefixtures("use_capture") + def test_non_commuting_observables_raise_error(self): + """Check that an error is raised if we try to diagonalize a circuit that contains + non-commuting observables.""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @diagonalize_final_measurements_pass + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.expval(qml.X(0)) + + with pytest.raises( + RuntimeError, match="cannot diagonalize circuit with non-commuting observables" + ): + _ = circuit(0.7) + + +class TestDiagonalizeFinalMeasurementsCatalystFrontend: + """Integration tests going through the catalyst frontend (program capture disabled)""" + + # pylint: disable=unnecessary-lambda + @pytest.mark.parametrize( + "mp, obs, expected_res", + [ + (qml.expval, qml.Identity, lambda x: 1), + (qml.var, qml.Identity, lambda x: 0), + (qml.expval, qml.X, lambda x: 0), + (qml.var, qml.X, lambda x: 1), + (qml.expval, qml.Y, lambda x: -np.sin(x)), + (qml.var, qml.Y, lambda x: 1 - np.sin(x) ** 2), + (qml.expval, qml.Z, lambda x: np.cos(x)), + (qml.var, qml.Z, lambda x: 1 - np.cos(x) ** 2), + (qml.expval, qml.Hadamard, lambda x: np.cos(x) / np.sqrt(2)), + (qml.var, qml.Hadamard, lambda x: (2 - np.cos(x) ** 2) / 2), + ], + ) + def test_with_single_obs(self, mp, obs, expected_res): + """Test the diagonalization transform for a circuit with a single measurement + of a single supported observable""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit_ref(phi): + qml.RX(phi, 0) + return mp(obs(0)) + + angle = 0.7692 + + assert np.allclose( + expected_res(angle), circuit_ref(angle) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")(circuit_ref), + ) + + np.allclose(expected_res(angle), circuit_compiled(angle)) + + def test_with_composite_observables(self): + """Test the transform works for an observable built using operator arithmetic + (sprod, prod, sum)""" + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + qml.RY(y / 2, 2) + return qml.expval(qml.Y(0) @ qml.X(1) + 3 * qml.X(2)) + + def expected_res(x, y): + y0_res = -np.sin(x) + x1_res = np.sin(y) + x2_res = np.sin(y / 2) + return y0_res * x1_res + 3 * x2_res + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")(circuit_ref), + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + def test_with_multiple_measurements(self): + """Test that the transform runs and returns the expected results for + a circuit with multiple measurements""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + return qml.expval(qml.Y(0)), qml.var(qml.X(1)) + + def expected_res(x, y): + return -np.sin(x), 1 - np.sin(y) ** 2 + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")(circuit_ref), + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + def test_overlapping_observables_raises_error(self): + """Test the case where multiple overlapping (commuting) observables exist in + the same circuit (an error is raised - split_non_commuting should have been applied).""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit() + @apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.var(qml.Y(0)) + + with pytest.raises( + RuntimeError, match="the circuit contains multiple observables with the same wire" + ): + _ = circuit(1.23) + + @pytest.mark.xfail(reason="for now, assume split_non_commuting is always applied") + def test_non_commuting_observables_raise_error(self): + """Check that an error is raised if we try to diagonalize a circuit that contains + non-commuting observables.""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit() + @apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.expval(qml.X(0)) + + with pytest.raises( + RuntimeError, match="cannot diagonalize circuit with non-commuting observables" + ): + _ = circuit(0.7) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py new file mode 100644 index 0000000000..8889bb08ac --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py @@ -0,0 +1,889 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit and integration tests for the Python compiler `measurements_from_samples` transform.""" + +# pylint: disable=wrong-import-position,line-too-long + +from functools import partial + +import numpy as np +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes import xdsl_plugin +from catalyst.python_interface.transforms import ( + MeasurementsFromSamplesPass, + measurements_from_samples_pass, +) + + +class TestMeasurementsFromSamplesPass: + """Unit tests for the measurements-from-samples pass.""" + + def test_1_wire_expval(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with an expval(Z) + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %3 = quantum.namedobs %2[PauliZ] : !quantum.obs + + // CHECK: [[obs:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.expval + %4 = quantum.expval %3 : f64 + %5 = "tensor.from_elements"(%4) : (f64) -> tensor + + // CHECK: func.return [[res]] : tensor + func.return %5 : tensor + } + // CHECK-LABEL: func.func public @expval_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_expval_shots_from_arith_constantop(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit with shots from an arith.constant op and an expval(Z) measurement.""" + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = arith.constant 1 : i64 + quantum.device shots(%0) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %2 = quantum.namedobs %1[PauliZ] : !quantum.obs + + // CHECK: [[obs:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.expval + %3 = quantum.expval %2 : f64 + %4 = "tensor.from_elements"(%3) : (f64) -> tensor + + // CHECK: func.return [[res]] : tensor + func.return %4 : tensor + } + // CHECK-LABEL: func.func public @expval_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_var(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a var(Z) + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %3 = quantum.namedobs %2[PauliZ] : !quantum.obs + + // CHECK: [[obs:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res:%.+]] = func.call @var_from_samples.tensor.1x1xf64([[samples]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.var + %4 = quantum.var %3 : f64 + %5 = "tensor.from_elements"(%4) : (f64) -> tensor + + // CHECK: func.return [[res]] : tensor + func.return %5 : tensor + } + // CHECK-LABEL: func.func public @var_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_probs(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a probs + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + %2 = "test.op"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[q0]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample [[compbasis]] : tensor<1x1xf64> + // CHECK: [[res:%.+]] = func.call @probs_from_samples.tensor.1x1xf64([[samples]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64> + // CHECK-NOT: quantum.probs + %4 = quantum.probs %3 : tensor<2xf64> + + // CHECK: func.return [[res]] : tensor<2xf64> + func.return %4 : tensor<2xf64> + } + // CHECK-LABEL: func.func public @probs_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_sample(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a sample + measurement. + + This pass should be a no-op. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.reg + %2 = "test.op"() : () -> !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[q0]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample [[compbasis]] : tensor<1x1xf64> + %4 = quantum.sample %3 : tensor<1x1xf64> + + // CHECK: func.return [[samples]] : tensor<1x1xf64> + func.return %4 : tensor<1x1xf64> + } + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.xfail(reason="Counts not supported", strict=True, raises=NotImplementedError) + def test_1_wire_counts(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a counts + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.reg + %2 = "test.op"() : () -> !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[q0]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample %3 : tensor<1x1xf64> + // CHECK: [[eigvals:%.+]], [[counts:%.+]] = func.call @counts_from_samples.tensor.1x1xf64([[samples]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64>, tensor<2xi64> + // CHECK-NOT: quantum.counts + %eigvals, %counts = quantum.counts %3 : tensor<2xf64>, tensor<2xi64> + + // CHECK: [[eigvals_converted:%.+]] = {{.*}}stablehlo.convert{{.+}}[[eigvals]] : + // CHECK-SAME: (tensor<2xf64>) -> tensor<2xi64> + %4 = "stablehlo.convert"(%eigvals) : (tensor<2xf64>) -> tensor<2xi64> + + // CHECK: func.return [[eigvals_converted]], [[counts]] : tensor<1x1xf64> + func.return %4, %counts : tensor<2xi64>, tensor<2xi64> + } + // CHECK-LABEL: func.func public @counts_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_expval(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with an expval(Z) + measurement on each wire. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %4 = quantum.namedobs %2[PauliZ] : !quantum.obs + %5 = quantum.namedobs %3[PauliZ] : !quantum.obs + + // CHECK: [[obs0:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples0:%.+]] = quantum.sample [[obs0]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[obs1:%.+]] = quantum.compbasis qubits [[q1]] : !quantum.obs + // CHECK: [[samples1:%.+]] = quantum.sample [[obs1]] : tensor<1x1xf64> + // CHECK: [[c1:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res0:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples0]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK: [[res1:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples1]], [[c1]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.expval + %6 = quantum.expval %4 : f64 + %7 = "tensor.from_elements"(%6) : (f64) -> tensor + %8 = quantum.expval %5 : f64 + %9 = "tensor.from_elements"(%8) : (f64) -> tensor + + // CHECK: func.return [[res0]], [[res1]] : tensor, tensor + func.return %7, %9 : tensor, tensor + } + // CHECK-LABEL: func.func public @expval_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_var(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with a var(Z) + measurement on each wire. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %4 = quantum.namedobs %2[PauliZ] : !quantum.obs + %5 = quantum.namedobs %3[PauliZ] : !quantum.obs + + // CHECK: [[obs0:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples0:%.+]] = quantum.sample [[obs0]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[obs1:%.+]] = quantum.compbasis qubits [[q1]] : !quantum.obs + // CHECK: [[samples1:%.+]] = quantum.sample [[obs1]] : tensor<1x1xf64> + // CHECK: [[c1:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res0:%.+]] = func.call @var_from_samples.tensor.1x1xf64([[samples0]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK: [[res1:%.+]] = func.call @var_from_samples.tensor.1x1xf64([[samples1]], [[c1]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.var + %6 = quantum.var %4 : f64 + %7 = "tensor.from_elements"(%6) : (f64) -> tensor + %8 = quantum.var %5 : f64 + %9 = "tensor.from_elements"(%8) : (f64) -> tensor + + // CHECK: func.return [[res0]], [[res1]] : tensor, tensor + func.return %7, %9 : tensor, tensor + } + // CHECK-LABEL: func.func public @var_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_probs_global(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with a "global" + probs measurement (one that implicitly acts on all wires). + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[qreg:%.+]] = quantum.alloc + %2 = quantum.alloc(2) : !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[qreg]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample [[compbasis]] : tensor<1x2xf64> + // CHECK: [[res:%.+]] = func.call @probs_from_samples.tensor.1x2xf64([[samples]]) : + // CHECK-SAME: (tensor<1x2xf64>) -> tensor<4xf64> + // CHECK-NOT: quantum.probs + %4 = quantum.probs %3 : tensor<4xf64> + + // CHECK: func.return [[res]] : tensor<4xf64> + func.return %4 : tensor<4xf64> + } + // CHECK-LABEL: func.func public @probs_from_samples.tensor.1x2xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_probs_per_wire(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with separate + probs measurements per wire. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[qreg:%.+]] = quantum.alloc + %2 = quantum.alloc(2) : !quantum.reg + + // CHECK: [[q0:%.+]] = quantum.extract [[qreg]][0] + %3 = quantum.extract %2[0] : !quantum.reg -> !quantum.bit + + // CHECK: [[compbasis0:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + %4 = quantum.compbasis qubits %3 : !quantum.obs + + // CHECK: [[samples0:%.+]] = quantum.sample [[compbasis0]] : tensor<1x1xf64> + // CHECK: [[res0:%.+]] = func.call @probs_from_samples.tensor.1x1xf64([[samples0]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64> + // CHECK-NOT: quantum.probs + %5 = quantum.probs %4 : tensor<2xf64> + + // CHECK: [[q1:%.+]] = quantum.extract [[qreg]][1] + %6 = quantum.extract %2[1] : !quantum.reg -> !quantum.bit + + // CHECK: [[compbasis1:%.+]] = quantum.compbasis qubits [[q1]] : !quantum.obs + %7 = quantum.compbasis qubits %6 : !quantum.obs + + // CHECK: [[samples1:%.+]] = quantum.sample [[compbasis1]] : tensor<1x1xf64> + // CHECK: [[res1:%.+]] = func.call @probs_from_samples.tensor.1x1xf64([[samples1]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64> + // CHECK-NOT: quantum.probs + %8 = quantum.probs %7 : tensor<2xf64> + + // CHECK: func.return [[res0]], [[res1]] : tensor<2xf64>, tensor<2xf64> + func.return %5, %8 : tensor<2xf64>, tensor<2xf64> + } + // CHECK-LABEL: func.func public @probs_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + +@pytest.mark.usefixtures("use_capture") +class TestMeasurementsFromSamplesIntegration: + """Tests of the execution of simple workloads with the xDSL-based MeasurementsFromSamplesPass + transform. + """ + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, mp, obs, expected_res", + [ + # PauliZ observables + (qml.I, qml.expval, qml.Z, 1.0), + (qml.X, qml.expval, qml.Z, -1.0), + (qml.I, qml.var, qml.Z, 0.0), + (qml.X, qml.var, qml.Z, 0.0), + # PauliX observables + pytest.param( + partial(qml.RY, phi=np.pi / 2), + qml.expval, + qml.X, + 1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RY, phi=-np.pi / 2), + qml.expval, + qml.X, + -1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RY, phi=np.pi / 2), + qml.var, + qml.X, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RY, phi=-np.pi / 2), + qml.var, + qml.X, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + # PauliY observables + pytest.param( + partial(qml.RX, phi=-np.pi / 2), + qml.expval, + qml.Y, + 1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RX, phi=np.pi / 2), + qml.expval, + qml.Y, + -1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RX, phi=-np.pi / 2), + qml.var, + qml.Y, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RX, phi=np.pi / 2), + qml.var, + qml.Y, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + ], + ) + # pylint: disable=too-many-arguments,too-many-positional-arguments + def test_exec_1_wire_mp_with_obs(self, shots, initial_op, mp, obs, expected_res): + """Test the measurements_from_samples transform on a device with a single wire and terminal + measurements that require an observable (i.e. expval and var). + """ + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return mp(obs(wires=0)) + + assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert expected_res == circuit_compiled() + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, expected_res", + [ + (qml.I, [1.0, 0.0]), + (qml.X, [0.0, 1.0]), + ], + ) + def test_exec_1_wire_probs(self, shots, initial_op, expected_res): + """Test the measurements_from_samples transform on a device with a single wire and terminal + probs measurements. + """ + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return qml.probs(wires=0) + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.xfail( + reason="Counts not supported in Catalyst with program capture", + strict=True, + raises=NotImplementedError, + ) + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, expected_res", + [ + (qml.I, {"0": 10, "1": 0}), + (qml.X, {"0": 0, "1": 10}), + ], + ) + def test_exec_1_wire_counts(self, shots, initial_op, expected_res): + """Test the measurements_from_samples transform on a device with a single wire and terminal + counts measurements. + """ + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return qml.counts(wires=0) + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, _counts_catalyst_to_pl(*circuit_compiled())) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, expected_res_base", + [ + (qml.I, 0), + (qml.X, 1), + ], + ) + def test_exec_1_wire_sample(self, shots, initial_op, expected_res_base): + """Test the measurements_from_samples transform on a device with a single wire and terminal + sample measurements. + + In this case, the measurements_from_samples pass should effectively be a no-op. + """ + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return qml.sample(wires=0) + + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + expected_res = expected_res_base * np.ones(shape=(shots, 1), dtype=int) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, mp, obs, expected_res", + [ + ((qml.I, qml.I), qml.expval, qml.Z, (1.0, 1.0)), + ((qml.I, qml.X), qml.expval, qml.Z, (1.0, -1.0)), + ((qml.X, qml.I), qml.expval, qml.Z, (-1.0, 1.0)), + ((qml.X, qml.X), qml.expval, qml.Z, (-1.0, -1.0)), + ((qml.I, qml.I), qml.var, qml.Z, (0.0, 0.0)), + ((qml.I, qml.X), qml.var, qml.Z, (0.0, 0.0)), + ((qml.X, qml.I), qml.var, qml.Z, (0.0, 0.0)), + ((qml.X, qml.X), qml.var, qml.Z, (0.0, 0.0)), + ], + ) + # pylint: disable=too-many-arguments,too-many-positional-arguments + def test_exec_2_wire_with_obs_separate(self, shots, initial_ops, mp, obs, expected_res): + """Test the measurements_from_samples transform on a device with two wires and terminal + measurements that require an observable (i.e. expval and var). + + In this test, the terminal measurements are performed separately per wire. + """ + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return mp(obs(wires=0)), mp(obs(wires=1)) + + assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert expected_res == circuit_compiled() + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.xfail( + reason="Operator arithmetic not yet supported with capture enabled", strict=True + ) + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, mp, expected_res", + [ + ((qml.I, qml.I), qml.expval, 1.0), + ((qml.I, qml.X), qml.expval, -1.0), + ((qml.X, qml.I), qml.expval, -1.0), + ((qml.X, qml.X), qml.expval, 1.0), + ((qml.I, qml.I), qml.var, 0.0), + ((qml.I, qml.X), qml.var, 0.0), + ((qml.X, qml.I), qml.var, 0.0), + ((qml.X, qml.X), qml.var, 0.0), + ], + ) + def test_exec_2_wire_with_obs_combined(self, shots, initial_ops, mp, expected_res): + """Test the measurements_from_samples transform on a device with two wires and terminal + measurements that require an observable (i.e. expval and var). + + In this test, the terminal measurements are performed on the combination of both wires. + """ + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return mp(qml.Z(wires=0) @ qml.Z(wires=1)) + + assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert expected_res == circuit_compiled() + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, expected_res", + [ + ((qml.I, qml.I), [1.0, 0.0, 0.0, 0.0]), + ((qml.I, qml.X), [0.0, 1.0, 0.0, 0.0]), + ((qml.X, qml.I), [0.0, 0.0, 1.0, 0.0]), + ((qml.X, qml.X), [0.0, 0.0, 0.0, 1.0]), + ], + ) + def test_exec_2_wire_probs_global(self, shots, initial_ops, expected_res): + """Test the measurements_from_samples transform on a device with two wires and a terminal, + "global" probs measurements (one that implicitly acts on all wires). + """ + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return qml.probs() + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, expected_res", + [ + ((qml.I, qml.I), ([1.0, 0.0], [1.0, 0.0])), + ((qml.I, qml.X), ([1.0, 0.0], [0.0, 1.0])), + ((qml.X, qml.I), ([0.0, 1.0], [1.0, 0.0])), + ((qml.X, qml.X), ([0.0, 1.0], [0.0, 1.0])), + ], + ) + def test_exec_2_wire_probs_per_wire(self, shots, initial_ops, expected_res): + """Test the measurements_from_samples transform on a device with two wires and a terminal, + "global" probs measurements (one that implicitly acts on all wires). + """ + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return qml.probs(wires=0), qml.probs(wires=1) + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.xfail(reason="Dynamic shots not supported") + def test_exec_expval_dynamic_shots(self): + """Test the measurements_from_samples transform where the number of shots is dynamic. + + This use case is not currently supported. + """ + + @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + def workload(shots): + dev = qml.device("lightning.qubit", wires=1) + + @measurements_from_samples_pass + @qml.qnode(dev, shots=shots) + def circuit(): + return qml.expval(qml.Z(wires=0)) + + return circuit() + + result = workload(2) + assert result == 1.0 + + def test_qjit_filecheck(self, run_filecheck_qjit): + """Test that the measurements_from_samples_pass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @measurements_from_samples_pass + @qml.qnode(dev, shots=25) + def circuit(): + # CHECK-NOT: quantum.namedobs + # CHECK: [[obs:%.+]] = quantum.compbasis + # CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<25x1xf64> + # CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + # CHECK: [[res:%.+]] = func.call @expval_from_samples.tensor.25x1xf64([[samples]], [[c0]]) : + # CHECK-SAME: (tensor<25x1xf64>, tensor) -> tensor + # CHECK-NOT: quantum.expval + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_integrate_with_decompose(self): + """Test that the measurements_from_samples pass works correctly when used in combination + with the decompose pass.""" + dev = qml.device("null.qubit", wires=4) + + @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @measurements_from_samples_pass + @partial( + qml.transforms.decompose, + gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "GlobalPhase"}, + ) + @qml.set_shots(1000) + @qml.qnode(dev, shots=1000) + def circuit(): + qml.CRX(0.1, wires=[0, 1]) + return qml.expval(qml.Z(0)) + + res = circuit() + assert res == 1.0 + + +def _counts_catalyst_to_pl(basis_states, counts): + """Helper function to convert counts in the Catalyst format to the PennyLane format. + + Example: + + >>> basis_states, counts = ([0, 1], [6, 4]) + >>> _counts_catalyst_to_pl(basis_states, counts) + {'0': 6, '1': 4} + """ + return {format(int(state), "01b"): count for state, count in zip(basis_states, counts)} + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py new file mode 100644 index 0000000000..3dfb9dad77 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py @@ -0,0 +1,244 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the merge rotations transform""" +import pytest + +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import MergeRotationsPass, merge_rotations_pass + + +class TestMergeRotationsPass: + """Unit tests for MergeRotationsPass.""" + + pipeline = (MergeRotationsPass(),) + + def test_no_composable_ops(self, run_filecheck): + """Test that nothing changes when there are no composable gates.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "RX"(%arg0) [[q0]] : !quantum.bit + // CHECK: quantum.custom "RY"(%arg1) [[q1]] : !quantum.bit + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RY"(%arg1) %1 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_composable_ops(self, run_filecheck): + """Test that composable gates are merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: quantum.custom "RX"([[phi0]]) [[q0]] : !quantum.bit + // CHECK-NOT: "quantum.custom" + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RX"(%arg1) %1 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_many_composable_ops(self, run_filecheck): + """Test that more than 2 composable ops are merged correctly.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[phi1:%.+]] = arith.addf [[phi0]], %arg2 : f64 + // CHECK: [[phi2:%.+]] = arith.addf [[phi1]], %arg3 : f64 + // CHECK: quantum.custom "RX"([[phi2]]) [[q0]] : !quantum.bit + // CHECK-NOT: "quantum.custom" + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RX"(%arg1) %1 : !quantum.bit + %3 = quantum.custom "RX"(%arg2) %2 : !quantum.bit + %4 = quantum.custom "RX"(%arg3) %3 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_non_consecutive_composable_ops(self, run_filecheck): + """Test that non-consecutive composable gates are not merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "RX"(%arg0) [[q0]] : !quantum.bit + // CHECK: [[q2:%.+]] = quantum.custom "RY"(%arg0) [[q1]] : !quantum.bit + // CHECK: quantum.custom "RX"(%arg1) [[q2]] : !quantum.bit + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RY"(%arg0) %1 : !quantum.bit + %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_composable_ops_different_qubits(self, run_filecheck): + """Test that composable gates on different qubits are not merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "RX"(%arg0) [[q0]] : !quantum.bit + // CHECK: quantum.custom "RX"(%arg1) [[q1]] : !quantum.bit + %2 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %3 = quantum.custom "RX"(%arg1) %1 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_controlled_composable_ops(self, run_filecheck): + """Test that controlled composable ops can be merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + %cst = "arith.constant"() <{value = true}> : () -> i1 + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: quantum.custom "RX"([[phi0]]) [[q0]] ctrls([[q1]], [[q2]]) ctrlvals(%cst, %cst) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK-NOT: "quantum.custom" + %3, %4, %5 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%cst, %cst) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %6, %7, %8 = quantum.custom "RX"(%arg1) %3 ctrls(%4, %5) ctrlvals(%cst, %cst) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_controlled_composable_ops_same_control_values(self, run_filecheck): + """Test that controlled composable ops with the same control values + can be merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + %cst0 = "arith.constant"() <{value = true}> : () -> i1 + %cst1 = "arith.constant"() <{value = false}> : () -> i1 + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: quantum.custom "RX"([[phi0]]) [[q0]] ctrls([[q1]], [[q2]]) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.custom + %3, %4, %5 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %6, %7, %8 = quantum.custom "RX"(%arg1) %3 ctrls(%4, %5) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_controlled_composable_ops_different_control_values(self, run_filecheck): + """Test that controlled composable ops with different control values + are not merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + %cst0 = "arith.constant"() <{value = true}> : () -> i1 + %cst1 = "arith.constant"() <{value = false}> : () -> i1 + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + // CHECK: [[q3:%.+]], [[q4:%.+]], [[q5:%.+]] = quantum.custom "RX"(%arg0) [[q0]] ctrls([[q1]], [[q2]]) ctrlvals(%cst1, %cst0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK: quantum.custom "RX"(%arg1) [[q3]] ctrls([[q4]], [[q5]]) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %3, %4, %5 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%cst1, %cst0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %6, %7, %8 = quantum.custom "RX"(%arg1) %3 ctrls(%4, %5) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + @pytest.mark.parametrize( + "first_adj, second_adj, sign", + [ + (False, False, "+"), + (False, True, "-"), + (True, False, "-"), + (True, True, "+"), + ], + ) + def test_adjoint_property(self, first_adj, second_adj, sign, run_filecheck): + """Test that composable ops with and without adjoint property are merged correctly.""" + adj_string_0 = "adj" if first_adj else "" + adj_string_1 = "adj" if second_adj else "" + arith_op_str = "arith.addf" if sign == "+" else "arith.subf" + program = f""" + func.func @test_func(%arg0: f64, %arg1: f64) {{ + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[new_angle:%.+]] = {arith_op_str} %arg0, %arg1 : f64 + // CHECK: [[q1:%.+]] = quantum.custom "RX"([[new_angle]]) [[q0]] {adj_string_0} : !quantum.bit + // CHECK-NOT: "quantum.custom" + %1 = quantum.custom "RX"(%arg0) %0 {adj_string_0}: !quantum.bit + %2 = quantum.custom "RX"(%arg1) %1 {adj_string_1}: !quantum.bit + return + }} + """ + + run_filecheck(program, self.pipeline) + + +# pylint: disable=too-few-public-methods +@pytest.mark.usefixtures("use_capture") +class TestMergeRotationsIntegration: + """Integration tests for the MergeRotationsPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the MergeRotationsPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @merge_rotations_pass + @qml.qnode(dev) + def circuit(x: float, y: float): + # CHECK: [[phi:%.+]] = arith.addf + # CHECK: quantum.custom "RX"([[phi]]) + # CHECK-NOT: quantum.custom + qml.RX(x, 0) + qml.RX(y, 0) + return qml.state() + + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py new file mode 100644 index 0000000000..30179ec736 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py @@ -0,0 +1,311 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the split non-commuting transform""" +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + SplitNonCommutingPass, + split_non_commuting_pass, +) + + +class TestSplitNonCommutingPass: + """Unit tests for SplitNonCommutingPass.""" + + def test_func_w_qnode_attr(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit() -> tensor attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[value:%.+]] = func.call [[dup_func:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[value]] + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func [[dup_func]] + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_multiple_func_w_qnode_attr(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit1() -> tensor attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[value:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[value]] + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + func.func public @circuit2() -> tensor attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[value:%.+]] = func.call [[dup_func2:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[value]] + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func [[dup_func1]] + // CHECK: func.func [[dup_func2]] + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_func_w_commuting_measurements(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with commuting ops.""" + program = """ + module @module_circuit { + func.func public @circuit() -> (tensor, tensor) attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[v0:%.+]], [[v1:%.+]] = func.call [[dup_func:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[v0]], [[v1]] + %c0 = arith.constant 0 : i64 + quantum.device shots(%c0) ["", "", ""] + %0 = quantum.alloc( 5) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.namedobs %out_qubits_0[ PauliZ] : !quantum.obs + %6 = quantum.expval %5 : f64 + %from_elements_1 = tensor.from_elements %6 : tensor + %7 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %8 = quantum.insert %7[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit + quantum.dealloc %8 : !quantum.reg + quantum.device_release + return %from_elements, %from_elements_1 : tensor, tensor + } + // CHECK: func.func [[dup_func]] + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_func_w_non_commuting_measurements(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with non-commuting ops.""" + program = """ + module @module_circuit { + func.func public @circuit() -> (tensor, tensor) attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[v0:%.+]] = func.call [[dup_func0:@[a-zA-Z0-9_.]+]] + // CHECK: [[v1:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[v0]], [[v1]] + %c0 = arith.constant 0 : i64 + quantum.device shots(%c0) ["", "", ""] + %0 = quantum.alloc( 5) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs + %6 = quantum.expval %5 : f64 + %from_elements_1 = tensor.from_elements %6 : tensor + %7 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %8 = quantum.insert %7[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit + quantum.dealloc %8 : !quantum.reg + quantum.device_release + return %from_elements, %from_elements_1 : tensor, tensor + } + // CHECK: func.func [[dup_func0]] + // CHECK: PauliX + // CHECK: func.func [[dup_func1]] + // CHECK: PauliZ + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_func_w_mixed_measurements(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with mixed ops.""" + program = """ + module @module_circuit { + func.func public @circuit() -> (tensor, tensor) attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[v0:%.+]], [[v1:%.+]] = func.call [[dup_func0:@[a-zA-Z0-9_.]+]] + // CHECK: [[v2:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[v0]], [[v2]], [[v1]] + %c0 = arith.constant 0 : i64 + quantum.device shots(%c0) ["", "", ""] + %0 = quantum.alloc( 5) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs + %6 = quantum.expval %5 : f64 + %from_elements_1 = tensor.from_elements %6 : tensor + %7 = quantum.namedobs %out_qubits_0[ PauliY] : !quantum.obs + %8 = quantum.expval %7 : f64 + %from_elements_2 = tensor.from_elements %8 : tensor + %9 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %10 = quantum.insert %9[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit + quantum.dealloc %10 : !quantum.reg + quantum.device_release + return %from_elements, %from_elements_1, %from_elements_2 : tensor, tensor, tensor + } + // CHECK: func.func [[dup_func0]] + // CHECK: PauliX + // CHECK: PauliY + // CHECK: func.func [[dup_func1]] + // CHECK: PauliZ + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.usefixtures("use_capture") + def test_split_non_commuting_pass_only(self, run_filecheck_qjit): + """Test the split non-commuting pass only.""" + dev = qml.device("lightning.qubit", wires=5) + + @qml.while_loop(lambda i: i < 5) + def _while_for(i): + qml.H(i) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @split_non_commuting_pass + @qml.set_shots(10) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # CHECK: [[v0:%.+]], [[v1:%.+]] = func.call [[dup_func:@[a-zA-Z0-9_.]+]] + # CHECK: [[v2:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + # CHECK: func.return [[v0]], [[v1]], [[v2]] + _while_for(0) + qml.CNOT(wires=[0, 1]) + return ( + qml.expval(qml.Z(wires=0)), + qml.expval(qml.Y(wires=1)), + qml.expval(qml.X(wires=0)), + ) + # CHECK: func.func [[dup_func]] + # CHECK: PauliZ + # CHECK: PauliY + # CHECK: func.func [[dup_func1]] + # CHECK: PauliX + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("use_capture") + def test_lightning_execution_with_structure(self): + """Test that the split non-commuting pass on lightning.qubit for a circuit with program + structure is executable and returns results as expected.""" + dev = qml.device("lightning.qubit", wires=10) + + @qml.for_loop(0, 10, 1) + def for_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + + @qml.while_loop(lambda i: i < 10) + def while_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @split_non_commuting_pass + @qml.qnode(dev) + def circuit(): + for_fn() # pylint: disable=no-value-for-parameter + while_fn(0) + qml.CNOT(wires=[0, 1]) + return ( + qml.expval(qml.Z(wires=0)), + qml.expval(qml.Y(wires=1)), + qml.expval(qml.X(wires=0)), + ) + + res = circuit() + + @qml.qjit( + target="mlir", + ) + @qml.qnode(dev) + def circuit_ref(): + for_fn() # pylint: disable=no-value-for-parameter + while_fn(0) + qml.CNOT(wires=[0, 1]) + return ( + qml.expval(qml.Z(wires=0)), + qml.expval(qml.Y(wires=1)), + qml.expval(qml.X(wires=0)), + ) + + res_ref = circuit_ref() + assert res == res_ref diff --git a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py new file mode 100644 index 0000000000..26fa8d47ea --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py @@ -0,0 +1,542 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the draw function in the Python Compiler visualization module.""" + + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +# pylint: disable=wrong-import-position,unnecessary-lambda +import jax +import pennylane as qml + +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + iterative_cancel_inverses_pass, + merge_rotations_pass, +) +from catalyst.python_interface.visualization import draw + + +@pytest.mark.usefixtures("use_capture") +class Testdraw: + """Unit tests for the draw function in the Python Compiler visualization module.""" + + @pytest.fixture + def transforms_circuit(self): + """Fixture for a circuit.""" + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circ(): + qml.RX(1, 0) + qml.RX(2.0, 0) + qml.RY(3.0, 1) + qml.RY(4.0, 1) + qml.RZ(5.0, 2) + qml.RZ(6.0, 2) + qml.Hadamard(0) + qml.Hadamard(0) + qml.CNOT([0, 1]) + qml.CNOT([0, 1]) + qml.Hadamard(1) + qml.Hadamard(1) + qml.RZ(7.0, 0) + qml.RZ(8.0, 0) + qml.CNOT([0, 2]) + qml.CNOT([0, 2]) + return qml.state() + + return circ + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + 1, + "0: ──RX──H──H─╭●─╭●──RZ────╭●─╭●── State\n" + "1: ──RY───────╰X─╰X──H───H─│──│─── State\n" + "2: ──RZ────────────────────╰X─╰X── State", + ), + (2, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + (None, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + (50, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + ], + ) + def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected): + """Test that multiple levels of transformation are applied correctly with xDSL + compilation passes.""" + + transforms_circuit = iterative_cancel_inverses_pass( + merge_rotations_pass(transforms_circuit) + ) + + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + 1, + "0: ──RX──H──H─╭●─╭●──RZ────╭●─╭●── State\n" + "1: ──RY───────╰X─╰X──H───H─│──│─── State\n" + "2: ──RZ────────────────────╰X─╰X── State", + ), + (2, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + (None, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + (50, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + ], + ) + def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expected): + """Test that multiple levels of transformation are applied correctly with Catalyst + compilation passes.""" + + transforms_circuit = qml.transforms.cancel_inverses( + qml.transforms.merge_rotations(transforms_circuit) + ) + + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + 1, + "0: ──RX──H──H─╭●─╭●──RZ────╭●─╭●── State\n" + "1: ──RY───────╰X─╰X──H───H─│──│─── State\n" + "2: ──RZ────────────────────╰X─╰X── State", + ), + (2, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + (None, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + (50, "0: ──RX──RZ── State\n1: ──RY────── State\n2: ──RZ────── State"), + ], + ) + def test_multiple_levels_xdsl_catalyst(self, transforms_circuit, level, qjit, expected): + """Test that multiple levels of transformation are applied correctly with xDSL and + Catalyst compilation passes.""" + + transforms_circuit = iterative_cancel_inverses_pass( + qml.transforms.merge_rotations(transforms_circuit) + ) + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + 1, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + 2, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + None, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ( + 50, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●── State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│─── State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X── State", + ), + ], + ) + def test_no_passes(self, transforms_circuit, level, qjit, expected): + """Test that if no passes are applied, the circuit is still visualized.""" + + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize( + "op, expected", + [ + ( + lambda: qml.ctrl(qml.RX(0.1, 0), control=(1, 2, 3)), + "1: ─╭●─── State\n2: β”€β”œβ—β”€β”€β”€ State\n3: β”€β”œβ—β”€β”€β”€ State\n0: ─╰RX── State", + ), + ( + lambda: qml.ctrl(qml.RX(0.1, 0), control=(1, 2, 3), control_values=(0, 1, 0)), + "1: ─╭○─── State\n2: β”€β”œβ—β”€β”€β”€ State\n3: β”€β”œβ—‹β”€β”€β”€ State\n0: ─╰RX── State", + ), + ( + lambda: qml.adjoint(qml.ctrl(qml.RX(0.1, 0), (1, 2, 3), control_values=(0, 1, 0))), + "1: ─╭○──── State\n2: β”€β”œβ—β”€β”€β”€β”€ State\n3: β”€β”œβ—‹β”€β”€β”€β”€ State\n0: ─╰RX†── State", + ), + ( + lambda: qml.ctrl(qml.adjoint(qml.RX(0.1, 0)), (1, 2, 3), control_values=(0, 1, 0)), + "1: ─╭○──── State\n2: β”€β”œβ—β”€β”€β”€β”€ State\n3: β”€β”œβ—‹β”€β”€β”€β”€ State\n0: ─╰RX†── State", + ), + ], + ) + def test_ctrl_adjoint_variants(self, op, expected): + """ + Test the visualization of control and adjoint variants. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + op() + return qml.state() + + assert draw(circuit)() == expected + + def test_ctrl_before_custom_op(self): + """ + Test the visualization of control operations before custom ops. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + qml.ctrl(qml.X(3), control=[0, 1, 2], control_values=[1, 0, 1]) + qml.RX(0.1, 2) + return qml.state() + + assert ( + draw(circuit)() + == "0: ─╭●────── State\n1: β”€β”œβ—‹β”€β”€β”€β”€β”€β”€ State\n2: β”€β”œβ—β”€β”€RX── State\n3: ─╰X────── State" + ) + + @pytest.mark.parametrize( + "measurement, expected", + [ + ( + lambda: (qml.probs(0), qml.probs(1), qml.probs(2)), + "0: ──RX── Probs\n1: ──RY── Probs\n2: ──RZ── Probs", + ), + ( + lambda: qml.probs(), + "0: ──RX── Probs\n1: ──RY── Probs\n2: ──RZ── Probs", + ), + ( + lambda: qml.sample(), + "0: ──RX── Sample\n1: ──RY── Sample\n2: ──RZ── Sample", + ), + ( + lambda: (qml.expval(qml.X(0)), qml.expval(qml.Y(1)), qml.expval(qml.Z(2))), + "0: ──RX── \n1: ──RY── \n2: ──RZ── ", + ), + ( + lambda: ( + qml.expval(qml.X(0) @ qml.Y(1)), + qml.expval(qml.Y(1) @ qml.Z(2) @ qml.X(0)), + qml.expval(qml.Z(2) @ qml.X(0) @ qml.Y(1)), + ), + "0: ──RX── β•­ β•­ β•­\n" + "1: ──RY── β•° β”œ β”œ\n" + "2: ──RZ── β•° β•°", + ), + ( + lambda: ( + qml.expval( + qml.Hamiltonian([0.2, 0.2], [qml.PauliX(0), qml.Y(1)]) + @ qml.Hamiltonian([0.1, 0.1], [qml.PauliZ(2), qml.PauliZ(3)]) + ) + ), + "0: ──RX── β•­<(𝓗)@(𝓗)>\n" + "1: ──RY── β”œ<(𝓗)@(𝓗)>\n" + "2: ──RZ── β”œ<(𝓗)@(𝓗)>\n" + "3: ────── β•°<(𝓗)@(𝓗)>", + ), + ( + lambda: (qml.var(qml.X(0)), qml.var(qml.Y(1)), qml.var(qml.Z(2))), + "0: ──RX── Var[X]\n1: ──RY── Var[Y]\n2: ──RZ── Var[Z]", + ), + ( + lambda: ( + qml.var(qml.X(0) @ qml.Y(1)), + qml.var(qml.Y(1) @ qml.Z(2) @ qml.X(0)), + qml.var(qml.Z(2) @ qml.X(0) @ qml.Y(1)), + ), + "0: ──RX── β•­Var[X@Y] β•­Var[Y@Z@X] β•­Var[Z@X@Y]\n" + "1: ──RY── β•°Var[X@Y] β”œVar[Y@Z@X] β”œVar[Z@X@Y]\n" + "2: ──RZ── β•°Var[Y@Z@X] β•°Var[Z@X@Y]", + ), + ], + ) + def test_measurements(self, measurement, expected): + """ + Test the visualization of measurements. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + qml.RX(0.1, 0) + qml.RY(0.2, 1) + qml.RZ(0.3, 2) + return measurement() + + if isinstance(measurement(), qml.measurements.SampleMP): + circuit = qml.set_shots(10)(circuit) + + assert draw(circuit)() == expected + + def test_global_phase(self): + """Test the visualization of global phase shifts.""" + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + qml.H(0) + qml.H(1) + qml.H(2) + qml.GlobalPhase(0.5) + return qml.state() + + assert draw(circuit)() == ( + "0: ──H─╭GlobalPhase── State\n" + "1: ──Hβ”€β”œGlobalPhase── State\n" + "2: ──H─╰GlobalPhase── State" + ) + + @pytest.mark.parametrize( + "postselect, mid_measure_label", + [ + (None, "β”€β†—β”œ"), + (0, "β”€β†—β‚€β”œ"), + (1, "β”€β†—β‚β”œ"), + ], + ) + def test_draw_mid_circuit_measurement_postselect(self, postselect, mid_measure_label): + """Test that mid-circuit measurements are drawn correctly.""" + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + qml.Hadamard(0) + qml.measure(0, postselect=postselect) + qml.PauliX(0) + return qml.expval(qml.PauliZ(0)) + + drawing = draw(circuit)() + expected_drawing = "0: ──H──" + mid_measure_label + "──X── " + + assert drawing == expected_drawing + + @pytest.mark.jax + @pytest.mark.parametrize( + "ops, expected", + [ + ( + [ + (qml.QubitUnitary, jax.numpy.array([[0, 1], [1, 0]]), [0]), + ( + qml.QubitUnitary, + jax.numpy.array([[0, 1, 0, 1], [1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]), + [0, 1], + ), + (qml.QubitUnitary, jax.numpy.zeros((8, 8)), [0, 1, 2]), + ( + qml.QubitUnitary, + jax.numpy.array([[0, 1, 0, 1], [1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]), + [0, 1], + ), + (qml.QubitUnitary, jax.numpy.array([[0, 1], [1, 0]]), [0]), + ], + "0: ──U(M0)─╭U(M1)─╭U(M2)─╭U(M1)──U(M0)── State\n" + "1: ────────╰U(M1)β”€β”œU(M2)─╰U(M1)───────── State\n" + "2: ───────────────╰U(M2)──────────────── State", + ), + ( + [ + (qml.StatePrep, jax.numpy.array([1, 0]), [0]), + (qml.StatePrep, jax.numpy.array([1, 0, 0, 0]), [0, 1]), + (qml.StatePrep, jax.numpy.array([1, 0, 0, 0, 1, 0, 0, 0]), [0, 1, 2]), + (qml.StatePrep, jax.numpy.array([1, 0, 0, 0]), [0, 1]), + (qml.StatePrep, jax.numpy.array([1, 0]), [0]), + ], + "0: ──|Ξ¨βŸ©β”€β•­|Ξ¨βŸ©β”€β•­|Ξ¨βŸ©β”€β•­|Ξ¨βŸ©β”€β”€|Ξ¨βŸ©β”€β”€ State\n" + "1: ──────╰|Ξ¨βŸ©β”€β”œ|Ξ¨βŸ©β”€β•°|Ξ¨βŸ©β”€β”€β”€β”€β”€β”€β”€ State\n" + "2: ───────────╰|Ξ¨βŸ©β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ State", + ), + ( + [ + (qml.MultiRZ, 0.1, [0]), + (qml.MultiRZ, 0.1, [0, 1]), + (qml.MultiRZ, 0.1, [0, 1, 2]), + (qml.MultiRZ, 0.1, [0, 1]), + (qml.MultiRZ, 0.1, [0]), + ], + "0: ──MultiRZ─╭MultiRZ─╭MultiRZ─╭MultiRZ──MultiRZ── State\n" + "1: ──────────╰MultiRZβ”€β”œMultiRZ─╰MultiRZ─────────── State\n" + "2: ───────────────────╰MultiRZ──────────────────── State", + ), + ( + [ + (qml.BasisState, jax.numpy.array([1]), [0]), + (qml.BasisState, jax.numpy.array([1, 0]), [0, 1]), + (qml.BasisState, jax.numpy.array([1, 0, 0]), [0, 1, 2]), + (qml.BasisState, jax.numpy.array([1, 0]), [0, 1]), + (qml.BasisState, jax.numpy.array([1]), [0]), + ], + "0: ──|Ξ¨βŸ©β”€β•­|Ξ¨βŸ©β”€β•­|Ξ¨βŸ©β”€β•­|Ξ¨βŸ©β”€β”€|Ξ¨βŸ©β”€β”€ State\n" + "1: ──────╰|Ξ¨βŸ©β”€β”œ|Ξ¨βŸ©β”€β•°|Ξ¨βŸ©β”€β”€β”€β”€β”€β”€β”€ State\n" + "2: ───────────╰|Ξ¨βŸ©β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ State", + ), + ], + ) + def test_visualization_cases(self, ops, expected): + """ + Test the visualization of the quantum operations defined in the unified compiler dialect. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + for op, param, wires in ops: + op(param, wires=wires) + return qml.state() + + assert draw(circuit)() == expected + + def test_reshape(self): + """Test that the visualization works when the parameters are reshaped.""" + + one_dim = jax.numpy.array([1, 0]) + two_dim = jax.numpy.array([[0, 1], [1, 0]]) + eight_dim = jax.numpy.zeros((8, 8)) + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + qml.RX(one_dim[0], wires=0) + qml.RZ(two_dim[0, 0], wires=0) + qml.QubitUnitary(eight_dim[:2, :2], wires=0) + qml.QubitUnitary(eight_dim[0:4, 0:4], wires=[0, 1]) + return qml.state() + + expected = ( + "0: ──RX(M0)──RZ(M0)──U(M1)─╭U(M2)── State\n" + "1: ────────────────────────╰U(M2)── State" + ) + assert draw(circuit)() == expected + + def test_args_warning(self): + """Test that a warning is raised when dynamic arguments are used.""" + + # pylint: disable=unused-argument + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circ(arg): + qml.RX(0.1, wires=0) + return qml.state() + + with pytest.warns(UserWarning): + draw(circ)(0.1) + + def adjoint_op_not_implemented(self): + """Test that NotImplementedError is raised when AdjointOp is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + qml.adjoint(qml.QubitUnitary)(jax.numpy.array([[0, 1], [1, 0]]), wires=[0]) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + def test_cond_not_implemented(self): + """Test that NotImplementedError is raised when cond is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + m0 = qml.measure(0, reset=False, postselect=0) + qml.cond(m0, qml.RX, qml.RY)(1.23, 1) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + def test_for_loop_not_implemented(self): + """Test that NotImplementedError is raised when for loop is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + for _ in range(3): + qml.RX(0.1, 0) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + def test_while_loop_not_implemented(self): + """Test that NotImplementedError is raised when while loop is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + i = 0 + while i < 3: + qml.RX(0.1, 0) + i += 1 + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py new file mode 100644 index 0000000000..9753bcf88a --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py @@ -0,0 +1,289 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the MLIR graph generation in the Unified Compiler visualization module.""" + +from pathlib import Path + +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +import pennylane as qml + +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + iterative_cancel_inverses_pass, + merge_rotations_pass, +) +from catalyst.python_interface.visualization import generate_mlir_graph + + +@pytest.fixture(autouse=True) +def _chdir_tmp(monkeypatch, tmp_path: Path): + """Ensure all tests run inside a temp directory.""" + monkeypatch.chdir(tmp_path) + return tmp_path + + +def collect_files(tmp_path: Path) -> set[str]: + """Return the set of generated SVG files.""" + out_dir = tmp_path / "mlir_generated_graphs" + return {f.name for f in out_dir.glob("*.svg")} + + +def assert_files(tmp_path: Path, expected: set[str]): + """Check that the generated files match the expected set.""" + files = collect_files(tmp_path) + assert files == expected, f"Expected {expected}, got {files}" + + +@pytest.mark.usefixtures("use_capture") +class TestMLIRGraph: + """Test the MLIR graph generation""" + + @pytest.mark.parametrize("qjit", [True, False]) + def test_no_transforms(self, tmp_path: Path, qjit: bool): + """Test the MLIR graph is still generated when no transforms are applied""" + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(): + qml.RX(0.1, 0) + qml.RX(2.0, 0) + qml.CNOT([0, 2]) + qml.CNOT([0, 2]) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)() + assert collect_files(tmp_path) == {"QNode_level_0_no_transforms.svg"} + + @pytest.mark.parametrize("qjit", [True, False]) + def test_xdsl_transforms_no_args(self, tmp_path: Path, qjit: bool): + """Test the MLIR graph generation with no arguments to the QNode with and without qjit""" + + @merge_rotations_pass + @iterative_cancel_inverses_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(): + qml.RX(0.1, 0) + qml.RX(2.0, 0) + qml.CNOT([0, 2]) + qml.CNOT([0, 2]) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)() + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-cancel-inverses.svg", + "QNode_level_2_after_xdsl-merge-rotations.svg", + }, + ) + + @pytest.mark.parametrize("qjit", [True, False]) + def test_xdsl_transforms_args(self, tmp_path: Path, qjit: bool): + """Test the MLIR graph generation with arguments to the QNode for xDSL transforms""" + + @merge_rotations_pass + @iterative_cancel_inverses_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y, w1, w2): + qml.RX(x, w1) + qml.RX(y, w2) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)(0.1, 0.2, 0, 1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-cancel-inverses.svg", + "QNode_level_2_after_xdsl-merge-rotations.svg", + }, + ) + + @pytest.mark.parametrize("qjit", [True, False]) + def test_catalyst_transforms_args(self, tmp_path: Path, qjit: bool): + """Test the MLIR graph generation with arguments to the QNode for catalyst transforms""" + + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y, w1, w2): + qml.RX(x, w1) + qml.RX(y, w2) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)(0.1, 0.2, 0, 1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_cancel-inverses.svg", + "QNode_level_2_after_merge-rotations.svg", + }, + ) + + @pytest.mark.parametrize("qjit", [True, False]) + def test_catalyst_xdsl_transforms_args(self, tmp_path: Path, qjit: bool): + """Test the MLIR graph generation with arguments to the QNode for catalyst and xDSL + transforms""" + + @qml.transforms.merge_rotations + @iterative_cancel_inverses_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y, w1, w2): + qml.RX(x, w1) + qml.RX(y, w2) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)(0.1, 0.2, 0, 1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-cancel-inverses.svg", + "QNode_level_2_after_merge-rotations.svg", + }, + ) + + def test_cond(self, tmp_path: Path): + """Test the MLIR graph generation for a conditional""" + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(pred, arg1, arg2): + """Quantum circuit with conditional branches.""" + + qml.RX(0.10, wires=0) + + def true_fn(arg1, arg2): + qml.RY(arg1, wires=0) + qml.RX(arg2, wires=0) + qml.RZ(arg1, wires=0) + + def false_fn(arg1, arg2): + qml.RX(arg1, wires=0) + qml.RX(arg2, wires=0) + + qml.cond(pred > 0, true_fn, false_fn)(arg1, arg2) + qml.RX(0.10, wires=0) + return qml.expval(qml.Z(wires=0)) + + generate_mlir_graph(_)(0.5, 0.1, 0.2) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + def test_cond_with_mcm(self, tmp_path: Path): + """Test the MLIR graph generation for a conditional with MCM""" + + def true_fn(arg): + qml.RX(arg, 0) + + def false_fn(arg): + qml.RY(3 * arg, 0) + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y): + """Quantum circuit with conditional branches.""" + + qml.RX(x, 0) + m = qml.measure(0) + + qml.cond(m, true_fn, false_fn)(y) + return qml.expval(qml.Z(0)) + + generate_mlir_graph(_)(0.5, 0.1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + def test_for_loop(self, tmp_path: Path): + """Test the MLIR graph generation for a for loop""" + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(): + @qml.for_loop(0, 100) + def loop(_): + qml.RX(0.1, 0) + qml.RX(0.1, 0) + + # pylint: disable=no-value-for-parameter + loop() + return qml.state() + + generate_mlir_graph(_)() + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + def test_while_loop(self, tmp_path: Path): + """Test the MLIR graph generation for a while loop""" + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x): + def cond_fn(x): + return x < 2 + + @qml.while_loop(cond_fn) + def loop(x): + return x**2 + + loop(x) + return qml.expval(qml.PauliZ(0)) + + generate_mlir_graph(_)(0.5) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py b/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py new file mode 100644 index 0000000000..46a53070f0 --- /dev/null +++ b/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py @@ -0,0 +1,531 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the constraints defined within the xdsl_extras module.""" + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +# pylint: disable=wrong-import-position +from xdsl.context import Context +from xdsl.dialects import builtin, test +from xdsl.dialects.builtin import MemRefType, TensorType, TupleType, i1, i32 +from xdsl.ir import Dialect +from xdsl.irdl import ( + BaseAttr, + IRDLOperation, + irdl_op_definition, + operand_def, + result_def, +) +from xdsl.irdl.constraints import ConstraintContext +from xdsl.utils.exceptions import VerifyException +from xdsl_jax.dialects.stablehlo import TokenType + +from catalyst.python_interface import QuantumParser +from catalyst.python_interface.xdsl_extras import ( + MemRefConstraint, + NestedTupleOfConstraint, + TensorConstraint, +) + + +@pytest.fixture(scope="module", name="my_dialect") +def my_dialect_fixture(): + """Returns a test dialect, called 'my_dialect', with simple ops that operate on memref and + tensor types. + """ + + @irdl_op_definition + class Float64MemrefOp(IRDLOperation): + """A test op with float memref types""" + + name = "my_dialect.memref_float64" + in_value = operand_def(MemRefConstraint(element_type=builtin.Float64Type())) + out_value = result_def(MemRefConstraint(element_type=builtin.Float64Type())) + + @irdl_op_definition + class Float64TensorOp(IRDLOperation): + """A test op with float tensor types""" + + name = "my_dialect.tensor_float64" + in_value = operand_def(TensorConstraint(element_type=builtin.Float64Type())) + out_value = result_def(TensorConstraint(element_type=builtin.Float64Type())) + + @irdl_op_definition + class Rank1MemrefOp(IRDLOperation): + """A test op with rank-1 memref types""" + + name = "my_dialect.memref_rank1" + in_value = operand_def(MemRefConstraint(rank=1)) + out_value = result_def(MemRefConstraint(rank=1)) + + @irdl_op_definition + class Rank1TensorOp(IRDLOperation): + """A test op with rank-1 tensor types""" + + name = "my_dialect.tensor_rank1" + in_value = operand_def(TensorConstraint(rank=1)) + out_value = result_def(TensorConstraint(rank=1)) + + @irdl_op_definition + class Rank2Or4MemrefOp(IRDLOperation): + """A test op with rank-2 or -4 memref types""" + + name = "my_dialect.memref_rank24" + in_value = operand_def(MemRefConstraint(rank=(2, 4))) + out_value = result_def(MemRefConstraint(rank=(2, 4))) + + @irdl_op_definition + class Rank2Or4TensorOp(IRDLOperation): + """A test op with rank-2 or -4 tensor types""" + + name = "my_dialect.tensor_rank24" + in_value = operand_def(TensorConstraint(rank=(2, 4))) + out_value = result_def(TensorConstraint(rank=(2, 4))) + + @irdl_op_definition + class Shape123MemrefOp(IRDLOperation): + """A test op with shape-(1, 2, 3) memref types""" + + name = "my_dialect.memref_shape123" + in_value = operand_def(MemRefConstraint(shape=(1, 2, 3))) + out_value = result_def(MemRefConstraint(shape=(1, 2, 3))) + + @irdl_op_definition + class Shape123TensorOp(IRDLOperation): + """A test op with shape-(1, 2, 3) tensor types""" + + name = "my_dialect.tensor_shape123" + in_value = operand_def(TensorConstraint(shape=(1, 2, 3))) + out_value = result_def(TensorConstraint(shape=(1, 2, 3))) + + MyDialect = Dialect( + "my_dialect", + [ + Float64MemrefOp, + Float64TensorOp, + Rank1MemrefOp, + Rank1TensorOp, + Rank2Or4MemrefOp, + Rank2Or4TensorOp, + Shape123MemrefOp, + Shape123TensorOp, + ], + ) + + return MyDialect + + +class TestMemRefConstraint: + """Tests for the MemRefConstraint class.""" + + def test_memref_constraint_init_invalid_element_type(self): + """Test that an error is raised if the provided element_type is invalid""" + + element_type = int + with pytest.raises( + TypeError, match="is not a valid constraint for the 'element_type' argument" + ): + MemRefConstraint(element_type=element_type) + + def test_memref_constraint_init_invalid_shape(self): + """Test that an error is raised if the provided shape is invalid""" + + shape = {1, 2} + with pytest.raises(TypeError, match="is not a valid constraint for the 'shape' argument"): + MemRefConstraint(shape=shape) + + def test_memref_constraint_init_invalid_rank(self): + """Test that an error is raised if the provided rank is invalid""" + + rank = 1.5 + with pytest.raises(TypeError, match="is not a valid constraint for the 'rank' argument"): + MemRefConstraint(rank=rank) + + @pytest.mark.parametrize("element_type", [None, builtin.Float64Type(), builtin.i64]) + def test_memref_constraint_init_rank_and_shape_error(self, element_type): + """Test that an error is raised if both rank and shape are provided.""" + shape = (1, 2) + rank = 2 + + with pytest.raises(ValueError, match="Only one of 'shape' or 'rank' may be provided"): + MemRefConstraint(element_type=element_type, shape=shape, rank=rank) + + def test_memref_constraint_properties(self): + """Test that the properties of MemRefConstraint object are correct.""" + rank = 1 + constraint = MemRefConstraint(rank=rank) + + assert constraint.expected_type == builtin.MemRefType + assert constraint.type_name == "memref" + assert constraint.mapping_type_vars({}) is constraint + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_memref_single_rank_constraint_verify_valid(self, rank): + """Test that verifying a MemRefType attribute with the same rank as the MemRefConstraint + does not raise an exception.""" + constraint = MemRefConstraint(rank=rank) + attr = builtin.MemRefType(builtin.i32, [1] * rank) + + constraint.verify(attr, None) + + def test_memref_multi_rank_constraint_verify_valid(self): + """Test that verifying a MemRefType attribute with any of the ranks as the MemRefConstraint + does not raise an exception.""" + rank = (3, 4, 5) + constraint = MemRefConstraint(rank=rank) + + for r in rank: + attr = builtin.MemRefType(builtin.i32, [1] * r) + constraint.verify(attr, None) + + def test_memref_shape_constraint_verify_valid(self): + """Test that verifying an attribute with a valid shape does not raise an exception.""" + constraint = MemRefConstraint(shape=(1, 2, 3)) + attr = builtin.MemRefType(builtin.i32, (1, 2, 3)) + + constraint.verify(attr, None) + + def test_memref_element_type_constraint_verify_valid(self): + """Test that verifying an attribute with a valid element type does not raise an + exception.""" + constraint = MemRefConstraint(element_type=builtin.i32) + attr = builtin.MemRefType(builtin.i32, [1]) + + constraint.verify(attr, None) + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_memref_single_rank_constraint_verify_invalid(self, rank): + """Test that verifying a MemRefType attribute with a different rank as the + MemRefConstraint raises a VerifyException.""" + constraint = MemRefConstraint(rank=rank) + attr = builtin.MemRefType(builtin.i32, [1] * (rank + 1)) + + with pytest.raises(VerifyException, match=f"Invalid value {rank + 1}, expected {rank}"): + constraint.verify(attr, None) + + def test_memref_multi_rank_constraint_verify_invalid(self): + """Test that verifying a MemRefType attribute with a different rank as the + MemRefConstraint raises a VerifyException.""" + + rank = {3, 4, 5} + invalid_rank = 2 + constraint = MemRefConstraint(rank=rank) + attr = builtin.MemRefType(builtin.i32, [1] * invalid_rank) + + with pytest.raises( + VerifyException, match=f"Invalid value {invalid_rank}, expected one of {rank}" + ): + constraint.verify(attr, None) + + def test_memref_constraint_verify_invalid_type(self): + """Test that verifying an attribute with a type other than MemRefType raises a + VerifyException.""" + constraint = MemRefConstraint(rank=1) + attr = builtin.TensorType(builtin.i32, [1]) + + with pytest.raises(VerifyException, match=f"{attr} should be of type MemRefType"): + constraint.verify(attr, None) + + def test_memref_shape_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid shape raises an exception.""" + constraint = MemRefConstraint(shape=(1, 2, 3)) + attr = builtin.MemRefType(builtin.i32, (2, 2, 3)) + + with pytest.raises(VerifyException, match=r"Expected attribute \[.*\] but got \[.*\]"): + constraint.verify(attr, None) + + def test_memref_element_type_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid element type raises an exception.""" + constraint = MemRefConstraint(element_type=builtin.i32) + attr = builtin.MemRefType(builtin.i64, [1]) + + with pytest.raises( + VerifyException, match=f"Expected attribute {builtin.i32} but got {builtin.i64}" + ): + constraint.verify(attr, None) + + def test_memref_constraint_integration(self, my_dialect): + """Test that verification of legal operations with memref operand/result constraints does + not raise an exception.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> memref<2xf64> + %1 = "test.op"() : () -> memref<2x4xf64> + %2 = "test.op"() : () -> memref<2x4x5x3xf64> + %3 = "test.op"() : () -> memref<1x2x3xf64> + %4 = "my_dialect.memref_float64"(%0) : (memref<2xf64>) -> memref<2xf64> + %5 = "my_dialect.memref_rank1"(%0) : (memref<2xf64>) -> memref<2xf64> + %6 = "my_dialect.memref_rank24"(%1) : (memref<2x4xf64>) -> memref<2x4xf64> + %7 = "my_dialect.memref_rank24"(%2) : (memref<2x4x5x3xf64>) -> memref<2x4x5x3xf64> + %8 = "my_dialect.memref_shape123"(%3) : (memref<1x2x3xf64>) -> memref<1x2x3xf64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + xdsl_module.verify() + + def test_memref_constraint_integration_invalid(self, my_dialect): + """Test that verification of illegal operations with memref operands/result constraints + raises a VerifyException.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> memref<2x2xi64> + %1 = "my_dialect.memref_float64"(%0) : (memref<2x2xi64>) -> memref<2x2xi64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + + with pytest.raises( + VerifyException, + match=f"Expected attribute {builtin.Float64Type()} but got {builtin.i64}", + ): + xdsl_module.verify() + + +class TestTensorConstraint: + """Tests for the TensorConstraint class.""" + + def test_tensor_constraint_init_invalid_element_type(self): + """Test that an error is raised if the provided element_type is invalid""" + + element_type = int + with pytest.raises( + TypeError, match="is not a valid constraint for the 'element_type' argument" + ): + TensorConstraint(element_type=element_type) + + def test_tensor_constraint_init_invalid_shape(self): + """Test that an error is raised if the provided shape is invalid""" + + shape = {1, 2} + with pytest.raises(TypeError, match="is not a valid constraint for the 'shape' argument"): + TensorConstraint(shape=shape) + + def test_tensor_constraint_init_invalid_rank(self): + """Test that an error is raised if the provided rank is invalid""" + + rank = 1.5 + with pytest.raises(TypeError, match="is not a valid constraint for the 'rank' argument"): + TensorConstraint(rank=rank) + + @pytest.mark.parametrize("element_type", [None, builtin.Float64Type(), builtin.i64]) + def test_tensor_constraint_init_rank_and_shape_error(self, element_type): + """Test that an error is raised if both rank and shape are provided.""" + shape = (1, 2) + rank = 2 + + with pytest.raises(ValueError, match="Only one of 'shape' or 'rank' may be provided"): + TensorConstraint(element_type=element_type, shape=shape, rank=rank) + + def test_tensor_constraint_properties(self): + """Test that the properties of TensorConstraint object are correct.""" + rank = 1 + constraint = TensorConstraint(rank=rank) + + assert constraint.expected_type == builtin.TensorType + assert constraint.type_name == "tensor" + assert constraint.mapping_type_vars({}) is constraint + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_tensor_single_rank_constraint_verify_valid(self, rank): + """Test that verifying a TensorType attribute with the same rank as the TensorConstraint + does not raise an exception.""" + constraint = TensorConstraint(rank=rank) + attr = builtin.TensorType(builtin.i32, [1] * rank) + + constraint.verify(attr, None) + + def test_tensor_multi_rank_constraint_verify_valid(self): + """Test that verifying a TensorType attribute with any of the ranks as the TensorConstraint + does not raise an exception.""" + rank = (3, 4, 5) + constraint = TensorConstraint(rank=rank) + + for r in rank: + attr = builtin.TensorType(builtin.i32, [1] * r) + constraint.verify(attr, None) + + def test_tensor_shape_constraint_verify_valid(self): + """Test that verifying an attribute with a valid shape does not raise an exception.""" + constraint = TensorConstraint(shape=(1, 2, 3)) + attr = builtin.TensorType(builtin.i32, (1, 2, 3)) + + constraint.verify(attr, None) + + def test_tensor_element_type_constraint_verify_valid(self): + """Test that verifying an attribute with a valid element type does not raise an + exception.""" + constraint = TensorConstraint(element_type=builtin.i32) + attr = builtin.TensorType(builtin.i32, [1]) + + constraint.verify(attr, None) + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_tensor_single_rank_constraint_verify_invalid(self, rank): + """Test that verifying a TensorType attribute with a different rank as the + TensorConstraint raises a VerifyException.""" + constraint = TensorConstraint(rank=rank) + attr = builtin.TensorType(builtin.i32, [1] * (rank + 1)) + + with pytest.raises(VerifyException, match=f"Invalid value {rank + 1}, expected {rank}"): + constraint.verify(attr, None) + + def test_tensor_multi_rank_constraint_verify_invalid(self): + """Test that verifying a TensorType attribute with a different rank as the + TensorConstraint raises a VerifyException.""" + + rank = {3, 4, 5} + invalid_rank = 2 + constraint = TensorConstraint(rank=rank) + attr = builtin.TensorType(builtin.i32, [1] * invalid_rank) + + with pytest.raises( + VerifyException, match=f"Invalid value {invalid_rank}, expected one of {rank}" + ): + constraint.verify(attr, None) + + def test_tensor_constraint_verify_invalid_type(self): + """Test that verifying an attribute with a type other than TensorType raises a + VerifyException.""" + constraint = TensorConstraint(rank=1) + attr = builtin.MemRefType(builtin.i32, [1]) + + with pytest.raises(VerifyException, match=f"{attr} should be of type TensorType"): + constraint.verify(attr, None) + + def test_tensor_shape_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid shape raises an exception.""" + constraint = TensorConstraint(shape=(1, 2, 3)) + attr = builtin.TensorType(builtin.i32, (2, 2, 3)) + + with pytest.raises(VerifyException, match=r"Expected attribute \[.*\] but got \[.*\]"): + constraint.verify(attr, None) + + def test_tensor_element_type_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid element type raises an exception.""" + constraint = TensorConstraint(element_type=builtin.i32) + attr = builtin.TensorType(builtin.i64, [1]) + + with pytest.raises( + VerifyException, match=f"Expected attribute {builtin.i32} but got {builtin.i64}" + ): + constraint.verify(attr, None) + + def test_tensor_constraint_integration(self, my_dialect): + """Test that verification of legal operations with tensor operand/result constraints does + not raise an exception.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> tensor<2xf64> + %1 = "test.op"() : () -> tensor<2x4xf64> + %2 = "test.op"() : () -> tensor<2x4x5x3xf64> + %3 = "test.op"() : () -> tensor<1x2x3xf64> + %4 = "my_dialect.tensor_float64"(%0) : (tensor<2xf64>) -> tensor<2xf64> + %5 = "my_dialect.tensor_rank1"(%0) : (tensor<2xf64>) -> tensor<2xf64> + %6 = "my_dialect.tensor_rank24"(%1) : (tensor<2x4xf64>) -> tensor<2x4xf64> + %7 = "my_dialect.tensor_rank24"(%2) : (tensor<2x4x5x3xf64>) -> tensor<2x4x5x3xf64> + %8 = "my_dialect.tensor_shape123"(%3) : (tensor<1x2x3xf64>) -> tensor<1x2x3xf64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + xdsl_module.verify() + + def test_tensor_constraint_integration_invalid(self, my_dialect): + """Test that verification of illegal operations with tensor operands/result constraints + raises a VerifyException.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> tensor<2x2xi64> + %1 = "my_dialect.tensor_float64"(%0) : (tensor<2x2xi64>) -> tensor<2x2xi64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + + with pytest.raises( + VerifyException, + match=f"Expected attribute {builtin.Float64Type()} but got {builtin.i64}", + ): + xdsl_module.verify() + + +class TestNestedTupleOfConstraint: + """Tests for the NestedTupleOfConstraint class.""" + + constraint = NestedTupleOfConstraint([TensorType, TokenType]) + + def test_nested_tuple_of_constraint(self): + """Test that the properties of NestedTupleOfConstraint object are correct.""" + assert self.constraint.elem_constraints == (BaseAttr(TensorType), BaseAttr(TokenType)) + + def test_nested_tuple_of_constraint_verify_valid(self): + """Test that verifying a valid tuple of tensor and token types passes.""" + tensor = TensorType(i32, [2]) + token = TokenType() + tup = TupleType([tensor, token]) + self.constraint.verify(tup, ConstraintContext()) + + def test_nested_tuple_of_constraint_accepts_two_tensors(self): + """Test that any mix of allowed types is accepted.""" + tensor1 = TensorType(i32, [2]) + tensor2 = TensorType(i1, [1]) + tup = TupleType([tensor1, tensor2]) + self.constraint.verify(tup, ConstraintContext()) + + def test_nested_tuple_of_constraint_accepts_reversed_order(self): + """Test that the order of the tuple is not enforced.""" + tensor = TensorType(i32, [2]) + token = TokenType() + tup = TupleType([token, tensor]) + self.constraint.verify(tup, ConstraintContext()) + + def test_nested_tuple_of_constraint_accepts_nested(self): + """Test that nested tuples are accepted.""" + tensor1 = TensorType(i32, [2]) + tensor2 = TensorType(i1, [1]) + token = TokenType() + inner = TupleType([token, tensor2]) + outer = TupleType([tensor1, inner]) + self.constraint.verify(outer, ConstraintContext()) + + def test_nested_tuple_of_constraint_rejects_disallowed_type(self): + """Test that a tuple with a disallowed type raises a VerifyException.""" + tensor = TensorType(i32, [2]) + memref = MemRefType(i32, [2]) + tup = TupleType([tensor, memref]) + with pytest.raises( + VerifyException, match="tuple leaf 1 failed all allowed constraints: memref<2xi32>" + ): + self.constraint.verify(tup, ConstraintContext()) diff --git a/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py b/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py new file mode 100644 index 0000000000..45ba543a42 --- /dev/null +++ b/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py @@ -0,0 +1,321 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the traits defined within the xdsl_extras module.""" + +from typing import TypeAlias + +import pytest + +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from xdsl.dialects.builtin import AnyAttr, TensorType, f32, i32, i64 +from xdsl.ir import Attribute + +# xdsl imports +from xdsl.irdl import ( + IRDLOperation, + attr_def, + irdl_op_definition, + operand_def, + result_def, + traits_def, + var_operand_def, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.test_value import create_ssa_value + +# Import the custom traits we want to test +from catalyst.python_interface.xdsl_extras.traits import ( + AllMatchSameOperatorTrait, + Elementwise, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +AnyTensorType: TypeAlias = TensorType[Attribute] + + +# Test the SameOperandsAndResultShape trait +def test_same_operands_and_result_shape_trait(): + """Test the SameOperandsAndResultShape trait.""" + + @irdl_op_definition + class ShapeTestOp(IRDLOperation): + """Test operation for the SameOperandsAndResultShape trait.""" + + name = "test.shape_test" + traits = traits_def(SameOperandsAndResultShape()) + operand = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + assert ShapeTestOp.has_trait(SameOperandsAndResultShape) + + operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ShapeTestOp.create(operands=[operand], result_types=[TensorType(i32, [2, 3])]) + op.verify() + + op = ShapeTestOp.create(operands=[operand], result_types=[TensorType(i32, [3, 2])]) + with pytest.raises(VerifyException, match="requires the same shape"): + op.verify() + + +# Test the SameOperandsElementType trait +def test_same_operands_element_type_trait(): + """Test the SameOperandsElementType trait.""" + + @irdl_op_definition + class ShapeTestOp(IRDLOperation): + """Test operation for the SameOperandsElementType trait.""" + + name = "test.element_type_test" + traits = traits_def(SameOperandsElementType()) + + operand1 = operand_def(AnyTensorType) + operand2 = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + assert ShapeTestOp.has_trait(SameOperandsElementType) + + operand1 = create_ssa_value(TensorType(i32, [2, 3])) + operand2 = create_ssa_value(TensorType(i32, [3, 2])) + op = ShapeTestOp.create(operands=[operand1, operand2], result_types=[TensorType(i32, [2, 2])]) + op.verify() + + operand3 = create_ssa_value(TensorType(f32, [2, 3])) + op = ShapeTestOp.create(operands=[operand1, operand3], result_types=[TensorType(f32, [2, 3])]) + + with pytest.raises(VerifyException, match="requires the same element type for all operands"): + op.verify() + + +# Test the SameOperandsAndResultElementType trait +def test_same_operands_and_result_element_type_trait(): + """Test the SameOperandsAndResultElementType trait.""" + + @irdl_op_definition + class ElementTypeTestOp(IRDLOperation): + """Test operation for the SameOperandsAndResultElementType trait.""" + + name = "test.element_type_test" + traits = traits_def(SameOperandsAndResultElementType()) + + operand = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + assert ElementTypeTestOp.has_trait(SameOperandsAndResultElementType) + + op = ElementTypeTestOp.create( + operands=[create_ssa_value(TensorType(i32, [2, 3]))], + result_types=[TensorType(i32, [2, 3])], + ) + op.verify() + + op = ElementTypeTestOp.create( + operands=[create_ssa_value(TensorType(i32, [2, 3]))], + result_types=[TensorType(f32, [2, 3])], + ) + with pytest.raises( + VerifyException, match="requires the same element type for all operands and results" + ): + op.verify() + + +# Test the Elementwise trait +@irdl_op_definition +class ElementwiseTestOp(IRDLOperation): + """Test operation for the Elementwise trait.""" + + name = "test.elementwise_test" + traits = traits_def(Elementwise()) + + operand = var_operand_def(AnyAttr()) + result = result_def(AnyAttr()) + + +def test_elementwise_trait(): + """Test the Elementwise trait.""" + assert ElementwiseTestOp.has_trait(Elementwise) + + operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ElementwiseTestOp.create(operands=[operand], result_types=[TensorType(i32, [2, 3])]) + op.verify() + + operand = create_ssa_value(i32) + op = ElementwiseTestOp.create(operands=[operand], result_types=[i32]) + op.verify() + + scalar_operand = create_ssa_value(i32) + tensor_operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ElementwiseTestOp.create( + operands=[scalar_operand, tensor_operand], result_types=[TensorType(i32, [2, 3])] + ) + op.verify() + + +def test_elementwise_trait_failure_no_tensor_result(): + """Test that Elementwise trait fails when operand is tensor but result is scalar.""" + operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ElementwiseTestOp.create(operands=[operand], result_types=[i32]) + + with pytest.raises( + VerifyException, + match="if an operand is non-scalar, then there must be at least one non-scalar result", + ): + op.verify() + + +def test_elementwise_trait_failure_no_tensor_operand(): + """Test that Elementwise trait fails when result is tensor but operand is scalar.""" + operand = create_ssa_value(i32) + op = ElementwiseTestOp.create(operands=[operand], result_types=[TensorType(i32, [2, 3])]) + + with pytest.raises( + VerifyException, + match="if a result is non-scalar, then at least one operand must be non-scalar", + ): + op.verify() + + +def test_elementwise_trait_failure_shape_mismatch(): + """Test that Elementwise trait fails for shape mismatches.""" + operand1 = create_ssa_value(TensorType(i32, [2, 3])) + operand2 = create_ssa_value(TensorType(i32, [3, 2])) + op = ElementwiseTestOp.create( + operands=[operand1, operand2], result_types=[TensorType(i32, [2, 3])] + ) + + with pytest.raises( + VerifyException, + match="all non-scalar operands/results must have the same shape and base type", + ): + op.verify() + + +def test_all_shapes_match(): + """AllMatchSameOperatorTrait: shape equality on TensorType attributes.""" + + @irdl_op_definition + class ShapeMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.shapes_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = ShapeMockOp.create(attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [2, 3])}) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.get_shape(), "shape") + trait.verify(op) + + op = ShapeMockOp.create(attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [2, 4])}) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same shape"): + trait.verify(op) + + +def test_all_ranks_match(): + """AllMatchSameOperatorTrait: rank equality on TensorType attributes.""" + + @irdl_op_definition + class RankMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.ranks_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = RankMockOp.create(attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [4, 5])}) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.get_num_dims(), "rank") + trait.verify(op) + + op = RankMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [1, 2, 3])} + ) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same rank"): + trait.verify(op) + + +def test_all_element_types_match(): + """AllMatchSameOperatorTrait: element type equality on TensorType attributes.""" + + @irdl_op_definition + class ElemTypeMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.elem_types_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = ElemTypeMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [1, 6])} + ) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.get_element_type(), "element type") + trait.verify(op) + + op = ElemTypeMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(f32, [2, 3])} + ) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same element type"): + trait.verify(op) + + +def test_all_element_counts_match(): + """AllMatchSameOperatorTrait: element count equality on TensorType attributes.""" + + @irdl_op_definition + class ElemCountMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.elem_counts_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = ElemCountMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [6])} + ) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.element_count(), "element count") + trait.verify(op) + + op = ElemCountMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [2, 4])} + ) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same element count"): + trait.verify(op) + + +def test_operator_cannot_compute_raises_verifyexception(): + """Trait should raise when it cannot compute the property for given attributes.""" + + @irdl_op_definition + class CannotComputeMockOp(IRDLOperation): + """Test operation with no traits.""" + + name = "test.cannot_compute" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + # Use non-shaped attributes; calling get_shape should fail + op = CannotComputeMockOp.create(attributes={"a": i64, "b": f32}) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda x: x.get_shape(), "shape") + + with pytest.raises(VerifyException, match=r"cannot compute shape for \{a, b\}:"): + trait.verify(op) diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index b757ff63e7..afeb101a8a 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -16,7 +16,7 @@ import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE +from utils import CONFIG_CUSTOM_DEVICE from catalyst import measure, qjit from catalyst.compiler import get_lib_path diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index 6d33ec22eb..2c9737daee 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -905,7 +905,7 @@ def g(x: float): def test_mlir_opt_using_xdsl_passes(self, backend): """Test mlir opt using xDSL passes.""" # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler.transforms import iterative_cancel_inverses_pass + from catalyst.python_interface.transforms import iterative_cancel_inverses_pass @qjit @iterative_cancel_inverses_pass diff --git a/frontend/test/pytest/test_measurement_transforms.py b/frontend/test/pytest/test_measurement_transforms.py index 2ccfee0c7c..e2aae968e6 100644 --- a/frontend/test/pytest/test_measurement_transforms.py +++ b/frontend/test/pytest/test_measurement_transforms.py @@ -24,10 +24,10 @@ import numpy as np import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE from pennylane.devices import Device from pennylane.devices.capabilities import OperatorProperties from pennylane.transforms import split_non_commuting, split_to_single_terms +from utils import CONFIG_CUSTOM_DEVICE from catalyst import qjit from catalyst.compiler import get_lib_path diff --git a/frontend/test/pytest/test_measurements_results.py b/frontend/test/pytest/test_measurements_results.py index f4938afd22..1f5df00e1e 100644 --- a/frontend/test/pytest/test_measurements_results.py +++ b/frontend/test/pytest/test_measurements_results.py @@ -18,8 +18,8 @@ import numpy as np import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE from jax import numpy as jnp +from utils import CONFIG_CUSTOM_DEVICE from catalyst import CompileError, qjit from catalyst.device import get_device_capabilities diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index a918ec09d4..5e89db1ca9 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -19,10 +19,10 @@ import numpy as np import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE from pennylane.devices import Device, NullQubit from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties from pennylane.tape import QuantumScript +from utils import CONFIG_CUSTOM_DEVICE from catalyst import CompileError, ctrl, qjit from catalyst.api_extensions.control_flow import ( diff --git a/frontend/test/pytest/utils.py b/frontend/test/pytest/utils.py new file mode 100644 index 0000000000..68435015f6 --- /dev/null +++ b/frontend/test/pytest/utils.py @@ -0,0 +1,22 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pytest utilities for the Catalyst test suite. +""" + +import os +import pathlib + +TEST_PATH = os.path.dirname(__file__) +CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml")