From df8b97bb724d00ac5626bf1ffc63b0f5c2bb334e Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 10:48:54 -0500 Subject: [PATCH 01/38] Migrate unified compiler to Catalyst --- .dep-versions | 2 + frontend/catalyst/compiler.py | 6 +- .../catalyst/python_interface/__init__.py | 27 + .../catalyst/python_interface/compiler.py | 83 + .../catalyst/python_interface/conversion.py | 171 ++ .../python_interface/dialects/__init__.py | 24 + .../python_interface/dialects/catalyst.py | 270 ++++ .../python_interface/dialects/mbqc.py | 182 +++ .../catalyst/python_interface/dialects/qec.py | 218 +++ .../python_interface/dialects/quantum.py | 1139 ++++++++++++++ .../dialects/stablehlo/__init__.py | 168 ++ .../dialects/stablehlo/attributes.py | 394 +++++ .../dialects/stablehlo/control_flow.py | 160 ++ .../dialects/stablehlo/data_movement.py | 415 +++++ .../dialects/stablehlo/dialect.py | 207 +++ .../dialects/stablehlo/dynamism.py | 180 +++ .../dialects/stablehlo/elementwise_binary.py | 216 +++ .../dialects/stablehlo/elementwise_other.py | 238 +++ .../dialects/stablehlo/elementwise_unary.py | 554 +++++++ .../dialects/stablehlo/extensibility.py | 167 ++ .../dialects/stablehlo/reduction.py | 160 ++ .../dialects/stablehlo/types.py | 249 +++ .../python_interface/dialects/transform.py | 123 ++ .../doc/unified_compiler_cookbook.rst | 1376 +++++++++++++++++ .../doc/xdsl_dummy_quantum_subroutines.rst | 216 +++ .../doc/xdsl_post_processing.rst | 225 +++ .../doc/xdsl_utils_tutorial.rst | 404 +++++ frontend/catalyst/python_interface/parser.py | 69 + .../python_interface/pass_api/__init__.py | 34 + .../pass_api/apply_transform_sequence.py | 85 + .../pass_api/compiler_transform.py | 64 + .../pass_api/transform_interpreter.py | 146 ++ .../python_interface/transforms/__init__.py | 66 + .../transforms/mbqc/__init__.py | 48 + .../mbqc/convert_to_mbqc_formalism.py | 733 +++++++++ .../transforms/mbqc/decompose_graph_state.py | 209 +++ .../transforms/mbqc/graph_state_utils.py | 256 +++ .../mbqc/outline_state_evolution.py | 470 ++++++ .../transforms/quantum/__init__.py | 41 + .../transforms/quantum/cancel_inverses.py | 102 ++ .../quantum/combine_global_phases.py | 86 ++ .../quantum/diagonalize_measurements.py | 158 ++ .../quantum/measurements_from_samples.py | 659 ++++++++ .../transforms/quantum/merge_rotations.py | 141 ++ .../transforms/quantum/split_non_commuting.py | 498 ++++++ frontend/catalyst/python_interface/utils.py | 77 + .../visualization/__init__.py | 23 + .../visualization/collector.py | 146 ++ .../python_interface/visualization/draw.py | 103 ++ .../visualization/mlir_graph.py | 119 ++ .../visualization/xdsl_conversion.py | 330 ++++ .../python_interface/xdsl_extras/__init__.py | 37 + .../xdsl_extras/constraints.py | 234 +++ .../python_interface/xdsl_extras/traits.py | 240 +++ .../test/pytest/python_interface/conftest.py | 203 +++ .../dialects/test_catalyst_dialect.py | 108 ++ .../dialects/test_mbqc_dialect.py | 178 +++ .../dialects/test_qec_dialect.py | 98 ++ .../dialects/test_quantum_dialect.py | 697 +++++++++ .../dialects/test_stablehlo_dialect.py | 964 ++++++++++++ .../dialects/test_transform_dialect.py | 149 ++ .../python_interface/test_python_compiler.py | 531 +++++++ .../python_interface/test_xdsl_utils.py | 120 ++ .../transforms/mbqc/test_graph_state_utils.py | 201 +++ .../test_xdsl_convert_to_mbqc_formalism.py | 576 +++++++ .../mbqc/test_xdsl_decompose_graph_state.py | 507 ++++++ .../mbqc/test_xdsl_outline_state_evolution.py | 390 +++++ .../quantum/test_xdsl_cancel_inverses.py | 234 +++ .../test_xdsl_combine_global_phases.py | 244 +++ .../test_xdsl_diagonalize_measurements.py | 597 +++++++ .../test_xdsl_measurements_from_samples.py | 889 +++++++++++ .../quantum/test_xdsl_merge_rotations.py | 247 +++ .../quantum/test_xdsl_split_non_commuting.py | 313 ++++ .../test_draw_unified_compiler.py | 534 +++++++ .../visualization/test_mlir_graph.py | 292 ++++ .../xdsl_extras/test_constraints.py | 547 +++++++ .../xdsl_extras/test_traits.py | 321 ++++ 77 files changed, 21685 insertions(+), 3 deletions(-) create mode 100644 frontend/catalyst/python_interface/__init__.py create mode 100644 frontend/catalyst/python_interface/compiler.py create mode 100644 frontend/catalyst/python_interface/conversion.py create mode 100644 frontend/catalyst/python_interface/dialects/__init__.py create mode 100644 frontend/catalyst/python_interface/dialects/catalyst.py create mode 100644 frontend/catalyst/python_interface/dialects/mbqc.py create mode 100644 frontend/catalyst/python_interface/dialects/qec.py create mode 100644 frontend/catalyst/python_interface/dialects/quantum.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/__init__.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/attributes.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/dialect.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/reduction.py create mode 100644 frontend/catalyst/python_interface/dialects/stablehlo/types.py create mode 100644 frontend/catalyst/python_interface/dialects/transform.py create mode 100644 frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst create mode 100644 frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst create mode 100644 frontend/catalyst/python_interface/doc/xdsl_post_processing.rst create mode 100644 frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst create mode 100644 frontend/catalyst/python_interface/parser.py create mode 100644 frontend/catalyst/python_interface/pass_api/__init__.py create mode 100644 frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py create mode 100644 frontend/catalyst/python_interface/pass_api/compiler_transform.py create mode 100644 frontend/catalyst/python_interface/pass_api/transform_interpreter.py create mode 100644 frontend/catalyst/python_interface/transforms/__init__.py create mode 100644 frontend/catalyst/python_interface/transforms/mbqc/__init__.py create mode 100644 frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py create mode 100644 frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py create mode 100644 frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py create mode 100644 frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/__init__.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py create mode 100644 frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py create mode 100644 frontend/catalyst/python_interface/utils.py create mode 100644 frontend/catalyst/python_interface/visualization/__init__.py create mode 100644 frontend/catalyst/python_interface/visualization/collector.py create mode 100644 frontend/catalyst/python_interface/visualization/draw.py create mode 100644 frontend/catalyst/python_interface/visualization/mlir_graph.py create mode 100644 frontend/catalyst/python_interface/visualization/xdsl_conversion.py create mode 100644 frontend/catalyst/python_interface/xdsl_extras/__init__.py create mode 100644 frontend/catalyst/python_interface/xdsl_extras/constraints.py create mode 100644 frontend/catalyst/python_interface/xdsl_extras/traits.py create mode 100644 frontend/test/pytest/python_interface/conftest.py create mode 100644 frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py create mode 100644 frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py create mode 100644 frontend/test/pytest/python_interface/dialects/test_qec_dialect.py create mode 100644 frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py create mode 100644 frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py create mode 100644 frontend/test/pytest/python_interface/dialects/test_transform_dialect.py create mode 100644 frontend/test/pytest/python_interface/test_python_compiler.py create mode 100644 frontend/test/pytest/python_interface/test_xdsl_utils.py create mode 100644 frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py create mode 100644 frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py create mode 100644 frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py create mode 100644 frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py create mode 100644 frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py create mode 100644 frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py create mode 100644 frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py create mode 100644 frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py create mode 100644 frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py create mode 100644 frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_mlir_graph.py create mode 100644 frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py create mode 100644 frontend/test/pytest/python_interface/xdsl_extras/test_traits.py diff --git a/.dep-versions b/.dep-versions index e26b195774..e6aaadbec6 100644 --- a/.dep-versions +++ b/.dep-versions @@ -5,6 +5,8 @@ jax=0.6.2 stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 +xdsl +xdsl-jax=0.1.0 # Always remove custom PL/LQ versions before release. diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index a61f4ccfe3..2b29fdcc2f 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -542,7 +542,7 @@ def check_nested_operations(op): return False @debug_logger - def is_using_python_compiler(self, mlir_module=None): + def is_using_unified_compiler(self, mlir_module=None): """Returns true if we need the Python compiler path. This happens when: @@ -598,11 +598,11 @@ def run(self, mlir_module, *args, **kwargs): (str): filename of shared object """ - if self.is_using_python_compiler(mlir_module): + if self.is_using_unified_compiler(mlir_module): # We keep this module here to keep xDSL requirement optional # Only move this is it has been decided that xDSL is no longer optional. # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler import Compiler as PythonCompiler + from catalyst.python_interface import Compiler as PythonCompiler compiler = PythonCompiler() mlir_module = compiler.run(mlir_module) diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py new file mode 100644 index 0000000000..b38cdd4744 --- /dev/null +++ b/frontend/catalyst/python_interface/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python Compiler API for integration of Catalyst with xDSL.""" + +from .compiler import Compiler +from .parser import QuantumParser +from .pass_api import compiler_transform +from .visualization import QMLCollector + + +__all__ = [ + "Compiler", + "compiler_transform", + "QuantumParser", + "QMLCollector", +] diff --git a/frontend/catalyst/python_interface/compiler.py b/frontend/catalyst/python_interface/compiler.py new file mode 100644 index 0000000000..bded5d4552 --- /dev/null +++ b/frontend/catalyst/python_interface/compiler.py @@ -0,0 +1,83 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the PennyLane-xDSL integration API.""" + + +import io + +from jax._src.interpreters import mlir +from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.ir import Context as jaxContext # pylint: disable=no-name-in-module +from jaxlib.mlir.ir import Module as jaxModule # pylint: disable=no-name-in-module +from pennylane.typing import Callable +from xdsl.context import Context as xContext +from xdsl.dialects.builtin import ModuleOp +from xdsl.passes import ModulePass, PassPipeline +from xdsl.printer import Printer + +from catalyst.python_interface.parser import QuantumParser +from catalyst.python_interface.pass_api import ApplyTransformSequence + + +# pylint: disable=too-few-public-methods +class Compiler: + """Compiler namespace""" + + @staticmethod + def run( + module: jaxModule | str, + callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = None, + ) -> jaxModule | str: + """Runs the apply-transform-sequence pass. + + The apply-transform-sequence pass is a "meta-pass". In other words, + it is a pass that runs other passes. + + Args: + module: Either a Jax MLIR module or MLIR IR as a string + callback: Optional callback function called between passes + + Returns: + jaxModule | str: jaxModule if the input was a jaxModule, else a string. + """ + # Convert to generic text format + is_jax_module = isinstance(module, jaxModule) + if is_jax_module: + gentxtmod = module.operation.get_asm( + binary=False, print_generic_op_form=True, assume_verified=True + ) + else: + gentxtmod = module + + # Parse and transform with xDSL + ctx = xContext(allow_unregistered=True) + parser = QuantumParser(ctx, gentxtmod) + # xmod is modified in place + xmod = parser.parse_module() + pipeline = PassPipeline((ApplyTransformSequence(callback=callback),)) + pipeline.apply(ctx, xmod) + + # Convert back to string + buffer = io.StringIO() + Printer(stream=buffer, print_generic_format=True).print_op(xmod) + + # Convert back to jaxModule if input was jaxModule + if is_jax_module: + with jaxContext() as jctx: + jctx.allow_unregistered_dialects = True + jctx.append_dialect_registry(mlir.upstream_dialects) + stablehlo.register_dialect(jctx) # pylint: disable=no-member + newmod: jaxModule = jaxModule.parse(buffer.getvalue()) + return newmod + return buffer.getvalue() diff --git a/frontend/catalyst/python_interface/conversion.py b/frontend/catalyst/python_interface/conversion.py new file mode 100644 index 0000000000..614bc65ca0 --- /dev/null +++ b/frontend/catalyst/python_interface/conversion.py @@ -0,0 +1,171 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for converting to xDSL module.""" + +from collections.abc import Callable, Sequence +from functools import wraps +from typing import TypeAlias + +from jax._src.lib import _jax +from jaxlib.mlir.dialects import stablehlo as jstablehlo # pylint: disable=no-name-in-module +from jaxlib.mlir.ir import Context as jContext # pylint: disable=no-name-in-module +from jaxlib.mlir.ir import Module as jModule # pylint: disable=no-name-in-module +from xdsl.context import Context as xContext +from xdsl.dialects import builtin as xbuiltin +from xdsl.dialects import func as xfunc +from xdsl.ir import Dialect as xDialect +from xdsl.traits import SymbolTable as xSymbolTable + +from catalyst import QJIT +from catalyst.python_interface.parser import QuantumParser + +JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member + + +def _mlir_module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule: + """Get the MLIR module from a jax.jitted function""" + return func.lower(*args, **kwargs).compiler_ir() + + +def mlir_module(func: JaxJittedFunction) -> Callable[..., jModule]: + """Returns a wrapper that creates an MLIR module from a jax.jitted function.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> jModule: + return _mlir_module_inline(func, *args, **kwargs) + + return wrapper + + +def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover + """Create the generic textual representation for a jax.jitted function""" + lowered = func.lower(*args, **kwargs) + mod = lowered.compiler_ir() + return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True) + + +def generic_str(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover + """Returns a wrapper that creates the generic textual representation for a + jax.jitted function.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> str: + return _generic_str_inline(func, *args, **kwargs) + + return wrapper + + +def parse_generic_to_xdsl_module( + program: str, extra_dialects: Sequence[xDialect] | None = None +) -> xbuiltin.ModuleOp: # pragma: no cover + """Parses a generic MLIR program string to an xDSL module.""" + ctx = xContext(allow_unregistered=True) + parser = QuantumParser(ctx, program, extra_dialects=extra_dialects) + moduleOp: xbuiltin.ModuleOp = parser.parse_module() + return moduleOp + + +def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover + """Parses a generic MLIR program string to an MLIR module.""" + with jContext() as ctx: + ctx.allow_unregistered_dialects = True + jstablehlo.register_dialect(ctx) # pylint: disable=no-member + return jModule.parse(program) + + +def mlir_from_docstring(func: Callable) -> jModule: # pragma: no cover + """Returns a wrapper that parses an MLIR program string located in the docstring + into an MLIR module.""" + + @wraps(func) + def wrapper(*_, **__): + return parse_generic_to_mlir_module(func.__doc__) + + return wrapper + + +def _xdsl_module_inline( + func: JaxJittedFunction, *args, **kwargs +) -> xbuiltin.ModuleOp: # pragma: no cover + """Get the xDSL module from a jax.jitted function""" + generic_repr = _generic_str_inline(func, *args, **kwargs) + return parse_generic_to_xdsl_module(generic_repr) + + +def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover + """Returns a wrapper that parses an MLIR program string located in the docstring + into an xDSL module.""" + + @wraps(func) + def wrapper(*_, **__): + return parse_generic_to_xdsl_module(func.__doc__) + + return wrapper + + +def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover + """Returns a wrapper that creates an xDSL module from a jax.jitted function.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp: + return _xdsl_module_inline(func, *args, **kwargs) + + return wrapper + + +def inline_module( + from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None +) -> None: + """Inline the contents of one xDSL module into another xDSL module. The inlined body is appended + to the end of ``to_mod``. + + If ``from_mod`` has a ``main`` function, its name is changed to ``change_main_to`` if specified. + """ + if change_main_to: + main = xSymbolTable.lookup_symbol(from_mod, "main") + if main is not None: + assert isinstance(main, xfunc.FuncOp) + main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to) + + for op in from_mod.body.ops: + xSymbolTable.insert_or_update(to_mod, op.clone()) + + +def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp) -> Callable[..., None]: + """Inline a ``jax.jit``-ed Python function to an xDSL module. The inlined body is appended + to the end of ``mod`` in-place. The name of the entry point function of ``func`` is the same + as the name of ``func``.""" + + @wraps(func) + def wrapper(*args, **kwargs): + func_mod = _xdsl_module_inline(func, *args, **kwargs) + inline_module(func_mod, mod, change_main_to=func.__name__) + + return wrapper + + +def xdsl_from_qjit(func: QJIT) -> Callable[..., xbuiltin.ModuleOp]: + """Decorator to convert QJIT-ed functions into xDSL modules.""" + + @wraps(func) + def wrapper(*args, **kwargs): + func.jaxpr, *_ = func.capture(args, **kwargs) + _mlir_module = func.generate_ir() + _generic_str = _mlir_module.operation.get_asm( + binary=False, print_generic_op_form=True, assume_verified=True + ) + return parse_generic_to_xdsl_module(_generic_str) + + return wrapper diff --git a/frontend/catalyst/python_interface/dialects/__init__.py b/frontend/catalyst/python_interface/dialects/__init__.py new file mode 100644 index 0000000000..1b0ab7d1a9 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This submodule contains xDSL dialects for the unified compiler.""" + +from .catalyst import Catalyst +from .mbqc import MBQC +from .qec import QEC +from .quantum import Quantum +from .stablehlo import StableHLO +from .transform import Transform + +__all__ = ["Catalyst", "MBQC", "Quantum", "QEC", "StableHLO", "Transform"] diff --git a/frontend/catalyst/python_interface/dialects/catalyst.py b/frontend/catalyst/python_interface/dialects/catalyst.py new file mode 100644 index 0000000000..7b3c474063 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/catalyst.py @@ -0,0 +1,270 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the Catalyst dialect for the Python compiler. + +This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool) and modified manually +to support the unified compiler. + +The catalyst dialect serves as a standard library for the Catalyst compiler. +It contains data structures that support core compiler functionality. +""" + +# pylint: disable=too-few-public-methods + +from typing import ClassVar + +from xdsl.dialects.builtin import ( + I64, + ArrayAttr, + DenseArrayBase, + DictionaryAttr, + FlatSymbolRefAttrConstr, + FunctionType, + IntegerAttr, + IntegerType, + MemRefType, + StringAttr, + SymbolRefAttr, + UnitAttr, + i32, +) +from xdsl.ir import AttributeCovT, Dialect, Generic, ParametrizedAttribute, TypeAttribute +from xdsl.irdl import ( + AnyAttr, + IRDLOperation, + ParsePropInAttrDict, + VarConstraint, + base, + irdl_attr_definition, + irdl_op_definition, + operand_def, + opt_operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + var_operand_def, + var_result_def, +) + + +@irdl_attr_definition +class ArrayListType(Generic[AttributeCovT], ParametrizedAttribute, TypeAttribute): + """A dynamically resizable array""" + + name = "catalyst.arraylist" + + element_type: AttributeCovT + + +@irdl_op_definition +class AssertionOp(IRDLOperation): + """Asserts condition at runtime.""" + + name = "catalyst.assert" + + assertion = operand_def(IntegerType(1)) + + error = prop_def(StringAttr) + + +@irdl_op_definition +class CallbackCallOp(IRDLOperation): + """CallbackCallOp operation.""" + + name = "catalyst.callback_call" + + assembly_format = """ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + """ + + callee = prop_def(FlatSymbolRefAttrConstr) + + inputs = var_operand_def() + + arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + callback_results = var_result_def() + + irdl_options = [ParsePropInAttrDict()] + + +@irdl_op_definition +class CallbackOp(IRDLOperation): + """Operation denoting a symbol to refer to user callbacks.""" + + name = "catalyst.callback" + + sym_name = prop_def(StringAttr) + + function_type = prop_def(FunctionType) + + id = prop_def(IntegerAttr[I64]) + + argc = prop_def(IntegerAttr[I64]) + + resc = prop_def(IntegerAttr[I64]) + + arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + body = region_def() + + +@irdl_op_definition +class CustomCallOp(IRDLOperation): + """CustomCall operation""" + + name = "catalyst.custom_call" + + assembly_format = """ + `fn` `(`$call_target_name`)` `(` $inputs `)` + attr-dict `:` functional-type(operands, results) + """ + + inputs = var_operand_def() + + call_target_name = prop_def(StringAttr) + + number_original_arg = opt_prop_def(DenseArrayBase.constr(i32)) + + custom_results = var_result_def() + + irdl_options = [ParsePropInAttrDict()] + + +@irdl_op_definition +class LaunchKernelOp(IRDLOperation): + """LaunchKernelOp operation.""" + + name = "catalyst.launch_kernel" + + assembly_format = """ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + """ + + callee = prop_def(SymbolRefAttr) + + inputs = var_operand_def() + + arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr]) + + kernel_results = var_result_def() + + irdl_options = [ParsePropInAttrDict()] + + +@irdl_op_definition +class ListDeallocOp(IRDLOperation): + """Deallocate the underlying memory of an arraylist.""" + + name = "catalyst.list_dealloc" + + assembly_format = """ $list attr-dict `:` type($list) """ + + list = operand_def(ArrayListType) + + +@irdl_op_definition +class ListInitOp(IRDLOperation): + """Initialize a dynamically resizable arraylist.""" + + name = "catalyst.list_init" + + assembly_format = """ attr-dict `:` type($list) """ + + list = result_def(ArrayListType) + + +@irdl_op_definition +class ListLoadDataOp(IRDLOperation): + """Get the underlying memref storing the data of an array list.""" + + name = "catalyst.list_load_data" + + assembly_format = """ $list attr-dict `:` type($list) `->` type($data) """ + + list = operand_def(ArrayListType) + + data = result_def(MemRefType) + + +@irdl_op_definition +class ListPopOp(IRDLOperation): + """Remove an element from the end of an array list and return it.""" + + name = "catalyst.list_pop" + + assembly_format = """ $list attr-dict `:` type($list) """ + + T: ClassVar = VarConstraint("T", AnyAttr()) + + list = operand_def(base(ArrayListType[T])) + + result = result_def(T) + + +@irdl_op_definition +class ListPushOp(IRDLOperation): + """Append an element to the end of an array list.""" + + name = "catalyst.list_push" + + assembly_format = """ $value `,` $list attr-dict `:` type($list) """ + + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + + list = operand_def(base(ArrayListType[T])) + + +@irdl_op_definition +class PrintOp(IRDLOperation): + """Prints numeric values or constant strings at runtime.""" + + name = "catalyst.print" + + val = opt_operand_def() + + const_val = opt_prop_def(StringAttr) + + print_descriptor = prop_def(UnitAttr) + + +Catalyst = Dialect( + "catalyst", + [ + AssertionOp, + CallbackCallOp, + CallbackOp, + CustomCallOp, + LaunchKernelOp, + ListDeallocOp, + ListInitOp, + ListLoadDataOp, + ListPopOp, + ListPushOp, + PrintOp, + ], + [ + ArrayListType, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/mbqc.py b/frontend/catalyst/python_interface/dialects/mbqc.py new file mode 100644 index 0000000000..8cf96b4910 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/mbqc.py @@ -0,0 +1,182 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the definition of the MBQC dialect for the Python compiler. + +The MBQC dialect is a set of operations and types used to represent measurement-based +quantum-computing instructions in the xDSL framework. + +It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the +catalyst/mlir/include/MBQC/IR/MBQCDialect.td file in the catalyst repository. + +For detailed documentation on the operations contained in this dialect, please refer to the MBQC +dialect documentation in Catalyst. +""" + +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + I32, + AnyAttr, + Float64Type, + IntegerAttr, + IntegerType, + StringAttr, + i1, +) +from xdsl.ir import Dialect, EnumAttribute, Operation, SpacedOpaqueSyntaxAttribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_attr_definition, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + result_def, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.str_enum import StrEnum # StrEnum is standard in Python>=3.11 + +from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint + +from .quantum import QubitType, QuregType + +QubitSSAValue: TypeAlias = SSAValue[QubitType] + + +class MeasurementPlaneEnum(StrEnum): + """Enum containing supported measurement-plane attributes""" + + # pylint: disable=too-few-public-methods + + XY = "XY" + YZ = "YZ" + ZX = "ZX" + + +@irdl_attr_definition +class MeasurementPlaneAttr(EnumAttribute[MeasurementPlaneEnum], SpacedOpaqueSyntaxAttribute): + """Planes in the Bloch sphere representation with support for arbitrary-basis measurements""" + + # pylint: disable=too-few-public-methods + + name = "mbqc.measurement_plane" + + +@irdl_op_definition +class MeasureInBasisOp(IRDLOperation): + """A parametric single-qubit projective measurement in an arbitrary basis.""" + + # pylint: disable=too-few-public-methods + + name = "mbqc.measure_in_basis" + + assembly_format = """ + `[` $plane `,` $angle `]` $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results) + """ + + in_qubit = operand_def(QubitType) + + plane = prop_def(MeasurementPlaneAttr) + + angle = operand_def(Float64Type()) + + postselect = opt_prop_def(IntegerAttr[I32]) + + mres = result_def(IntegerType(1)) + + out_qubit = result_def(QubitType) + + def __init__( + self, + in_qubit: QubitSSAValue | Operation, + plane: MeasurementPlaneAttr, + angle: SSAValue[Float64Type], + postselect: int | IntegerAttr | None = None, + ): + properties = {"plane": plane} + + if isinstance(postselect, int): + postselect = IntegerAttr.from_int_and_width(postselect, 32) + + if postselect is not None: + properties["postselect"] = postselect + + super().__init__( + operands=(in_qubit, angle), + properties=properties, + result_types=(IntegerType(1), QubitType()), + ) + + def verify_(self): + """Verify operation when rewriting.""" + if self.postselect is None: + return + + if self.postselect.value.data not in [0, 1]: + raise VerifyException("'postselect' must be 0 or 1.") + + +@irdl_op_definition +class GraphStatePrepOp(IRDLOperation): + """Allocate resources for a new graph state.""" + + # pylint: disable=too-few-public-methods + + name = "mbqc.graph_state_prep" + + assembly_format = """ + `(` $adj_matrix `:` type($adj_matrix) `)` `[` `init` $init_op `,` `entangle` $entangle_op `]` attr-dict `:` type(results) + """ + + adj_matrix = operand_def( + TensorConstraint(element_type=i1, rank=1) | MemRefConstraint(element_type=i1, rank=1) + ) + + init_op = prop_def(StringAttr) + + entangle_op = prop_def(StringAttr) + + qreg = result_def(QuregType) + + def __init__( + self, adj_matrix: AnyAttr, init_op: str | StringAttr, entangle_op: str | StringAttr + ): + if isinstance(init_op, str): + init_op = StringAttr(data=init_op) + + if isinstance(entangle_op, str): + entangle_op = StringAttr(data=entangle_op) + + properties = {"init_op": init_op, "entangle_op": entangle_op} + + qreg = QuregType() + + super().__init__( + operands=(adj_matrix,), + result_types=(qreg,), + properties=properties, + ) + + +MBQC = Dialect( + "mbqc", + [ + MeasureInBasisOp, + GraphStatePrepOp, + ], + [ + MeasurementPlaneAttr, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/qec.py b/frontend/catalyst/python_interface/dialects/qec.py new file mode 100644 index 0000000000..962adaba9f --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/qec.py @@ -0,0 +1,218 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the QEC dialect for the Python compiler. + +The QEC dialect is a set of operations and types used to represent quantum error correction +instructions in the xDSL framework. + +It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the +catalyst/mlir/include/QEC/IR/QECDialect.td file in the catalyst repository. +""" + +# pylint: disable=too-few-public-methods + +from xdsl.dialects.builtin import I16, ArrayAttr, IntegerAttr, IntegerType, StringAttr, i16 +from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.ir import Attribute, Dialect, EnumAttribute, SpacedOpaqueSyntaxAttribute +from xdsl.irdl import ( + AttrSizedOperandSegments, + IRDLOperation, + irdl_attr_definition, + irdl_op_definition, + lazy_traits_def, + operand_def, + opt_operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import HasParent, IsTerminator, Pure, SingleBlockImplicitTerminator +from xdsl.utils.str_enum import StrEnum + +from .quantum import QubitType + + +class LogicalInitKind(StrEnum): + """The initial state of a logical qubit such as |0⟩, |1⟩, |+⟩, |−⟩, |Y⟩, |-Y⟩, |m⟩, or |m̅⟩.""" + + Zero = "zero" # |0⟩ Non-magic state + One = "one" # |1⟩ Non-magic state + Plus = "plus" # |+⟩ = (|0⟩ + |1⟩) / sqrt(2) Non-magic state + Minus = "minus" # |-⟩ = (|0⟩ - |1⟩) / sqrt(2) Non-magic state + PlusI = "plus_i" # |Y⟩ = (|0⟩ + i|1⟩) / sqrt(2) Non-magic / Magic state + MinusI = "minus_i" # |-Y⟩ = (|0⟩ - i|1⟩) / sqrt(2) Non-magic / Magic state + Magic = "magic" # |m⟩ = |0⟩ + e^{iπ/4}|1⟩ Magic state + MagicConj = "magic_conj" # |m̅⟩ = |0⟩ + e^{-iπ/4}|1⟩ Magic state + + +@irdl_attr_definition +class LogicalInit(EnumAttribute[LogicalInitKind], SpacedOpaqueSyntaxAttribute): + """The initial state of a logical qubit such as |0⟩, |1⟩, |+⟩, |−⟩, |Y⟩, |-Y⟩, |m⟩, or |m̅⟩.""" + + name = "qec.enum" + + +# Type alias for a product of Pauli operators, aka a Pauli word. +PauliWord = ArrayAttr[StringAttr] + + +@irdl_op_definition +class YieldOp(AbstractYieldOperation[Attribute]): + """Return results from a layer region""" + + name = "qec.yield" + + traits = lazy_traits_def(lambda: (IsTerminator(), HasParent(LayerOp), Pure())) + + +@irdl_op_definition +class FabricateOp(IRDLOperation): + """Fabricate axillary qubits from qubit factories.""" + + name = "qec.fabricate" + + assembly_format = """ + $init_state attr-dict `:` type($out_qubits) + """ + + init_state = prop_def(LogicalInit) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class LayerOp(IRDLOperation): + """A layer operation""" + + name = "qec.layer" + + initArgs = var_operand_def() + + results = var_result_def() + + region = region_def("single_block") + + traits = traits_def(SingleBlockImplicitTerminator(YieldOp)) + + # TODO: add a custom parse and print for this operation + + +@irdl_op_definition +class PPMeasurementOp(IRDLOperation): + """Pauli Product Measurement on qubits.""" + + name = "qec.ppm" + + assembly_format = """ + $pauli_product (`(` $rotation_sign^ `)`)? $in_qubits (`cond` `(` $condition^ `)`)? attr-dict `:` type(results) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + pauli_product = prop_def(PauliWord) + + rotation_sign = opt_prop_def(IntegerAttr[I16], default_value=IntegerAttr(1, i16)) + + in_qubits = var_operand_def(QubitType) + + condition = opt_operand_def(IntegerType(1)) + + mres = result_def(IntegerType(1)) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class PPRotationOp(IRDLOperation): + """Pauli Product Rotation on qubits.""" + + name = "qec.ppr" + + assembly_format = """ + $pauli_product `(` $rotation_kind `)` $in_qubits attr-dict (`cond` `(` $condition^ `)`)? `:` type($out_qubits) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + pauli_product = prop_def(PauliWord) + + rotation_kind = prop_def(IntegerAttr[IntegerType(16)]) + + in_qubits = var_operand_def(QubitType) + + condition = opt_operand_def(IntegerType(1)) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class PrepareStateOp(IRDLOperation): + """Initialize existing qubits into a given state.""" + + name = "qec.prepare" + + assembly_format = """ + $init_state $in_qubits attr-dict `:` type($out_qubits) + """ + + init_state = prop_def(LogicalInit) + + in_qubits = var_operand_def(QubitType) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class SelectPPMeasurementOp(IRDLOperation): + """Multiplexed Pauli product measurement.""" + + name = "qec.select.ppm" + + assembly_format = """ + `(` $select_switch `,` $pauli_product_0 `,` $pauli_product_1 `)` $in_qubits attr-dict `:` type(results) + """ + + select_switch = operand_def(IntegerType(1)) + + pauli_product_0 = prop_def(PauliWord) + + pauli_product_1 = prop_def(PauliWord) + + in_qubits = var_operand_def(QubitType) + + mres = result_def(IntegerType(1)) + + out_qubits = var_result_def(QubitType) + + +QEC = Dialect( + "qec", + [ + FabricateOp, + LayerOp, + PPMeasurementOp, + PPRotationOp, + PrepareStateOp, + SelectPPMeasurementOp, + YieldOp, + ], + [ + LogicalInit, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/quantum.py b/frontend/catalyst/python_interface/dialects/quantum.py new file mode 100644 index 0000000000..a4915fd5fc --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/quantum.py @@ -0,0 +1,1139 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the definition of the Quantum dialect for the unified compiler. + +The Quantum dialect is a set of operations and types used to represent quantum computations +in the xDSL framework. + +It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) +starting from the catalyst/mlir/include/Quantum/IR/QuantumOps.td file in the catalyst repository. +""" + +# pylint: disable=too-few-public-methods + +from collections.abc import Sequence +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + I32, + I64, + ComplexType, + Float64Type, + FloatAttr, + IntegerAttr, + IntegerType, + MemRefType, + StringAttr, + TensorType, + UnitAttr, + i1, + i64, +) +from xdsl.ir import ( + Block, + Dialect, + EnumAttribute, + Operation, + ParametrizedAttribute, + Region, + SpacedOpaqueSyntaxAttribute, + SSAValue, + StrEnum, + TypeAttribute, +) +from xdsl.irdl import ( + AtLeast, + AttrSizedOperandSegments, + AttrSizedResultSegments, + IntSetConstraint, + IRDLOperation, + ParsePropInAttrDict, + SameVariadicResultSize, + irdl_attr_definition, + irdl_op_definition, + lazy_traits_def, + operand_def, + opt_operand_def, + opt_prop_def, + opt_result_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + HasParent, + IsTerminator, + NoMemoryEffect, + Pure, + ReturnLike, + SingleBlockImplicitTerminator, +) + +from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint + +################################################################ +######################## ATTRIBUTES ############################ +################################################################ + + +@irdl_attr_definition +class ObservableType(ParametrizedAttribute, TypeAttribute): + """A quantum observable for use in measurements.""" + + name = "quantum.obs" + + +@irdl_attr_definition +class QubitType(ParametrizedAttribute, TypeAttribute): + """A value-semantic qubit (state).""" + + name = "quantum.bit" + + +@irdl_attr_definition +class QuregType(ParametrizedAttribute, TypeAttribute): + """An array of value-semantic qubits (i.e. quantum register).""" + + name = "quantum.reg" + + +@irdl_attr_definition +class ResultType(ParametrizedAttribute, TypeAttribute): + """A quantum measurement result.""" + + name = "quantum.res" + + +class NamedObservable(StrEnum): + """Known named observables""" + + Identity = "Identity" + PauliX = "PauliX" + PauliY = "PauliY" + PauliZ = "PauliZ" + Hadamard = "Hadamard" + + +@irdl_attr_definition +class NamedObservableAttr(EnumAttribute[NamedObservable], SpacedOpaqueSyntaxAttribute): + """Known named observables""" + + name = "quantum.named_observable" + + +################################################################ +######################## OPERATIONS ############################ +################################################################ + + +QubitSSAValue: TypeAlias = SSAValue[QubitType] +QuregSSAValue: TypeAlias = SSAValue[QuregType] +ObservableSSAValue: TypeAlias = SSAValue[ObservableType] + + +@irdl_op_definition +class AdjointOp(IRDLOperation): + """Calculate the adjoint of the enclosed operations""" + + name = "quantum.adjoint" + + assembly_format = """ + `(` $qreg `)` attr-dict `:` type(operands) $region + """ + + qreg = operand_def(QuregType) + + out_qreg = result_def(QuregType) + + region = region_def("single_block") + + traits = lazy_traits_def(lambda: (NoMemoryEffect(), SingleBlockImplicitTerminator(YieldOp))) + + def __init__( + self, + qreg: QuregSSAValue | Operation, + region: Region | Sequence[Operation] | Sequence[Block], + ): + super().__init__(operands=(qreg,), result_types=(QuregType(),), regions=(region,)) + + +@irdl_op_definition +class AllocOp(IRDLOperation): + """Allocate n qubits into a quantum register.""" + + name = "quantum.alloc" + + assembly_format = """ + `(` ($nqubits^):($nqubits_attr)? `)` attr-dict `:` type(results) + """ + + nqubits = opt_operand_def(i64) + + nqubits_attr = opt_prop_def(IntegerAttr.constr(type=I64, value=AtLeast(0))) + + qreg = result_def(QuregType) + + def __init__(self, nqubits): + if isinstance(nqubits, int): + nqubits = IntegerAttr.from_int_and_width(nqubits, 64) + + if isinstance(nqubits, IntegerAttr): + operands = (None,) + properties = {"nqubits_attr": nqubits} + else: + operands = (nqubits,) + properties = {} + + super().__init__(operands=operands, properties=properties, result_types=(QuregType(),)) + + +@irdl_op_definition +class AllocQubitOp(IRDLOperation): + """Allocate a single qubit.""" + + name = "quantum.alloc_qb" + + assembly_format = """attr-dict `:` type(results)""" + + qubit = result_def(QubitType) + + def __init__(self): + super().__init__( + result_types=(QubitType(),), + ) + + +@irdl_op_definition +class ComputationalBasisOp(IRDLOperation): + """Define a pseudo-obeservable of the computational basis for use in measurements""" + + name = "quantum.compbasis" + + assembly_format = """ + (`qubits` $qubits^)? (`qreg` $qreg^)? attr-dict `:` type(results) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + qubits = var_operand_def(QubitType) + + qreg = opt_operand_def(QuregType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class CountsOp(IRDLOperation): + """Compute sample counts for the given observable for the current state""" + + name = "quantum.counts" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $in_eigvals^ `:` type($in_eigvals) `,` $in_counts `:` type($in_counts) `)` )? + attr-dict ( `:` type($eigvals)^ `,` type($counts) )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + SameVariadicResultSize(), + ] + + obs = operand_def(ObservableType) + + dynamic_shape = opt_operand_def(i64) + + in_eigvals = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1)) + + in_counts = opt_operand_def(MemRefConstraint(element_type=i64, rank=1)) + + eigvals = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1)) + + counts = opt_result_def(TensorConstraint(element_type=i64, rank=1)) + + +@irdl_op_definition +class CustomOp(IRDLOperation): + """A generic quantum gate on n qubits with m floating point parameters.""" + + name = "quantum.custom" + + assembly_format = """ + $gate_name `(` $params `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + params = var_operand_def(Float64Type()) + + in_qubits = var_operand_def(QubitType) + + gate_name = prop_def(StringAttr) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + gate_name: str | StringAttr, + params: SSAValue[Float64Type] | Sequence[SSAValue[Float64Type]] | None = None, + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + params = () if params is None else params + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(params, Sequence): + params = (params,) + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + if isinstance(gate_name, str): + gate_name = StringAttr(data=gate_name) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {"gate_name": gate_name} + if adjoint: + properties["adjoint"] = UnitAttr() + + super().__init__( + operands=(params, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class DeallocOp(IRDLOperation): + """Deallocate a quantum register.""" + + name = "quantum.dealloc" + + assembly_format = """ + $qreg attr-dict `:` type(operands) + """ + + qreg = operand_def(QuregType) + + def __init__(self, qreg: QuregSSAValue | Operation): + super().__init__(operands=(qreg,)) + + +@irdl_op_definition +class DeallocQubitOp(IRDLOperation): + """Deallocate a single qubit.""" + + name = "quantum.dealloc_qb" + + assembly_format = """$qubit attr-dict `:` type(operands)""" + + qubit = operand_def(QubitType) + + def __init__(self, qubit: QubitSSAValue | Operation): + super().__init__( + operands=(qubit,), + ) + + +@irdl_op_definition +class DeviceInitOp(IRDLOperation): + """Initialize a quantum device.""" + + name = "quantum.device" + + assembly_format = """ + (`shots` `(` $shots^ `)`)? `[` $lib `,` $device_name `,` $kwargs `]` attr-dict + """ + + irdl_options = [ParsePropInAttrDict()] + + shots = opt_operand_def(i64) + + auto_qubit_management = opt_prop_def(UnitAttr) + + lib = prop_def(StringAttr) + + device_name = prop_def(StringAttr) + + kwargs = prop_def(StringAttr) + + +@irdl_op_definition +class DeviceReleaseOp(IRDLOperation): + """Release the active quantum device.""" + + name = "quantum.device_release" + + assembly_format = "attr-dict" + + +@irdl_op_definition +class ExpvalOp(IRDLOperation): + """Compute the expectation value of the given observable for the current state""" + + name = "quantum.expval" + + assembly_format = "$obs attr-dict `:` type(results)" + + obs = operand_def(ObservableType) + + expval = result_def(Float64Type()) + + def __init__(self, obs: ObservableSSAValue | Operation): + super().__init__(operands=(obs,), result_types=(Float64Type(),)) + + +@irdl_op_definition +class ExtractOp(IRDLOperation): + """Extract a qubit value from a register.""" + + name = "quantum.extract" + + assembly_format = """ + $qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results) + """ + + qreg = operand_def(QuregType) + + idx = opt_operand_def(i64) + + idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0))) + + qubit = result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + def __init__( + self, + qreg: QuregSSAValue | Operation, + idx: int | SSAValue[IntegerType] | Operation | IntegerAttr, + ): + if isinstance(idx, int): + idx = IntegerAttr.from_int_and_width(idx, 64) + + if isinstance(idx, IntegerAttr): + operands = (qreg, None) + properties = {"idx_attr": idx} + else: + operands = (qreg, idx) + properties = {} + + super().__init__( + operands=operands, + result_types=(QubitType(),), + properties=properties, + ) + + +@irdl_op_definition +class FinalizeOp(IRDLOperation): + """Teardown the quantum runtime.""" + + name = "quantum.finalize" + + assembly_format = "attr-dict" + + +@irdl_op_definition +class GlobalPhaseOp(IRDLOperation): + """Global Phase.""" + + name = "quantum.gphase" + + assembly_format = """ + `(` $params `)` + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type(results) + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] + + params = operand_def(Float64Type()) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_ctrl_qubits = var_result_def(QubitType) + + def __init__( + self, + *, + params: float | SSAValue[Float64Type], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + ): + if isinstance(params, float): + params = FloatAttr(data=params, type=Float64Type()) + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + + super().__init__( + operands=(params, in_ctrl_qubits, in_ctrl_values), + result_types=(out_ctrl_qubits,), + ) + + +@irdl_op_definition +class HamiltonianOp(IRDLOperation): + """Define a Hamiltonian observable for use in measurements""" + + name = "quantum.hamiltonian" + + assembly_format = """ + `(` $coeffs `:` type($coeffs) `)` $terms attr-dict `:` type(results) + """ + + coeffs = operand_def( + TensorConstraint(element_type=Float64Type(), rank=1) + | (MemRefConstraint(element_type=Float64Type(), rank=1)) + ) + + terms = var_operand_def(ObservableType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class HermitianOp(IRDLOperation): + """Define a Hermitian observable for use in measurements""" + + name = "quantum.hermitian" + + assembly_format = """ + `(` $matrix `:` type($matrix) `)` $qubits attr-dict `:` type(results) + """ + + matrix = operand_def( + TensorConstraint(element_type=ComplexType(Float64Type()), rank=2) + | MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2) + ) + + qubits = var_operand_def(QubitType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class InitializeOp(IRDLOperation): + """Initialize the quantum runtime.""" + + name = "quantum.init" + + assembly_format = "attr-dict" + + +@irdl_op_definition +class InsertOp(IRDLOperation): + """Update the qubit value of a register.""" + + name = "quantum.insert" + + assembly_format = """ + $in_qreg `[` ($idx^):($idx_attr)? `]` `,` $qubit attr-dict `:` type($in_qreg) `,` type($qubit) + """ + + in_qreg = operand_def(QuregType) + + idx = opt_operand_def(i64) + + idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0))) + + qubit = operand_def(QubitType) + + out_qreg = result_def(QuregType) + + traits = traits_def(NoMemoryEffect()) + + def __init__( + self, + in_qreg: QuregSSAValue | Operation, + idx: SSAValue[IntegerType] | Operation | int | IntegerAttr, + qubit: QubitSSAValue | Operation, + ): + if isinstance(idx, int): + idx = IntegerAttr.from_int_and_width(idx, 64) + + if isinstance(idx, IntegerAttr): + operands = (in_qreg, None, qubit) + properties = {"idx_attr": idx} + else: + operands = (in_qreg, idx, qubit) + properties = {} + + super().__init__(operands=operands, properties=properties, result_types=(QuregType(),)) + + +@irdl_op_definition +class MeasureOp(IRDLOperation): + """A single-qubit projective measurement in the computational basis.""" + + name = "quantum.measure" + + assembly_format = """ + $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results) + """ + + in_qubit = operand_def(QubitType) + + postselect = opt_prop_def( + IntegerAttr.constr(type=I32, value=IntSetConstraint(frozenset((0, 1)))) + ) + + mres = result_def(i1) + + out_qubit = result_def(QubitType) + + def __init__( + self, in_qubit: QubitSSAValue | Operation, postselect: int | IntegerAttr | None = None + ): + if isinstance(postselect, int): + postselect = IntegerAttr.from_int_and_width(postselect, 32) + + if postselect is None: + properties = {} + else: + properties = {"postselect": postselect} + + super().__init__( + operands=(in_qubit,), properties=properties, result_types=(i1, QubitType()) + ) + + +@irdl_op_definition +class MultiRZOp(IRDLOperation): + """Apply an arbitrary multi Z rotation""" + + name = "quantum.multirz" + + assembly_format = """ + `(` $theta `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + theta = operand_def(Float64Type()) + + in_qubits = var_operand_def(QubitType) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + theta: SSAValue[Float64Type], + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {"adjoint": UnitAttr()} if adjoint else {} + + super().__init__( + operands=(theta, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class NamedObsOp(IRDLOperation): + """Define a Named observable for use in measurements""" + + name = "quantum.namedobs" + + assembly_format = """ + $qubit `[` $type `]` attr-dict `:` type(results) + """ + + qubit = operand_def(QubitType) + + type = prop_def(NamedObservableAttr) + + obs = result_def(ObservableType) + + def __init__(self, qubit: QubitSSAValue | Operation, obs_type: NamedObservableAttr): + super().__init__( + operands=(qubit,), properties={"type": obs_type}, result_types=(ObservableType(),) + ) + + +@irdl_op_definition +class NumQubitsOp(IRDLOperation): + """Get the number of currently allocated qubits.""" + + name = "quantum.num_qubits" + + assembly_format = """ + attr-dict `:` type(results) + """ + + num_qubits = result_def(i64) + + +@irdl_op_definition +class PCPhaseOp(IRDLOperation): + """Apply a projector-controlled phase gate""" + + name = "quantum.pcphase" + + assembly_format = """ + `(` $theta `,` $dim `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + theta = operand_def(Float64Type()) + + dim = operand_def(Float64Type()) + + in_qubits = var_operand_def(QubitType) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + theta: SSAValue[Float64Type], + dim: SSAValue[Float64Type], + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {"adjoint": UnitAttr()} if adjoint else {} + + super().__init__( + operands=(theta, dim, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class ProbsOp(IRDLOperation): + """Compute computational basis probabilities for the current state""" + + name = "quantum.probs" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $state_in^ `:` type($state_in) `)` )? + attr-dict ( `:` type($probabilities)^ )? + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + obs = operand_def(ObservableType) + + dynamic_shape = opt_operand_def(i64) + + state_in = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1)) + + probabilities = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1)) + + +@irdl_op_definition +class QubitUnitaryOp(IRDLOperation): + """Apply an arbitrary fixed unitary matrix""" + + name = "quantum.unitary" + + assembly_format = """ + `(` $matrix `:` type($matrix) `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + matrix = operand_def( + (TensorConstraint(element_type=ComplexType(Float64Type()), rank=2)) + | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2)) + ) + + in_qubits = var_operand_def(QubitType) + + adjoint = opt_prop_def(UnitAttr) + + in_ctrl_qubits = var_operand_def(QubitType) + + in_ctrl_values = var_operand_def(i1) + + out_qubits = var_result_def(QubitType) + + out_ctrl_qubits = var_result_def(QubitType) + + traits = traits_def(NoMemoryEffect()) + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + matrix: SSAValue[TensorType | MemRefType], + in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation], + in_ctrl_qubits: ( + QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None + ) = None, + in_ctrl_values: ( + SSAValue[IntegerType] + | Operation + | Sequence[SSAValue[IntegerType]] + | Sequence[Operation] + | None + ) = None, + adjoint: UnitAttr | bool = False, + ): + in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits + in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values + + if not isinstance(in_qubits, Sequence): + in_qubits = (in_qubits,) + if not isinstance(in_ctrl_qubits, Sequence): + in_ctrl_qubits = (in_ctrl_qubits,) + if not isinstance(in_ctrl_values, Sequence): + in_ctrl_values = (in_ctrl_values,) + + out_qubits = tuple(QubitType() for _ in in_qubits) + out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits) + properties = {} + if adjoint: + properties["adjoint"] = UnitAttr() + + super().__init__( + operands=(matrix, in_qubits, in_ctrl_qubits, in_ctrl_values), + result_types=(out_qubits, out_ctrl_qubits), + properties=properties, + ) + + +@irdl_op_definition +class SampleOp(IRDLOperation): + """Sample eigenvalues from the given observable for the current state""" + + name = "quantum.sample" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $in_data^ `:` type($in_data) `)` )? + attr-dict ( `:` type($samples)^ )? + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + obs = operand_def(ObservableType) + + dynamic_shape = var_operand_def(i64) + + in_data = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=(1, 2))) + + samples = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=(1, 2))) + + +@irdl_op_definition +class SetBasisStateOp(IRDLOperation): + """Set basis state.""" + + name = "quantum.set_basis_state" + + assembly_format = """ + `(` $basis_state`)` $in_qubits attr-dict `:` functional-type(operands, results) + """ + + basis_state = operand_def( + (TensorConstraint(element_type=i1, rank=1)) | (MemRefConstraint(element_type=i1, rank=1)) + ) + + in_qubits = var_operand_def(QubitType) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class SetStateOp(IRDLOperation): + """Set state to a complex vector.""" + + name = "quantum.set_state" + + assembly_format = """ + `(` $in_state `)` $in_qubits attr-dict `:` functional-type(operands, results) + """ + + in_state = operand_def( + (TensorConstraint(element_type=ComplexType(Float64Type()), rank=1)) + | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1)) + ) + + in_qubits = var_operand_def(QubitType) + + out_qubits = var_result_def(QubitType) + + +@irdl_op_definition +class StateOp(IRDLOperation): + """Return the current statevector""" + + name = "quantum.state" + + assembly_format = """ + $obs ( `shape` $dynamic_shape^ )? + ( `in` `(` $state_in^ `:` type($state_in) `)` )? + attr-dict ( `:` type($state)^ )? + """ + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + obs = operand_def(ObservableType) + + dynamic_shape = opt_operand_def(i64) + + state_in = opt_operand_def(MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1)) + + state = opt_result_def(TensorConstraint(element_type=ComplexType(Float64Type()), rank=1)) + + +@irdl_op_definition +class TensorOp(IRDLOperation): + """Define a tensor product of observables for use in measurements""" + + name = "quantum.tensor" + + assembly_format = """ + $terms attr-dict `:` type(results) + """ + + terms = var_operand_def(ObservableType) + + obs = result_def(ObservableType) + + +@irdl_op_definition +class VarianceOp(IRDLOperation): + """Compute the variance of the given observable for the current state""" + + name = "quantum.var" + + assembly_format = """ + $obs attr-dict `:` type(results) + """ + + obs = operand_def(ObservableType) + + variance = result_def(Float64Type()) + + def __init__(self, obs: ObservableSSAValue | Operation): + super().__init__(operands=(obs,), result_types=(Float64Type(),)) + + +@irdl_op_definition +class YieldOp(IRDLOperation): + """Return results from quantum program regions""" + + name = "quantum.yield" + + assembly_format = """ + attr-dict ($retvals^ `:` type($retvals))? + """ + + retvals = var_operand_def(QuregType) + + traits = traits_def(HasParent(AdjointOp), IsTerminator(), Pure(), ReturnLike()) + + +Quantum = Dialect( + "quantum", + [ + AdjointOp, + AllocOp, + AllocQubitOp, + ComputationalBasisOp, + CountsOp, + CustomOp, + DeallocOp, + DeallocQubitOp, + DeviceInitOp, + DeviceReleaseOp, + ExpvalOp, + ExtractOp, + FinalizeOp, + GlobalPhaseOp, + HamiltonianOp, + HermitianOp, + InitializeOp, + InsertOp, + MeasureOp, + MultiRZOp, + NamedObsOp, + NumQubitsOp, + PCPhaseOp, + ProbsOp, + QubitUnitaryOp, + SampleOp, + SetBasisStateOp, + SetStateOp, + StateOp, + TensorOp, + VarianceOp, + YieldOp, + ], + [ + ObservableType, + QubitType, + QuregType, + ResultType, + NamedObservableAttr, + ], +) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py new file mode 100644 index 0000000000..2eb1cb67ce --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py @@ -0,0 +1,168 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO dialect package for PennyLane's compiler infrastructure. + +This package contains organized elementwise operations and other StableHLO-related +functionality. +""" + +# Import all elementwise operations explicitly +from .elementwise_unary import ( + ConvertOp, + CosineOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogOp, + LogPlusOneOp, + LogisticOp, + NegateOp, + RealOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SignOp, + SineOp, + SqrtOp, + TanOp, + TanhOp, +) + +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) + +from .elementwise_other import ( + ClampOp, + CompareOp, + ConstantOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) + +from .control_flow import ( + IfOp, + WhileOp, + OptimizationBarrierOp, +) + +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) + +from .dynamism import ( + DynamicBroadcastInDimOp, +) + +from .reduction import ( + ReduceOp, +) + +from .extensibility import ( + CustomCallOp, +) + +from .attributes import ( + GatherDimensionNumbers, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, + CustomCallApiVersion, + CustomCallApiVersionAttr, + OutputOperandAlias, +) + +# Import the main StableHLO dialect +from .dialect import StableHLO + +# Export all operations and the dialect for external use +__all__ = [ + # Main dialect + "StableHLO", + # Elementwise unary operations + "ConvertOp", + "CosineOp", + "ExponentialMinusOneOp", + "ExponentialOp", + "FloorOp", + "ImagOp", + "IsFiniteOp", + "LogOp", + "LogPlusOneOp", + "LogisticOp", + "NegateOp", + "RealOp", + "RoundNearestAfzOp", + "RoundNearestEvenOp", + "RsqrtOp", + "SignOp", + "SineOp", + "SqrtOp", + "TanOp", + "TanhOp", + # Elementwise binary operations + "ComplexOp", + "DivideOp", + "MaximumOp", + "MinimumOp", + "PowerOp", + "RemainderOp", + # Elementwise other operations + "ClampOp", + "CompareOp", + "ConstantOp", + "MapOp", + "ReducePrecisionOp", + "SelectOp", + # Control flow operations + "IfOp", + "WhileOp", + "OptimizationBarrierOp", + # Data movement operations + "BroadcastInDimOp", + "ConcatenateOp", + "DynamicSliceOp", + "GatherOp", + "ReshapeOp", + "ScatterOp", + "SliceOp", + # Dynamism operations + "DynamicBroadcastInDimOp", + # Reduction operations + "ReduceOp", + # Extensibility operations + "CustomCallOp", + # Attributes + "GatherDimensionNumbers", + "ResultAccuracyModeAttr", + "ScatterDimensionNumbers", + "CustomCallApiVersion", + "CustomCallApiVersionAttr", + "OutputOperandAlias", +] diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py new file mode 100644 index 0000000000..7f2e60e2f9 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py @@ -0,0 +1,394 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO attribute definitions for PennyLane's compiler infrastructure. + +This module provides attribute definitions based on the StableHLO specification +(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including +attributes for StableHLO operations. +""" + +# pylint: disable=too-few-public-methods + +from collections.abc import Sequence + +from xdsl.dialects.builtin import I64, ArrayAttr, IntegerAttr, i64 +from xdsl.ir import ( + Attribute, + EnumAttribute, + ParametrizedAttribute, + SpacedOpaqueSyntaxAttribute, + StrEnum, +) +from xdsl.irdl import irdl_attr_definition +from xdsl.parser import AttrParser +from xdsl.printer import Printer + + +# Utility functions for dimension array parsing/printing +def parse_dims(parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]: + """Parse dimension array in [1, 2, 3] format""" + value = parser.parse_comma_separated_list( + AttrParser.Delimiter.SQUARE, + lambda: IntegerAttr(parser.parse_integer(), i64), + ) + return ArrayAttr(value) + + +def print_dims(printer: Printer, dims: ArrayAttr[IntegerAttr[I64]]): + """Print dimension array in [1, 2, 3] format""" + printer.print_string("[") + printer.print_list( + dims.data, + lambda dim: printer.print_string(f"{dim.value.data}"), + ) + printer.print_string("]") + + +class ResultAccuracyMode(StrEnum): + """ + XLA result accuracy mode. + """ + + DEFAULT = "DEFAULT" + HIGH = "HIGHEST" + HIGHEST = "TOLERANCE" + + +@irdl_attr_definition +class ResultAccuracyModeAttr(EnumAttribute[ResultAccuracyMode], SpacedOpaqueSyntaxAttribute): + """ + XLA result accuracy mode. + + See external [documentation](https://github.com/openxla/stablehlo/blob/7c50d4efeaea30bff6aa5e46c7f71170f5aa06af/stablehlo/dialect/StablehloEnums.td#L49-L70). + """ + + name = "stablehlo.result_accuracy_mode" + + +@irdl_attr_definition +class GatherDimensionNumbers(ParametrizedAttribute): + """ + XLA gather dimension numbers. + + This attribute models the dimension information for gather operations. + See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42). + """ + + name = "stablehlo.gather" + + offset_dims: ArrayAttr[IntegerAttr[I64]] + collapsed_slice_dims: ArrayAttr[IntegerAttr[I64]] + operand_batching_dims: ArrayAttr[IntegerAttr[I64]] + start_indices_batching_dims: ArrayAttr[IntegerAttr[I64]] + start_index_map: ArrayAttr[IntegerAttr[I64]] + index_vector_dim: IntegerAttr[I64] + + def print_parameters(self, printer: Printer) -> None: + """Print gather dimension numbers in structured format""" + with printer.in_angle_brackets(): + with printer.indented(): + # Print offset_dims + printer.print_string("\noffset_dims = ") + print_dims(printer, self.offset_dims) + printer.print_string(",") + + # Print collapsed_slice_dims + printer.print_string("\ncollapsed_slice_dims = ") + print_dims(printer, self.collapsed_slice_dims) + printer.print_string(",") + + # Print operand_batching_dims + printer.print_string("\noperand_batching_dims = ") + print_dims(printer, self.operand_batching_dims) + printer.print_string(",") + + # Print start_indices_batching_dims + printer.print_string("\nstart_indices_batching_dims = ") + print_dims(printer, self.start_indices_batching_dims) + printer.print_string(",") + + # Print start_index_map + printer.print_string("\nstart_index_map = ") + print_dims(printer, self.start_index_map) + printer.print_string(",") + + # Print index_vector_dim + printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}") + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: + """Parse gather dimension numbers from structured format""" + with parser.in_angle_brackets(): + # Initialize default values for all fields + offset_dims = ArrayAttr([]) + collapsed_slice_dims = ArrayAttr([]) + operand_batching_dims = ArrayAttr([]) + start_indices_batching_dims = ArrayAttr([]) + start_index_map = ArrayAttr([]) + index_vector_dim = IntegerAttr(0, i64) + + # Try to parse offset_dims + if parser.parse_optional_characters("offset_dims") is not None: + parser.parse_punctuation("=") + offset_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse collapsed_slice_dims + if parser.parse_optional_characters("collapsed_slice_dims") is not None: + parser.parse_punctuation("=") + collapsed_slice_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse operand_batching_dims + if parser.parse_optional_characters("operand_batching_dims") is not None: + parser.parse_punctuation("=") + operand_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse start_indices_batching_dims + if parser.parse_optional_characters("start_indices_batching_dims") is not None: + parser.parse_punctuation("=") + start_indices_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse start_index_map + if parser.parse_optional_characters("start_index_map") is not None: + parser.parse_punctuation("=") + start_index_map = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse index_vector_dim + if parser.parse_optional_characters("index_vector_dim") is not None: + parser.parse_punctuation("=") + index_vector_dim = IntegerAttr(parser.parse_integer(), i64) + + return ( + offset_dims, + collapsed_slice_dims, + operand_batching_dims, + start_indices_batching_dims, + start_index_map, + index_vector_dim, + ) + + +@irdl_attr_definition +class ScatterDimensionNumbers(ParametrizedAttribute): + """ + XLA scatter dimension numbers. + + This attribute models the dimension information for scatter operations. + See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28). + """ + + name = "stablehlo.scatter" + + update_window_dims: ArrayAttr[IntegerAttr[I64]] + inserted_window_dims: ArrayAttr[IntegerAttr[I64]] + input_batching_dims: ArrayAttr[IntegerAttr[I64]] + scatter_indices_batching_dims: ArrayAttr[IntegerAttr[I64]] + scatter_dims_to_operand_dims: ArrayAttr[IntegerAttr[I64]] + index_vector_dim: IntegerAttr[I64] + + def print_parameters(self, printer: Printer) -> None: + """Print scatter dimension numbers in structured format""" + with printer.in_angle_brackets(): + with printer.indented(): + # Print update_window_dims + printer.print_string("\nupdate_window_dims = ") + print_dims(printer, self.update_window_dims) + printer.print_string(",") + + # Print inserted_window_dims + printer.print_string("\ninserted_window_dims = ") + print_dims(printer, self.inserted_window_dims) + printer.print_string(",") + + # Print input_batching_dims + printer.print_string("\ninput_batching_dims = ") + print_dims(printer, self.input_batching_dims) + printer.print_string(",") + + # Print scatter_indices_batching_dims + printer.print_string("\nscatter_indices_batching_dims = ") + print_dims(printer, self.scatter_indices_batching_dims) + printer.print_string(",") + + # Print scatter_dims_to_operand_dims + printer.print_string("\nscatter_dims_to_operand_dims = ") + print_dims(printer, self.scatter_dims_to_operand_dims) + printer.print_string(",") + + # Print index_vector_dim + printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}") + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: + """Parse scatter dimension numbers from structured format""" + with parser.in_angle_brackets(): + # Initialize default values for all fields + update_window_dims = ArrayAttr([]) + inserted_window_dims = ArrayAttr([]) + input_batching_dims = ArrayAttr([]) + scatter_indices_batching_dims = ArrayAttr([]) + scatter_dims_to_operand_dims = ArrayAttr([]) + index_vector_dim = IntegerAttr(0, i64) + + # Try to parse update_window_dims + if parser.parse_optional_characters("update_window_dims") is not None: + parser.parse_punctuation("=") + update_window_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse inserted_window_dims + if parser.parse_optional_characters("inserted_window_dims") is not None: + parser.parse_punctuation("=") + inserted_window_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse input_batching_dims + if parser.parse_optional_characters("input_batching_dims") is not None: + parser.parse_punctuation("=") + input_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse scatter_indices_batching_dims + if parser.parse_optional_characters("scatter_indices_batching_dims") is not None: + parser.parse_punctuation("=") + scatter_indices_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse scatter_dims_to_operand_dims + if parser.parse_optional_characters("scatter_dims_to_operand_dims") is not None: + parser.parse_punctuation("=") + scatter_dims_to_operand_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse index_vector_dim + if parser.parse_optional_characters("index_vector_dim") is not None: + parser.parse_punctuation("=") + index_vector_dim = IntegerAttr(parser.parse_integer(), i64) + + return ( + update_window_dims, + inserted_window_dims, + input_batching_dims, + scatter_indices_batching_dims, + scatter_dims_to_operand_dims, + index_vector_dim, + ) + + +# ===== CustomCall and layout-related attributes ===== + + +class CustomCallApiVersion(StrEnum): + """StableHLO CustomCall API version.""" + + API_VERSION_UNSPECIFIED = "API_VERSION_UNSPECIFIED" + API_VERSION_ORIGINAL = "API_VERSION_ORIGINAL" + API_VERSION_STATUS_RETURNING = "API_VERSION_STATUS_RETURNING" + API_VERSION_STATUS_RETURNING_UNIFIED = "API_VERSION_STATUS_RETURNING_UNIFIED" + API_VERSION_TYPED_FFI = "API_VERSION_TYPED_FFI" + + +@irdl_attr_definition +class CustomCallApiVersionAttr(EnumAttribute[CustomCallApiVersion], SpacedOpaqueSyntaxAttribute): + """StableHLO custom call API version attribute. + + Mirrors StableHLO enum for CustomCall API versions. + """ + + name = "stablehlo.custom_call_api_version" + + +@irdl_attr_definition +class OutputOperandAlias(ParametrizedAttribute): + """ + This attribute captures the alias relationship of the output to one of the + operands for a ``CustomCall`` op, denoted by ``operand_index``. The + ``output_tuple_indices`` and ``operand_tuple_indices`` are used to index into + output and operand types. These indices lists are empty if the corresponding + types are not tuple types, and can be arbitrarily long in case of + arbitrarily nested tuple types. + + See https://www.tensorflow.org/xla/aliasing. + + Example when used as array with in stablehlo.custom-call: + + ```mlir + %0 = "stablehlo.custom_call"(%arg0, %arg1) { + // other attributes + output_operand_alias = [ + #stablehlo.output_operand_alias + ] + } : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple> + + The output and the 0th operand are both tuples. The aliasing shows the + relationship between the 0th element in output tuple with the 1st element in + the 0th operand. And both of them are of the same type: ``tensor<2x3xf32>``. + ``` + """ + + name = "stablehlo.output_operand_alias" + + output_tuple_indices: ArrayAttr[IntegerAttr[I64]] + operand_index: IntegerAttr[I64] + operand_tuple_indices: ArrayAttr[IntegerAttr[I64]] + + def print_parameters(self, printer: Printer) -> None: + """Print the OutputOperandAlias attribute.""" + with printer.in_angle_brackets(): + with printer.indented(): + printer.print_string("\noutput_tuple_indices = ") + print_dims(printer, self.output_tuple_indices) + printer.print_string(",") + + printer.print_string("\noperand_index = ") + printer.print_string(f"{self.operand_index.value.data}") + printer.print_string(",") + + printer.print_string("\noperand_tuple_indices = ") + print_dims(printer, self.operand_tuple_indices) + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser): + """Parse the OutputOperandAlias attribute.""" + with parser.in_angle_brackets(): + output_tuple_indices = ArrayAttr([]) + operand_index = IntegerAttr(0, i64) + operand_tuple_indices = ArrayAttr([]) + + if parser.parse_optional_characters("output_tuple_indices") is not None: + parser.parse_punctuation("=") + output_tuple_indices = parse_dims(parser) + parser.parse_optional_punctuation(",") + + if parser.parse_optional_characters("operand_index") is not None: + parser.parse_punctuation("=") + operand_index = IntegerAttr(parser.parse_integer(), i64) + parser.parse_optional_punctuation(",") + + if parser.parse_optional_characters("operand_tuple_indices") is not None: + parser.parse_punctuation("=") + operand_tuple_indices = parse_dims(parser) + + return (output_tuple_indices, operand_index, operand_tuple_indices) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py new file mode 100644 index 0000000000..293c831eaa --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py @@ -0,0 +1,160 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods + +""" +Control flow operations for the StableHLO dialect. +""" + +from typing import TypeVar + +from xdsl.dialects.builtin import AnyTensorType +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + region_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + Pure, + RecursivelySpeculatable, + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) +from xdsl_jax.dialects.stablehlo import ReturnOp + +# Import our custom StableHLO types +from .types import HLO_PredTensor, HLO_TensorOrPerAxisQuantizedTensorOrToken, HLO_TensorOrToken + +# Generic type variables for templating +T_IN = TypeVar("T_IN", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +@irdl_op_definition +class IfOp(IRDLOperation): + """ + Produces the output from executing exactly one branch from `true_branch` or + `false_branch` depending on the value of `pred`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if + + Example: + %result = "stablehlo.if"(%pred) ({ + "stablehlo.return"(%result_true_branch) : (tensor) -> () + }, { + "stablehlo.return"(%result_false_branch) : (tensor) -> () + }) : (tensor) -> tensor + """ + + name = "stablehlo.if" + + pred = operand_def(HLO_PredTensor) + + res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + true_branch = region_def("single_block") + + false_branch = region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + RecursivelySpeculatable(), + SingleBlockImplicitTerminator(ReturnOp), + # TODO: InferTypeOpInterface + # TODO: OpAsmOpInterface + ) + + # TODO: Add custom assembly format + + +@irdl_op_definition +class WhileOp(IRDLOperation): + """ + Produces the output from executing `body` function 0 or more times while the + `cond` function outputs `true`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while + + Example: + ```mlir + %results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor, tensor + cond { + %cond = stablehlo.compare LT, %arg0, %ten : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + } do { + %new_sum = stablehlo.add %arg1, %one : tensor + %new_i = stablehlo.add %arg0, %one : tensor + stablehlo.return %new_i, %new_sum : tensor, tensor + } + """ + + name = "stablehlo.while" + + operand = var_operand_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + cond = region_def("single_block") + + body = region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + RecursivelySpeculatable(), + SingleBlockImplicitTerminator(ReturnOp), + # TODO: InferTypeOpInterface + # TODO: OpAsmOpInterface + ) + + +@irdl_op_definition +class OptimizationBarrierOp(IRDLOperation): + """ + Ensures that the operations that produce the `operand` are executed before any + operations that depend on the `result` and prevents compiler transformations + from moving operations across the barrier. Other than that, the operation is + an identity, i.e. `result` = `operand`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier + + Example: + ```mlir + %result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor, tensor + ``` + """ + + name = "stablehlo.optimization_barrier" + + operand = var_operand_def(HLO_TensorOrToken) + + res = var_result_def(HLO_TensorOrToken) + + traits = traits_def( + Pure(), + # TODO: HLO_PairwiseSameOperandAndResultType + # TODO: InferTypeOpInterface + ) + + # TODO: Add custom assembly format + # assembly_format = """ + # attr-dict ($operand^ `:` custom(type($operand), type($result))):(`(` `)`)? + # """ diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py new file mode 100644 index 0000000000..ee6a9a3109 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py @@ -0,0 +1,415 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods + +""" +Data movement operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import BoolAttr, DenseArrayBase, IntegerAttr, TensorType, i64 +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.irdl.attributes import eq +from xdsl.irdl.constraints import AtLeast +from xdsl.irdl.operations import SameVariadicOperandSize +from xdsl.traits import ( + ConditionallySpeculatable, + NoMemoryEffect, + Pure, + RecursiveMemoryEffect, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.type import get_element_type_or_self + +from catalyst.python_interface.xdsl_extras import ( + AllMatchSameOperatorTrait, + SameOperandsAndResultElementType, + TensorConstraint, +) + +from .attributes import GatherDimensionNumbers, ScatterDimensionNumbers +from .types import HLO_AnyIntegerOrIndexTensor, HLO_AnyTensor, HLO_Int, HLO_IntTensor, HLO_Tensor + + +@irdl_op_definition +class BroadcastInDimOp(IRDLOperation): + """ + Expands the dimensions and/or rank of an input tensor by duplicating the + data in the ``operand`` tensor and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim + + Example: + ```mlir + %result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + ``` + """ + + name = "stablehlo.broadcast_in_dim" + operand = operand_def(HLO_AnyTensor) + broadcast_dimensions = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_AnyTensor) + + assembly_format = """ + $operand `,` `dims` `=` $broadcast_dimensions + attr-dict `:` functional-type(operands, results) + """ + + traits = traits_def( + NoMemoryEffect(), + # TODO: HLO_SpeculatableIfAllInputsStatic, + # TODO: HLO_CompatibleOperandsAndResultElementType, + ) + + def verify_(self) -> None: + """Verify non-quantized broadcast_in_dim constraints.""" + o_type = self.operand_types[0] + r_type = self.result_types[0] + + # These are constrained to tensors by the op definition + assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType) + + # broadcast_in_dim_c2: broadcast_dimensions size == operand rank + dims = tuple(self.broadcast_dimensions.get_values()) + operand_rank = o_type.get_num_dims() + if len(dims) != operand_rank: + raise VerifyException( + "broadcast_dimensions size (" + f"{len(dims)}" + ") does not match operand rank (" + f"{operand_rank}" + ")" + ) + + # broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(dims)) != len(dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + # Result rank and per-dimension checks + result_rank = r_type.get_num_dims() + o_shape = o_type.get_shape() + r_shape = r_type.get_shape() + + for i, dim_index in enumerate(dims): + # broadcast_in_dim_c3: each dim index in bounds of result rank + if dim_index < 0 or dim_index >= result_rank: + raise VerifyException( + "broadcast_dimensions contains invalid value " + f"{dim_index} for result with rank {result_rank}" + ) + + # If operand dim is static, enforce broadcast_in_dim_c5 + if o_shape[i] != -1: + dim_size = o_shape[i] + result_dim_size = r_shape[dim_index] + if dim_size not in (1, result_dim_size): + raise VerifyException( + "size of operand dimension " + f"{i} ({dim_size}) is not equal to 1 or size of result dimension " + f"{dim_index} ({result_dim_size})" + ) + + +@irdl_op_definition +class ConcatenateOp(IRDLOperation): + """ + Concatenates a variadic number of tensors in ``inputs`` along ``dimension`` + dimension in the same order as the given arguments and produces a ``result`` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate + + Example: + ```mlir + %result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + ``` + """ + + name = "stablehlo.concatenate" + + inputs = var_operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + dimension = prop_def(IntegerAttr.constr(type=eq(i64), value=AtLeast(0))) + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + SameOperandsAndResultElementType(), + # InferTypeOpInterface(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) + # """ + + +@irdl_op_definition +class DynamicSliceOp(IRDLOperation): + """ + Extracts a slice from the ``operand`` using dynamically-computed starting + indices and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice + + Example: + ```mlir + %result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2] + : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x2xi32> + ``` + """ + + name = "stablehlo.dynamic_slice" + operand = operand_def(HLO_Tensor) + start_indices = var_operand_def(TensorConstraint(element_type=HLO_Int, rank=0)) + slice_sizes = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand `,` custom($start_indices) + # `sizes` `=` $slice_sizes attr-dict `:` functional-type(operands, results) + # """ + + traits = traits_def( + Pure(), + AllMatchSameOperatorTrait( + ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type" + ), + # TODO: InferTensorType(), + ) + + +@irdl_op_definition +class GatherOp(IRDLOperation): + """ + Gathers slices from ``operand`` tensor from offsets specified in + ``start_indices`` and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + + Example: + ```mlir + %result = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64> + ``` + """ + + name = "stablehlo.gather" + operand = operand_def(HLO_Tensor) + start_indices = operand_def(HLO_IntTensor) + dimension_numbers = prop_def(GatherDimensionNumbers) + slice_sizes = prop_def(DenseArrayBase.constr(i64)) + indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + result = result_def(HLO_Tensor) + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + AllMatchSameOperatorTrait( + ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type" + ), + # TODO: InferTensorTypeWithReify(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) + # """ + + +@irdl_op_definition +class ReshapeOp(IRDLOperation): + """ + Performs reshape of ``operand`` tensor to a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape + + Example: + ```mlir + %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32> + ``` + """ + + name = "stablehlo.reshape" + operand = operand_def(HLO_AnyTensor) + result = result_def(HLO_AnyTensor) + + assembly_format = """ + operands attr-dict `:` functional-type(operands, results) + """ + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + # TODO: HLO_CompatibleOperandsAndResultElementType, + ) + + def verify_(self) -> None: + """Verify that the operation has the same shape for all operands and results.""" + o_type = self.operand_types[0] + r_type = self.result_types[0] + + # These are constrained to tensors by the op definition + assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType) + + # If o_type or r_type is dynamically shaped there is nothing to verify. + if not o_type.has_static_shape() or not r_type.has_static_shape(): + return + + # If the operand type is statically shaped (not required) the number of + # elements must match that of the result type. + num_operand_elements = 1 + for dim in o_type.get_shape(): + num_operand_elements *= dim + + num_result_elements = 1 + for dim in r_type.get_shape(): + num_result_elements *= dim + + if num_result_elements != num_operand_elements: + raise VerifyException( + "number of output elements (" + f"{num_result_elements}" + ") doesn't match expected number of elements (" + f"{num_operand_elements}" + ")" + ) + + +@irdl_op_definition +class ScatterOp(IRDLOperation): + """ + Produces ``results`` tensors which are equal to ``inputs`` tensors except that + several slices specified by ``scatter_indices`` are updated with the values + ``updates`` using ``update_computation``. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + + Example: + ```mlir + %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + ``` + """ + + name = "stablehlo.scatter" + inputs = var_operand_def(HLO_Tensor) + scatter_indices = operand_def(HLO_AnyIntegerOrIndexTensor) + updates = var_operand_def(HLO_Tensor) + scatter_dimension_numbers = prop_def(ScatterDimensionNumbers) + indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + unique_indices = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + result = var_result_def(HLO_Tensor) + update_computation = region_def("single_block") + # TODO: The MLIR implementation doesn't have the SingleBlockImplicitTerminator trait, + # However, it is checked to have a terminator in the verifier, + # which does not specifically check the terminator to be stablehlo.return. + + traits = traits_def( + RecursiveMemoryEffect(), + ConditionallySpeculatable(), + # TODO: InferTypeOpInterface(), + ) + + irdl_options = [SameVariadicOperandSize()] + + # TODO: MLIR has a custom verifier for the scatter operation. + + +@irdl_op_definition +class SliceOp(IRDLOperation): + """ + Extracts a slice from the ``operand`` using statically-computed starting + indices and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice + + Example: + ```mlir + %result = stablehlo.slice %operand [1:3, 4:8:2] + : (tensor<3x8xi64>) -> tensor<2x2xi64> + + // Same in generic form: the `1:3` above is mapped to the first entry in + // `start_indices` and `limit_indices`, while `strides` is implicitly 1. + // The `4:8:2` above is parsed into the second entry of `start_indices`, + // `limit_indices` and `strides` respectively. + %result = "stablehlo.slice" (%operand) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x8xi64>) -> tensor<2x2xi64> + ``` + """ + + name = "stablehlo.slice" + + operand = operand_def(HLO_Tensor) + start_indices = prop_def(DenseArrayBase.constr(i64)) + limit_indices = prop_def(DenseArrayBase.constr(i64)) + strides = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand custom($start_indices, $limit_indices, $strides) + # attr-dict `:` functional-type(operands, results) + # """ + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + AllMatchSameOperatorTrait(("start_indices", "limit_indices", "strides"), len, "size"), + SameOperandsAndResultElementType(), + ) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py new file mode 100644 index 0000000000..ce880bba2b --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py @@ -0,0 +1,207 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Extended StableHLO dialect that dynamically includes all upstream operations +plus custom operations for PennyLane's compiler infrastructure. + +This module automatically imports all operations and attributes from the upstream +xdsl_jax.dialects.stablehlo and adds custom ones without needing to hardcode +the upstream operation list. +""" + +import xdsl_jax.dialects.stablehlo as xstablehlo +from xdsl.ir import Dialect + +from .attributes import ( + CustomCallApiVersionAttr, + GatherDimensionNumbers, + OutputOperandAlias, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, +) +from .control_flow import ( + IfOp, + OptimizationBarrierOp, + WhileOp, +) +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) +from .dynamism import ( + DynamicBroadcastInDimOp, +) +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) +from .elementwise_other import ( + ClampOp, + CompareOp, + ConstantOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) + +# Import all elementwise operations from organized files +from .elementwise_unary import ( + ConvertOp, + CosineOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogisticOp, + LogOp, + LogPlusOneOp, + NegateOp, + RealOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SignOp, + SineOp, + SqrtOp, + TanhOp, + TanOp, +) +from .extensibility import ( + CustomCallOp, +) +from .reduction import ( + ReduceOp, +) +from .types import UniformQuantizedPerAxisType, UniformQuantizedType + +# Operations to add to the dialect +OPERATIONS = [ + ClampOp, + CompareOp, + ComplexOp, + ConstantOp, + ConvertOp, + CosineOp, + DivideOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogOp, + LogPlusOneOp, + LogisticOp, + MapOp, + MaximumOp, + MinimumOp, + NegateOp, + PowerOp, + RealOp, + ReducePrecisionOp, + RemainderOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SelectOp, + SignOp, + SineOp, + SqrtOp, + TanOp, + TanhOp, + # Data movement operations + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, + # Control flow operations + IfOp, + WhileOp, + OptimizationBarrierOp, + # Dynamism operations + DynamicBroadcastInDimOp, + # Reduction operations + ReduceOp, + # Extensibility operations + CustomCallOp, +] + +# Attributes to add to the dialect +ATTRIBUTES = [ + CustomCallApiVersionAttr, + GatherDimensionNumbers, + ResultAccuracyModeAttr, + OutputOperandAlias, + ScatterDimensionNumbers, + UniformQuantizedPerAxisType, + UniformQuantizedType, +] + +# Operations/attributes from upstream that should be deleted/replaced in the local version +UPSTREAM_OPERATIONS_TO_DELETE = [ + xstablehlo.ConstantOp, +] +UPSTREAM_ATTRIBUTES_TO_DELETE = [] + + +def filter_and_extend_upstream(upstream_list, to_delete, to_add): + """Filter out operations/attributes from upstream list and add new ones. + + Args: + upstream_list: List of operations/attributes to filter + to_delete: List of operations/attributes to remove + to_add: List of operations/attributes to add + + Returns: + Modified list of operations/attributes + """ + filtered_ops = list(upstream_list) + + # Remove operations that should be deleted + for op_to_delete in to_delete: + if op_to_delete in filtered_ops: + filtered_ops.remove(op_to_delete) + + # Add new operations + filtered_ops.extend(to_add) + + return filtered_ops + + +all_operations = filter_and_extend_upstream( + xstablehlo.StableHLO.operations, UPSTREAM_OPERATIONS_TO_DELETE, OPERATIONS +) +all_attributes = filter_and_extend_upstream( + xstablehlo.StableHLO.attributes, UPSTREAM_ATTRIBUTES_TO_DELETE, ATTRIBUTES +) + +# Create the extended StableHLO dialect by dynamically getting upstream components +StableHLO = Dialect( + "stablehlo", + all_operations, + all_attributes, +) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py new file mode 100644 index 0000000000..e012c07529 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py @@ -0,0 +1,180 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods + +""" +Dynamism operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import DenseArrayBase, TensorType, i64 +from xdsl.irdl import ( + IRDLOperation, + ParsePropInAttrDict, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + result_def, + traits_def, +) +from xdsl.traits import ( + ConditionallySpeculatable, + NoMemoryEffect, +) +from xdsl.utils.exceptions import VerifyException + +from catalyst.python_interface.xdsl_extras import TensorConstraint + +from .types import HLO_AnyTensor, HLO_DimensionValue + + +@irdl_op_definition +class DynamicBroadcastInDimOp(IRDLOperation): + """ + This operation is functionally identical to + [broadcast_in_dim](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim) + op, but the result shape is specified dynamically via ``output_dimensions``. + + It also accepts optional attributes to express static knowledge about the + expanding behavior of dimensions. If not specified, all dimensions are + assumed to be possibly expanding. The sets of dimensions that are known to + be expanding and the set of dimensions that are known to be non-expanding + must be disjoint and they must be a subset of the operand's dimensions. + + See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim + + Example: + ```mlir + %operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64> + %output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64> + %result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + ``` + """ + + name = "stablehlo.dynamic_broadcast_in_dim" + + operand = operand_def(HLO_AnyTensor) + output_dimensions = operand_def(TensorConstraint(element_type=HLO_DimensionValue, rank=1)) + broadcast_dimensions = prop_def(DenseArrayBase.constr(i64)) + known_expanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64)) + known_nonexpanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_AnyTensor) + + assembly_format = ( + "$operand `,` $output_dimensions `,` `dims` `=` $broadcast_dimensions " + "attr-dict `:` functional-type(operands, results)" + ) + + traits = traits_def( + ConditionallySpeculatable(), + NoMemoryEffect(), + # TODO: InferShapedTypeOpInterface(), + ) + + irdl_options = [ParsePropInAttrDict()] + + # pylint: disable=too-many-branches + def verify_(self): + """Verify the operation.""" + # Operand and result must be tensors + operand_ty = self.operand_types[0] + result_ty = self.result_types[0] + assert isinstance(operand_ty, TensorType) and isinstance(result_ty, TensorType) + + # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank + bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member + operand_rank = operand_ty.get_num_dims() + if len(bcast_dims) != operand_rank: + raise VerifyException( + "broadcast_dimensions size (" + f"{len(bcast_dims)}" + ") does not match operand rank (" + f"{operand_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c3: result rank >= operand rank + result_rank = result_ty.get_num_dims() + if result_rank < operand_rank: + raise VerifyException( + "result rank (" + f"{result_rank}" + ") is less than operand rank (" + f"{operand_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(bcast_dims)) != len(bcast_dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility + operand_shape = operand_ty.get_shape() + result_shape = result_ty.get_shape() + for i, dim_index in enumerate(bcast_dims): + if dim_index < 0 or dim_index >= result_rank: + raise VerifyException( + "broadcast_dimensions contains invalid value " + f"{dim_index} for result with rank {result_rank}" + ) + op_dim = operand_shape[i] + res_dim = result_shape[dim_index] + # If operand dim is static and not size-1, require compatibility with result dim + if op_dim not in (-1, 1): + if res_dim not in (-1, op_dim): + raise VerifyException( + "size of operand dimension " + f"{i} ({op_dim}) is not compatible with size of result dimension " + f"{dim_index} ({res_dim})" + ) + + # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank + out_dims_ty = self.output_dimensions.type # pylint: disable=no-member + assert isinstance(out_dims_ty, TensorType) + # Must be rank-1 tensor (enforced by type constraint), and length must match result rank when statically known + out_shape = out_dims_ty.get_shape() + if len(out_shape) != 1: + raise VerifyException("output_dimensions must be a 1D tensor") + if out_shape[0] != -1 and out_shape[0] != result_rank: + raise VerifyException( + "length of output_dimensions (" + f"{out_shape[0]}" + ") is not compatible with result rank (" + f"{result_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c8: no duplicate expansion hints across both lists + hints = [] + if self.known_expanding_dimensions is not None: + hints.extend(self.known_expanding_dimensions.get_values()) # pylint: disable=no-member + if self.known_nonexpanding_dimensions is not None: + hints.extend( + self.known_nonexpanding_dimensions.get_values() # pylint: disable=no-member + ) + if len(set(hints)) != len(hints): + raise VerifyException("duplicate expansion hint for at least one operand dimension") + + # dynamic_broadcast_in_dim_c9/c10: each hint must reference a valid operand dimension + for h in set(hints): + if h < 0 or h >= operand_rank: + raise VerifyException( + "hint for expanding dimension " + f"{h} does not refer to a valid operand dimension" + ) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py new file mode 100644 index 0000000000..135c865005 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py @@ -0,0 +1,216 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Binary elementwise operations for the StableHLO dialect. +""" + +# pylint: disable=too-few-public-methods + +import abc +from typing import Generic, TypeVar + +from xdsl.dialects.builtin import AnyTensorType, ComplexType, Float32Type, Float64Type, TensorType +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + result_def, + traits_def, +) +from xdsl.traits import NoMemoryEffect + +from catalyst.python_interface.xdsl_extras import ( + Elementwise, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +from .types import ( + HLO_ComplexTensor, + HLO_Fp32Or64Tensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_Tensor, +) + +# Type aliases +F32Or64Type = Float32Type | Float64Type +F32Or64TensorType = TensorType[F32Or64Type] +ComplexTensorType = TensorType[ComplexType] + +# Generic type variables for templating +T_LHS = TypeVar("T_LHS", bound=AnyTensorType) +T_RHS = TypeVar("T_RHS", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +class ElementwiseBinaryOperation(IRDLOperation, abc.ABC, Generic[T_LHS, T_RHS, T_OUT]): + """ + Templated base class for elementwise binary operations. + + This class provides a flexible template for binary operations that can work + with different tensor types. + + For more information about the semantics, see: + https://openxla.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + """ + + lhs = operand_def(T_LHS) + rhs = operand_def(T_RHS) + result = result_def(T_OUT) + + traits = traits_def( + NoMemoryEffect(), + SameOperandsAndResultShape(), + Elementwise(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $lhs `,` $rhs attr-dict + # `:` custom(type($lhs), type($rhs), type($result)) + # """ + + def __init__(self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None): + if result_type is None: + result_type = lhs.type + super().__init__(operands=(lhs, rhs), result_types=(result_type,)) + + +@irdl_op_definition +class ComplexOp( + ElementwiseBinaryOperation[HLO_Fp32Or64Tensor, HLO_Fp32Or64Tensor, HLO_ComplexTensor] +): + """ + Performs element-wise conversion to a complex value from a pair of real and + imaginary values, `lhs` and `rhs`, and produces a `result` tensor. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex + Example: + ```mlir + %result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex> + ``` + """ + + name = "stablehlo.complex" + + # assembly_format = """ + # operands attr-dict + # `:` custom(type($lhs), type($rhs), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + SameOperandsElementType(), + SameOperandsAndResultShape(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + ) + + +@irdl_op_definition +class DivideOp( + ElementwiseBinaryOperation[ + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + ] +): + """ + Performs element-wise division of dividend `lhs` and divisor `rhs` tensors + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide + + Example: + ```mlir + %result = stablehlo.divide %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.divide" + + +@irdl_op_definition +class MaximumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise max operation on tensors `lhs` and `rhs` and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum + + Example: + ```mlir + %result = stablehlo.maximum %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.maximum" + + +@irdl_op_definition +class MinimumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise min operation on tensors `lhs` and `rhs` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum + + Example: + ```mlir + %result = stablehlo.minimum %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.minimum" + + +@irdl_op_definition +class PowerOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + + Example: + ```mlir + %result = stablehlo.power %lhs, %rhs : tensor<6xf64> + ``` + """ + + name = "stablehlo.power" + + +@irdl_op_definition +class RemainderOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder + + Example: + ```mlir + %result = stablehlo.remainder %lhs, %rhs : tensor<4xi64> + ``` + """ + + name = "stablehlo.remainder" diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py new file mode 100644 index 0000000000..e4d4c7087d --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py @@ -0,0 +1,238 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Other elementwise operations for the StableHLO dialect. +""" + +# pylint: disable=too-few-public-methods + +import xdsl_jax.dialects.stablehlo as xstablehlo +from xdsl.dialects.builtin import ( + AnyFloat, + DenseArrayBase, + DenseIntOrFPElementsAttr, + IntegerAttr, + TensorType, + i32, + i64, +) +from xdsl.irdl import ( + IRDLOperation, + attr_def, + irdl_op_definition, + operand_def, + opt_attr_def, + prop_def, + result_def, + traits_def, + var_operand_def, + var_region_def, +) +from xdsl.irdl.attributes import eq +from xdsl.irdl.constraints import AtLeast +from xdsl.traits import NoMemoryEffect, RecursiveMemoryEffect, SingleBlockImplicitTerminator + +from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape + +from .types import HLO_AnyTensor, HLO_FpOrQuantizedIntTensor, HLO_PredTensor, HLO_Tensor + +# Type aliases +FloatTensorType = TensorType[AnyFloat] + + +@irdl_op_definition +class ClampOp(IRDLOperation): + """Element-wise clamp with min and max bounds. + + See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp + """ + + name = "stablehlo.clamp" + + min = operand_def(HLO_Tensor) + operand = operand_def(HLO_Tensor) + max = operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $min `,` $operand `,` $max attr-dict + # `:` custom(type($min), type($operand), type($max), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + # TODO: HLO_CompatibleOperandsAndResultElementType(), + # TODO: HLO_BroadcastingElementwise(), + # TODO: InferTensorType(), + # TODO: InferShapedTypeOpInterface(), + ) + + +@irdl_op_definition +class CompareOp(IRDLOperation): + """Element-wise compare with direction and type attributes.""" + + name = "stablehlo.compare" + + assembly_format = """ + $comparison_direction `,` $lhs `,` $rhs (`,` $comparison_type^)? attr-dict `:` functional-type(operands, results) + """ + + lhs = operand_def(HLO_Tensor) + rhs = operand_def(HLO_Tensor) + result = result_def(HLO_PredTensor) + comparison_direction = attr_def(xstablehlo.ComparisonDirectionAttr) + comparison_type = opt_attr_def(xstablehlo.ComparisonTypeAttr) + + traits = traits_def( + NoMemoryEffect(), + Elementwise(), + SameOperandsAndResultShape(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + # TODO: HLO_CompatibleOperandsElementType(), + # TODO: InferTensorTypeWithReify(), + ) + + +@irdl_op_definition +class MapOp(IRDLOperation): + """ + Applies a map function `computation` to `inputs` along the `dimensions` and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map + + Example: + ```mlir + %result = "stablehlo.map"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.multiply %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + dimensions = array + } : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> + ``` + """ + + name = "stablehlo.map" + + inputs = var_operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + dimensions = attr_def(DenseArrayBase.constr(i64)) + computation = var_region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + SameOperandsAndResultShape(), + SingleBlockImplicitTerminator(xstablehlo.ReturnOp), + # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic(), + # TODO: InferTypeOpInterface + # TODO: InferShapedTypeOpInterface(), + ) + + +@irdl_op_definition +class ReducePrecisionOp(IRDLOperation): + """ + Performs element-wise conversion of `operand` to another floating-point type + that uses `exponent_bits` and `mantissa_bits` and back to the original + floating-point type and produces an `output` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision + + Example: + ```mlir + %output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64> + ``` + """ + + name = "stablehlo.reduce_precision" + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand `,` `format` `=` custom($exponent_bits, $mantissa_bits) + # attr-dict `:` custom(type($operand), type($output)) + # """ + + operand = operand_def(HLO_FpOrQuantizedIntTensor) + result = result_def(HLO_FpOrQuantizedIntTensor) + + exponent_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(1))) + mantissa_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(0))) + + traits = traits_def( + NoMemoryEffect(), + Elementwise(), + # TODO: HLO_CompatibleOperandsAndResultType(), + # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(), + ) + + +@irdl_op_definition +class SelectOp(IRDLOperation): + """ + Produces a `result` tensor where each element is selected from `on_true` or + `on_false` tensor based on the value of the corresponding element of `pred`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select + + Example: + ```mlir + %result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32> + ``` + """ + + name = "stablehlo.select" + + # assembly_format = """ + # operands attr-dict `:` + # custom(type($pred), type($on_true), type($on_false), type($result)) + # """ + + pred = operand_def(HLO_PredTensor) + on_true = operand_def(HLO_Tensor) + on_false = operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + + traits = traits_def( + NoMemoryEffect(), + ) + + +@irdl_op_definition +class ConstantOp(IRDLOperation): + """ + Produces an ``output`` tensor from a constant ``value``. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant + + Example: + ```mlir + %output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + """ + + name = "stablehlo.constant" + + value = prop_def(DenseIntOrFPElementsAttr) + output = result_def(HLO_AnyTensor) + + def __init__(self, value: DenseIntOrFPElementsAttr): + super().__init__(properties={"value": value}, result_types=(value.type,)) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py new file mode 100644 index 0000000000..5e80f6adb7 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py @@ -0,0 +1,554 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods + +""" +Unary elementwise operations for the StableHLO dialect. +""" + +import abc +from typing import Generic, TypeVar + +from xdsl.dialects.builtin import ( + I1, + AnyFloat, + AnyTensorType, + ComplexType, + TensorType, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + opt_attr_def, + result_def, + traits_def, +) +from xdsl.traits import NoMemoryEffect + +from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape + +from .attributes import ResultAccuracyMode, ResultAccuracyModeAttr +from .types import ( + HLO_FloatTensor, + HLO_FpComplexOrQuantizedIntTensor, + HLO_FpOrComplexTensor, + HLO_FpOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_NonQuantizedTensor, + HLO_PredTensor, + HLO_SIntFpComplexOrQuantizedIntTensor, +) + +# Type aliases +I1TensorType = TensorType[I1] +FloatTensorType = TensorType[AnyFloat] +FloatOrComplexType = AnyFloat | ComplexType +FloatOrComplexTensorType = TensorType[FloatOrComplexType] +ComplexTensorType = TensorType[ComplexType] + +# Generic type variables for templating +T_IN = TypeVar("T_IN", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +class ElementwiseUnaryOperation(IRDLOperation, abc.ABC, Generic[T_IN, T_OUT]): + """ + Templated base class for elementwise unary operations. + + This class provides a flexible template for unary operations that can work + with different tensor types. + + For more informtation about the semantics, see: + https://openxla.org/xla/operation_semantics#element-wise_unary_functions + """ + + operand = operand_def(T_IN) + result = result_def(T_OUT) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand attr-dict `:` custom(type($operand), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + SameOperandsAndResultShape(), + Elementwise(), + # TODO: InferShapedTypeOpInterface(), + # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(), + ) + + def __init__(self, operand: SSAValue, result_type: Attribute | None = None): + if result_type is None: + result_type = operand.type + super().__init__(operands=(operand,), result_types=(result_type,)) + + +@irdl_op_definition +class ConvertOp(ElementwiseUnaryOperation[HLO_NonQuantizedTensor, HLO_NonQuantizedTensor]): + """ + Performs an element-wise conversion from one element type to another on + `operand` tensor and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert + + Example: + ```mlir + %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex> + ``` + """ + + name = "stablehlo.convert" + + traits = traits_def(SameOperandsAndResultShape()) + + +@irdl_op_definition +class CosineOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise cosine operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine + + Example: + ```mlir + %result = stablehlo.cosine %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.cosine" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class ExponentialMinusOneOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise exponential minus one operation on `operand` tensor + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one + + Example: + ```mlir + %result = stablehlo.exponential_minus_one %operand : tensor<2xf64> + ``` + """ + + name = "stablehlo.exponential_minus_one" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class ExponentialOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise exponential operation on `operand` tensor and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential + + Example: + ```mlir + %result = stablehlo.exponential %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.exponential" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class FloorOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]): + """ + Performs element-wise floor of `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor + + Example: + ```mlir + %result = stablehlo.floor %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.floor" + + +@irdl_op_definition +class ImagOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]): + """ + Extracts the imaginary part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag + + Example: + ```mlir + %result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32> + ``` + """ + + name = "stablehlo.imag" + + +@irdl_op_definition +class IsFiniteOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_PredTensor]): + """ + Performs element-wise check whether the value in `x` is finite (i.e. is + neither +Inf, -Inf, nor NaN) and produces a `y` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite + + Example: + ```mlir + %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1> + ``` + """ + + name = "stablehlo.is_finite" + + +@irdl_op_definition +class LogOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise logarithm operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log + + Example: + ```mlir + %result = stablehlo.log %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.log" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class LogPlusOneOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise logarithm plus one operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one + + Example: + ```mlir + %result = stablehlo.log_plus_one %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.log_plus_one" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class LogisticOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise logistic operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic + + Example: + ```mlir + %result = stablehlo.logistic %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.logistic" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class NegateOp( + ElementwiseUnaryOperation[ + HLO_IntFpOrComplexOrQuantizedIntTensor, HLO_IntFpOrComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise negation of `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate + + Example: + ```mlir + %result = stablehlo.negate %operand : tensor<2x3xi32> + ``` + """ + + name = "stablehlo.negate" + + +@irdl_op_definition +class RealOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]): + """ + Extracts the real part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real + + Example: + ```mlir + %result = stablehlo.real %operand : tensor<2xcomplex> : tensor<2xf32> + ``` + """ + + name = "stablehlo.real" + + +@irdl_op_definition +class RoundNearestAfzOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise rounding towards the nearest integer, breaking ties + away from zero, on the `operand` tensor and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz + + Example: + ```mlir + %result = stablehlo.round_nearest_afz %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.round_nearest_afz" + + +@irdl_op_definition +class RoundNearestEvenOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise rounding towards the nearest integer, breaking ties + towards the even integer, on the `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even + + Example: + ```mlir + %result = stablehlo.round_nearest_even %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.round_nearest_even" + + +@irdl_op_definition +class RsqrtOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise reciprocal square root operation on `operand` tensor + and produces a `result` tensor, implementing the `rSqrt` operation from the + IEEE-754 specification. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt + + Example: + ```mlir + %result = stablehlo.rsqrt %operand : tensor<2x2xf32> + ``` + """ + + name = "stablehlo.rsqrt" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class SignOp( + ElementwiseUnaryOperation[ + HLO_SIntFpComplexOrQuantizedIntTensor, HLO_SIntFpComplexOrQuantizedIntTensor + ] +): + """ + Returns the sign of the `operand` element-wise and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign + + Example: + ```mlir + %result = stablehlo.sign %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.sign" + + +@irdl_op_definition +class SineOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise sine operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine + + Example: + ```mlir + %result = stablehlo.sine %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.sine" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class SqrtOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise square root operation on `operand` tensor and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt + + Example: + ```mlir + %result = stablehlo.sqrt %operand : tensor<2x2xf32> + ``` + """ + + name = "stablehlo.sqrt" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class TanOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise tangent operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan + + Example: + ```mlir + %result = stablehlo.tan %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.tan" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class TanhOp( + ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor] +): + """ + Performs element-wise hyperbolic tangent operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh + + Example: + ```mlir + %result = stablehlo.tanh %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.tanh" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py new file mode 100644 index 0000000000..6133b5e6c7 --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py @@ -0,0 +1,167 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods + +""" +Dynamism operations for the StableHLO dialect. +""" + + +from xdsl.dialects.builtin import ( + ArrayAttr, + BoolAttr, + DenseIntElementsAttr, + DictionaryAttr, + FlatSymbolRefAttr, + StringAttr, + TensorType, + TupleType, +) +from xdsl.ir import Attribute +from xdsl.irdl import ( + AnyAttr, + IRDLOperation, + irdl_op_definition, + opt_prop_def, + prop_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + MemoryEffect, +) +from xdsl.utils.exceptions import VerifyException + +from .attributes import CustomCallApiVersion, CustomCallApiVersionAttr, OutputOperandAlias + + +@irdl_op_definition +class CustomCallOp(IRDLOperation): + """ + Encapsulates an implementation-defined operation ``call_target_name`` that + takes ``inputs`` and ``called_computations`` and produces ``results``. + + Depending on the API version there are two ways to pass extra bits of static + information to the external function: + 1. Use ``API_VERSION_TYPED_FFI`` which allows passing a dictionary attribute. + 2. Use a previous API version with a ``StringAttr`` to encode backend config. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call + + Example: + ```mlir + %results = stablehlo.custom_call @foo(%input0) { + backend_config = {bar = 42 : i32}, + api_version = 4 : i32, + called_computations = [@foo] + } : (tensor) -> tensor + ``` + """ + + name = "stablehlo.custom_call" + + inputs = var_operand_def(AnyAttr()) + call_target_name = prop_def(StringAttr) + has_side_effect = prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + backend_config = opt_prop_def(DictionaryAttr | StringAttr) + api_version = prop_def( + CustomCallApiVersionAttr, + default_value=CustomCallApiVersionAttr(CustomCallApiVersion.API_VERSION_ORIGINAL), + ) + called_computations = opt_prop_def(ArrayAttr[FlatSymbolRefAttr], default_value=ArrayAttr([])) + operand_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr]) + result_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr]) + output_operand_aliases = prop_def(ArrayAttr[OutputOperandAlias]) + + result = var_result_def(AnyAttr()) + + traits = traits_def( + MemoryEffect(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($call_target_name) `(` $inputs `)` + # attr-dict `:` functional-type(operands, results) + # """ + + def verify_(self) -> None: + """Verify the CustomCallOp.""" + # If both operand and result layout attributes are not specified then nothing to verify. + if self.operand_layouts is None and self.result_layouts is None: + return + + # Layout constraints for either both operands & results or none should be specified. + if (self.operand_layouts is None) != (self.result_layouts is None): + raise VerifyException( + "Layout attributes should be specified for either both operands and results or none." + ) + + assert self.operand_layouts is not None and self.result_layouts is not None + + def verify_types_and_layouts( + types: tuple[Attribute, ...], layouts: ArrayAttr, value_name: str + ): + if len(types) != len(layouts.data): + raise VerifyException( + "Number of " + f"{value_name}s must match the number of {value_name} layouts, " + f"{len(types)} != {len(layouts.data)}" + ) + + for index, (ty, layout_attr) in enumerate(zip(types, layouts.data)): + # Tuple types are not fully supported with layout constraints yet + if isinstance(ty, TupleType): + raise VerifyException( + "Tuple types are not fully supported with layout constraints yet" + ) + + try: + dims = list(layout_attr.get_values()) + except Exception as exc: + raise VerifyException("invalid layout attribute") from exc + + # For non-tensor types, layout must be empty + if not isinstance(ty, TensorType): + if len(dims) == 0: + continue + raise VerifyException( + "Only tensor types can have non-empty layout: " + f"{value_name} #{index} of type {ty} has layout {dims}" + ) + + # For ranked tensors, require permutation of [0, rank) + rank = ty.get_num_dims() + if rank != len(dims) or sorted(dims) != list(range(rank)): + raise VerifyException( + f"incorrect layout {dims} for type {ty}, layout must be a permutation of [0, {rank})" + ) + + # Operand types + operand_types: tuple[Attribute, ...] = tuple(op.type for op in self.operands) + + # Result types: if single tuple result, use its element types + if len(self.result_types) == 1 and isinstance(self.result_types[0], TupleType): + tuple_ty: TupleType = self.result_types[0] + result_types = tuple(tuple_ty.types.data) + else: + result_types = tuple(self.result_types) + + # Verify that operands and operand layouts match. + verify_types_and_layouts(operand_types, self.operand_layouts, "operand") + # Verify that results and result layouts match. + verify_types_and_layouts(result_types, self.result_layouts, "result") diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py new file mode 100644 index 0000000000..3597aaaeae --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py @@ -0,0 +1,160 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods + +""" +Dynamism operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import DenseArrayBase, i64 +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + prop_def, + region_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.irdl.operations import SameVariadicOperandSize +from xdsl.traits import ( + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) +from xdsl.utils.exceptions import VerifyException +from xdsl_jax.dialects import stablehlo as xstablehlo + +from .types import HLO_Tensor + + +@irdl_op_definition +class ReduceOp(IRDLOperation): + """ + Applies a reduction function ``body`` to ``inputs`` and ``init_values`` along the + ``dimensions`` and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce + + Example: + ```mlir + %result = "stablehlo.reduce"(%input, %init_value) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + dimensions = array + } : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + ``` + """ + + name = "stablehlo.reduce" + + inputs = var_operand_def(HLO_Tensor) + init_values = var_operand_def(HLO_Tensor) + dimensions = prop_def(DenseArrayBase.constr(i64)) + result = var_result_def(HLO_Tensor) + body = region_def("single_block") + + irdl_options = [SameVariadicOperandSize()] + + traits = traits_def( + RecursiveMemoryEffect(), + # TODO: InferShapedTypeOpInterface(), + # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic, + # TODO: InferTensorTypeWithReify(), + SingleBlockImplicitTerminator(xstablehlo.ReturnOp), + ) + + # pylint: disable=no-member + # pylint: disable=too-many-branches + def verify_(self): + """Verify the ReduceOp.""" + # Gather shaped operand/result types + input_types = [op.type for op in self.inputs] + init_types = [op.type for op in self.init_values] + + # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility + dims_attr = self.dimensions + dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple() + + # Basic structural checks mirroring verifyReduceOpInputsAndInferShape + if len(input_types) == 0: + raise VerifyException("expected at least 1 input for reduce") + if len(input_types) != len(init_types): + raise VerifyException("number of inputs must match number of init_values") + + # All inputs must have equal rank; dimensions must be within rank and unique + # and not empty. + ranks = [] + for t in input_types: + # Tensors by op definition + assert hasattr(t, "get_num_dims") + ranks.append(t.get_num_dims()) + rank0 = ranks[0] + if any(r != rank0 for r in ranks): + raise VerifyException("all inputs must have the same rank") + + if len(dims) == 0: + raise VerifyException("dimensions cannot be empty for reduce") + if len(set(dims)) != len(dims): + raise VerifyException("dimensions should not have duplicates") + if any(d < 0 or d >= rank0 for d in dims): + raise VerifyException("dimensions contains an invalid value") + + # Element type compatibility between each input and its init value + for it, iv in zip(input_types, init_types): + it_elem = it.get_element_type() + iv_elem = iv.get_element_type() + if it_elem != iv_elem: + raise VerifyException("input and init_value must have the same element type") + + # reduce_c2/c6: verify reducer region shape + # Expect block with arity 2 * number of inputs, with matching tensor element types and 0D tensors + if len(self.body.blocks) != 1: + raise VerifyException("reducer must have a single block") + block = self.body.blocks[0] + + expected_args = 2 * len(input_types) + if len(block.args) != expected_args: + raise VerifyException( + f"reducer must take {expected_args} arguments, got {len(block.args)}" + ) + + # Each pair (arg_i, arg_{i+N}) must be 0D tensors of the input element type + for i, it in enumerate(input_types): + it_elem = it.get_element_type() + acc = block.args[i] + val = block.args[i + len(input_types)] + for a in (acc, val): + a_ty = a.type + if not hasattr(a_ty, "get_num_dims") or a_ty.get_num_dims() != 0: + raise VerifyException("reducer arguments must be rank-0 tensors") + if a_ty.get_element_type() != it_elem: + raise VerifyException( + "reducer argument element types must match input element type" + ) + + # Region must terminate with exactly len(inputs) results + ret = block.ops.last + if len(ret.operands) != len(input_types): + raise VerifyException("reducer must return exactly one value per input") + for i, it in enumerate(input_types): + it_elem = it.get_element_type() + rty = ret.operands[i].type + if not hasattr(rty, "get_num_dims") or rty.get_num_dims() != 0: + raise VerifyException("reducer return values must be rank-0 tensors") + if rty.get_element_type() != it_elem: + raise VerifyException("reducer return element types must match input element type") diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/types.py b/frontend/catalyst/python_interface/dialects/stablehlo/types.py new file mode 100644 index 0000000000..7bbbf5ad5f --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/stablehlo/types.py @@ -0,0 +1,249 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO type definitions for PennyLane's compiler infrastructure. + +This module provides type definitions based on the StableHLO specification +(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including +token types and other necessary type definitions for StableHLO operations. +""" + +# pylint: disable=too-few-public-methods + +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + AnyFloatConstr, + ComplexType, + Float32Type, + Float64Type, + IndexType, + IntAttr, + IntAttrConstraint, + IntegerType, + ParametrizedAttribute, + Signedness, + SignednessAttr, + TensorType, + i1, +) +from xdsl.irdl import eq, irdl_attr_definition +from xdsl.irdl.attributes import EqAttrConstraint, ParamAttrConstraint +from xdsl.irdl.constraints import IntSetConstraint +from xdsl_jax.dialects.stablehlo import TokenType + +from catalyst.python_interface.xdsl_extras.constraints import ( + NestedTupleOfConstraint, +) + + +def _create_param_constrained_type( + base_attr: type, widths: list[int], signedness: Signedness | None = None +): + """Create an integer type constrained using ParamAttrConstraint with IntSetConstraint.""" + width_constraint = IntAttrConstraint(IntSetConstraint(frozenset(widths))) + + if signedness is None: + signedness_constraint = None + else: + signedness_constraint = EqAttrConstraint(SignednessAttr(signedness)) + + return ParamAttrConstraint(base_attr, [width_constraint, signedness_constraint]) + + +# ============================================================================= +# Core StableHLO types constraints +# ============================================================================= + +HLO_Pred = eq(i1) +HLO_PredTensor: TypeAlias = TensorType[HLO_Pred] + +# NOTE: IntegerType is defined in the StableHLO spec as: +# IntegerType ::= SignedIntegerType | UnsignedIntegerType, +# but the MLIR implementation is using signless integers instead of signed, +# and there is a TODO to fix it. + +_HLO_INT_WIDTHS = [2, 4, 8, 16, 32, 64] +HLO_SignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.SIGNED) +HLO_UnsignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.UNSIGNED) +HLO_SignlessInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, None) + +HLO_Int: TypeAlias = HLO_UnsignedInt | HLO_SignlessInt +HLO_IntTensor: TypeAlias = TensorType[HLO_Int] + +_HLO_INT_OR_PRED_WIDTHS = [1, 2, 4, 8, 16, 32, 64] +HLO_IntOrPred = _create_param_constrained_type(IntegerType, _HLO_INT_OR_PRED_WIDTHS, None) + + +HLO_AnyIntegerOrIndex: TypeAlias = IntegerType | IndexType +HLO_AnyIntegerOrIndexTensor: TypeAlias = TensorType.constr(HLO_AnyIntegerOrIndex) + +HLO_DimensionValue: TypeAlias = HLO_Int | IndexType + +# Constraint variants for use in unions with ParamAttrConstraint +HLO_Float: TypeAlias = AnyFloatConstr +HLO_Float32Or64: TypeAlias = Float32Type | Float64Type +HLO_FloatTensor: TypeAlias = TensorType.constr(HLO_Float) +HLO_Fp32Or64Tensor: TypeAlias = TensorType.constr(HLO_Float32Or64) + +# Complex as a constraint over element types {f32,f64} +HLO_Complex: TypeAlias = ComplexType[HLO_Float32Or64] +HLO_ComplexTensor: TypeAlias = TensorType.constr(HLO_Complex) + +# ============================================================================= +# Quantized element type definitions +# ============================================================================= + + +@irdl_attr_definition +class UniformQuantizedType(ParametrizedAttribute): + """ + Placeholder for StableHLO per-tensor uniform quantized types. + + Parameterized by width to support different quantized integer widths + (e.g., 8-bit, 16-bit quantization). + """ + + name = "stablehlo.uniform_quantized" + width: IntAttr + signedness: SignednessAttr + + +@irdl_attr_definition +class UniformQuantizedPerAxisType(ParametrizedAttribute): + """ + Placeholder for StableHLO per-axis uniform quantized types. + + Parameterized by width to support different quantized integer widths + (e.g., 8-bit, 16-bit quantization). + """ + + name = "stablehlo.uniform_quantized_per_axis" + width: IntAttr + signedness: SignednessAttr + + +# ============================================================================= +# StableHLO quantized type aliases +# ============================================================================= + +_HLO_QUANTIZED_WIDTHS = [2, 4, 8, 16, 32] + +# Constraint-based types for operation definitions +HLO_QuantizedSignedInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED +) +HLO_QuantizedUnsignedInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED +) +HLO_QuantizedAnySignednessInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, None +) +HLO_QuantizedInt: TypeAlias = HLO_QuantizedSignedInt | HLO_QuantizedUnsignedInt + +HLO_PerAxisQuantizedSignedInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED +) +HLO_PerAxisQuantizedUnsignedInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED +) +HLO_PerAxisQuantizedAnySignednessInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, None +) +HLO_PerAxisQuantizedInt: TypeAlias = HLO_PerAxisQuantizedSignedInt | HLO_PerAxisQuantizedUnsignedInt + +# ============================================================================= +# Main tensor type definitions +# ============================================================================= + +HLO_Tensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt] +HLO_NonQuantizedTensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred] + +# Note: There is a discrepancy between the StableHLO spec and the MLIR implementation. +# The spec does not allow unranked tensors, but the MLIR implementation +# defines it as a tensor of any type and rank. There is a TODO to fix this in MLIR. +# Therefore, we use the correct ranked tensor type. +HLO_AnyTensor: TypeAlias = TensorType[ + HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt | HLO_PerAxisQuantizedInt +] +HLO_TensorOrToken: TypeAlias = HLO_Tensor | TokenType +HLO_TensorOrPerAxisQuantizedTensorOrToken: TypeAlias = HLO_AnyTensor | TokenType + +# HLO_AnyTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Token]> +HLO_AnyTuple = NestedTupleOfConstraint([HLO_AnyTensor, TokenType]) + +HLO_CustomCallValue: TypeAlias = HLO_Tensor | TokenType | HLO_AnyTuple + +# ============================================================================= +# HLO combined type definitions +# ============================================================================= + +HLO_PredOrIntTensor: TypeAlias = TensorType.constr(HLO_IntOrPred) + +HLO_FpOrComplexTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_Complex) +HLO_FpOrQuantizedIntTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_QuantizedInt) +HLO_FpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_Float | HLO_Complex | HLO_QuantizedInt +) +HLO_IntFpOrComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_Int | HLO_Float | HLO_Complex | HLO_QuantizedInt +) +HLO_SIntFpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_SignedInt | HLO_Float | HLO_Complex | HLO_QuantizedInt +) + + +__all__ = [ + # Core types + "HLO_Pred", + "HLO_PredTensor", + "HLO_Int", + "HLO_IntTensor", + "HLO_AnyIntegerOrIndex", + "HLO_AnyIntegerOrIndexTensor", + "HLO_DimensionValue", + "HLO_Float", + "HLO_Float32Or64", + "HLO_FloatTensor", + "HLO_Fp32Or64Tensor", + "HLO_ComplexTensor", + "HLO_SignedInt", + "HLO_UnsignedInt", + "HLO_SignlessInt", + "HLO_QuantizedSignedInt", + "HLO_QuantizedUnsignedInt", + "HLO_QuantizedAnySignednessInt", + "HLO_QuantizedInt", + "HLO_PerAxisQuantizedSignedInt", + "HLO_PerAxisQuantizedUnsignedInt", + "HLO_PerAxisQuantizedAnySignednessInt", + "HLO_PerAxisQuantizedInt", + # Quantized types + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "HLO_Tensor", + "HLO_NonQuantizedTensor", + "HLO_AnyTensor", + "HLO_TensorOrToken", + "HLO_TensorOrPerAxisQuantizedTensorOrToken", + "HLO_CustomCallValue", + # Combined types + "HLO_PredOrIntTensor", + "HLO_FpOrComplexTensor", + "HLO_FpOrQuantizedIntTensor", + "HLO_FpComplexOrQuantizedIntTensor", + "HLO_IntFpOrComplexOrQuantizedIntTensor", + "HLO_SIntFpComplexOrQuantizedIntTensor", +] diff --git a/frontend/catalyst/python_interface/dialects/transform.py b/frontend/catalyst/python_interface/dialects/transform.py new file mode 100644 index 0000000000..1e0d58963f --- /dev/null +++ b/frontend/catalyst/python_interface/dialects/transform.py @@ -0,0 +1,123 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains an updated version of the transform dialect. +As of the time of writing, xDSL uses the MLIR released with LLVM's +version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916 +will be updating MLIR where the transform dialect has the +`apply_registered_pass` operation re-defined. + +See the following changelog on the above PR + + Things related to transform.apply_registered_pass op: + + It now takes in a dynamic_options + + [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as + a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static. + https://github.com/llvm/llvm-project/pull/142683 + + The options it takes in are now dictionaries instead of strings + [MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159 + https://github.com/llvm/llvm-project/pull/143159 + +This file will re-define the apply_registered_pass operation in xDSL +and the transform dialect. + +Once xDSL moves to a newer version of MLIR, these changes should +be contributed upstream. +""" + +from xdsl.dialects.builtin import Dialect + +# pylint: disable=too-few-public-methods +from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp +from xdsl.dialects.transform import ( + DictionaryAttr, + StringAttr, +) +from xdsl.dialects.transform import Transform as xTransform +from xdsl.dialects.transform import ( + TransformHandleType, + irdl_op_definition, + operand_def, + prop_def, + result_def, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import IRDLOperation, ParsePropInAttrDict + + +@irdl_op_definition +class ApplyRegisteredPassOp(IRDLOperation): + """ + See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop). + """ + + name = "transform.apply_registered_pass" + + options = prop_def(DictionaryAttr, default_value=DictionaryAttr({})) + pass_name = prop_def(StringAttr) + target = operand_def(TransformHandleType) + result = result_def(TransformHandleType) + # While this assembly format doesn't match + # the one in upstream MLIR, + # this is because xDSL currently lacks CustomDirectives + # https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives + # https://github.com/xdslproject/xdsl/pull/4829 + # However, storing the property in the attribute should still work + # specially when parsing and printing in generic format. + # Which is how Catalyst and XDSL currently communicate at the moment. + # TODO: Add test. + assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)" + irdl_options = [ParsePropInAttrDict()] + + def __init__( + self, + pass_name: str | StringAttr, + target: SSAValue, + options: dict[str | StringAttr, Attribute | str | bool | int] | None = None, + ): + if isinstance(pass_name, str): + pass_name = StringAttr(pass_name) + + if isinstance(options, dict): + options = DictionaryAttr(options) + + super().__init__( + properties={ + "pass_name": pass_name, + "options": options, + }, + operands=[target], + result_types=[target.type], + ) + + +# Copied over from xDSL's sources +# the main difference will be the use +# of a different ApplyRegisteredPassOp +operations = list(xTransform.operations) +del operations[operations.index(xApplyRegisteredPassOp)] +operations.append(ApplyRegisteredPassOp) + +Transform = Dialect( + "transform", + [ + *operations, + ], + [ + *xTransform.attributes, + ], +) diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst new file mode 100644 index 0000000000..b2f88b6060 --- /dev/null +++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst @@ -0,0 +1,1376 @@ +Unified Compiler Cookbook +========================= + +**Note:** The cookbook is developed with the following package versions, +on Python 3.12.11: + +.. code-block:: bash + + jax==0.6.2 + jaxlib==0.6.2 + numpy==2.3.1 + pennylane==0.44.0-dev19 + pennylane-lightning==0.43.0 + pennylane-catalyst==0.14.0-dev15 + xdsl==0.53.0 + xdsl-jax==git+https://github.com/xdslproject/xdsl-jax.git@895f7c13e8d0f02bbe99d7fb9ebcaafea4ea629f#egg=xdsl_jax + +Note that ``xdsl-jax`` does not currently have a release published on +PyPI, so it needs to be installed from GitHub by running the following: + +.. code-block:: bash + + pip install git+https://github.com/xdslproject/xdsl-jax.git + +Motivation +========== + +As we approach FTQC, quantum compilation becomes a more and more +important research area. Catalyst uses MLIR as its intermediate +representation, which is also the layer in which a majority of the +optimizations happen. However, quantum compilation researchers are not +likely to be accustomed with MLIR or C++. + +So, the motivation of the “Unified Compiler” is to provide a Python +layer in which compilation passes can be implemented and applied. +Additionally, we also want to enable researchers to use abstractions +that they are familiar with. We’re aiming to do this using xDSL, which +is a reimplementation of MLIR in Python. + +This document is meant to be a quickstart for users that are interested +in developing compiler passes for Catalyst, but are not familiar with +MLIR or xDSL. + +MLIR basics +=========== + +If readers are already familiar to some degree with MLIR, this section +can be skipped. + +SSA +--- + +“In compiler design, static single assignment form (often abbreviated as +SSA form or simply SSA) is a type of intermediate representation (IR) +where each variable is assigned exactly once” [`1 <#references>`__]. + +SSA is powerful because it allows us to define chains of uses and +definitions, i.e, we can keep track of which operation created a +variable (since each variable only gets created once), and all +operations that use that variable. These chains of uses and definitions, +if used well, can make transformations quite performant, and in some +cases, very simple to implement, but they can also make the IR harder to +parse. + +IR structure +------------ + +MLIR represents programs using a hierarchical, graph-like structure. In +this structure, nodes are called *operations*, and edges are called +*values*. Each value is the result of exactly one operation or argument +for a block (more on that later), and has a type that is defined by the +type system (more on that later also) [`2 <#references>`__]. + +The IR is recursively nested - operations may contain regions, which +contain blocks, which contain operations. More concretely, an operation +may contain zero or more regions, each of which must contain one or more +blocks, each of which holds a list of arguments and an ordered list of +operations that may use those arguments [`3 <#references>`__]. + +Operations +~~~~~~~~~~ + +Operations are basic units of execution. Operations are fully +extensible, i.e., there is no fixed list of operations. Operations can +return zero or more results, take in zero or more operands, declare +properties and attributes, and can have zero or more successors and +regions [`2 <#references>`__]. + +Regions +~~~~~~~ + +A region is an ordered list of blocks. Its semantics are defined by the +operation that contains them [`2 <#references>`__]. For example, an +``scf.IfOp``, which represents a conditional operation, contains two +regions, one for the true branch, and another for the false branch, and +the way we interpret these regions is dependent on the fact that they +belong to the ``scf.IfOp``. + +Blocks +~~~~~~ + +Blocks are lists of operations. The operations inside blocks are +executed in order. Blocks take a list of block arguments, annotated in a +function-like way. The first block in a region is special, and is called +the “entry block”. The block arguments of the entry block are also +arguments to its outer region [`2 <#references>`__]. + +Values +~~~~~~ + +In MLIR, there are computable values with a type, a single defining +operation, and zero or more uses [`4 <#references>`__]. These values are +either the results of operations or block arguments, and adhere to SSA +semantics. + +Dialects +-------- + +MLIR uses dialects to allow developers to define a set of high level +operations, attributes, and types, which can be used in the intermediate +representation to represent custom subroutines, etc. and can be +converted into lower level representations using interpretation rules +that can be defined. For example, the MLIR ``arith`` dialect defines +operations for arithmetic computations, the ``linalg`` dialect defines +linear algebra operations, etc. We use dialects to define quantum +instructions such as gates, measurements, and attributes such as qubits, +quantum registers. + +Def-use and use-def chains +-------------------------- + +Frequently, we might have a value and we want to determine which +instructions use this value. The list of all users of a value is the +*def-use chain*. Conversely, we might have an instruction and we want to +determine which values it uses. The list of values used by an +instruction is the *use-def chain* [`5 <#references>`__]. These chains +allow us to iterate through operations topologically, which can be very +powerful when implementing passes. + +xDSL API +======== + +Now that we’re familiar with the high-level constructs that MLIR uses, +let’s go over what their xDSL actual implementation looks like. + +``SSAValue`` +------------ + +``SSAValue`` is the class used to represent variables. SSA values may +used by operations as operands, and returned as results. The three key +properties that this class provides are listed below: + +- ``type``: The type of the value. Since SSA values are variables, their + value is not known at compile time, but their type is. +- ``uses``: A set of all operations that use a given ``SSAValue`` as an + operand. The operations are wrapped around a convenience class called + ``Use``, and the corresponding operation can be accessed using + ``Use.operation``. +- ``owner``: The operation or block that defined a given ``SSAValue`` + (more on that later) + +``SSAValue`` has two subclasses, which will be seen a lot when +inspecting xDSL modules. These are: + +- ``OpResult``: This subclass represents SSA values that are defined by + an operation + + .. code-block:: + + c = add a b + + In the above pseudocode, ``c`` is defined by the operation ``add``, + and in xDSL, will be represented as an ``OpResult`` instance. + +- ``BlockArgument``: This subclass represents SSA values that are block + arguments + +``Attribute`` +------------- + +An attribute defines a compile-time value. These can be used to define +types, and can also be used to define properties of operations that are +known at compile time. For example: + +- ``CustomOp``, which is an operation from the ``Quantum`` dialect (more + on that later) has a ``gate_name`` that must be a string, represented + by the ``StringAttr`` attribute. ``StringAttr`` has a reference to the + concrete string representing the gate name, and we can access this + concrete value at compile-time using ``CustomOp.gate_name.data``. +- ``CustomOp`` also takes qubits as inputs and outputs, which are of + type ``QubitType``. The ``QubitType`` class inherits ``Attribute``, + but we use it to declare a type that can be used to define SSA values. + +Below is the definition of ``QubitType`` to illustrate: + +.. code-block:: python + + # TypeAttribute just means that this attribute can be used to represent + # the types of the operands and results of an operation, but not its + # properties. + # ParametrizedAttribute means that this attribute can take parameters as + # input (although QubitType doesn't have any parameters). + @irdl_attr_definition + class QubitType(ParametrizedAttribute, TypeAttribute): + """A value-semantic qubit (state).""" + + name = "quantum.bit" + +``Operation`` +------------- + +The ``Operation`` class is used to represent operations, which are basic +units of execution. All instructions in a module are operations. In +fact, the modules themselves are operations. Operations contain several +fields used to define their form and function: + +- Operands: operands are runtime values that the operation consumes. + Note that when defining an operation, we only declare the *types* of + the operands, not the actual operands. Only when we construct an + instance of an operation do we provide actual ``SSAValue``\ s as the + operands, and these must adhere to the type system. +- Properties: properties are compile-time values used to define the + semantics of an operation. For example, the ``CustomOp`` operation + that is used to define quantum gates in Catalyst has a property called + ``gate_name``, which is a string specifier of the gate’s name. This + name directly impacts how the operation should be interpreted, and its + value is known at compile-time. +- Attributes: Operation attributes are stored as a dictionary, + containing more compile-time values. Generally, these don’t get used + much in xDSL, but they serve a purpose very similar to properties. +- Result types: operations may return values, and if so, the types of + the return values must be defined. + +Below is the definition of ``CustomOp`` from the ``Quantum`` dialect, +which represents general quantum gates to illustrate what defining +operations looks like: + +.. code-block:: python + + @irdl_op_definition + class CustomOp(IRDLOperation): + """A generic quantum gate on n qubits with m floating point parameters.""" + + name = "quantum.custom" + + # assembly_format defines what the operation should look like + # when pretty-printed in the IR + assembly_format = """ + $gate_name `(` $params `)` $in_qubits + (`adj` $adjoint^)? + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + """ + + # These options are used because we have operands whose lengths aren't + # known (eg. different instances of CustomOp may have different number + # of qubits depending on the gate they represent (2 for CNOT, 1 for + # RX, etc.). These options basically say, "when the operation instance + # is initialized, create 2 properties that store the length of each of + # the different groups of operands and results. + irdl_options = [ + AttrSizedOperandSegments(as_property=True), + AttrSizedResultSegments(as_property=True), + ] + + # var_operand_def means that the length of this operand + # can vary. + params = var_operand_def(EqAttrConstraint(Float64Type())) + + in_qubits = var_operand_def(BaseAttr(QubitType)) + + # prop_def means that gate_name is a required property + gate_name = prop_def(BaseAttr(StringAttr)) + + # opt_prop_def means adjoint is an optional property. Additionally, + # it's type is a UnitAttr(), which essentially means that a given + # instance of CustomOp is an adjoint gate iff it has an adjoint property. + # The value of the property is irrelevant; it only gets meaning from its + # existance. + adjoint = opt_prop_def(EqAttrConstraint(UnitAttr())) + + in_ctrl_qubits = var_operand_def(BaseAttr(QubitType)) + + in_ctrl_values = var_operand_def(EqAttrConstraint(IntegerType(1))) + + # var_result_def means that the length of out_qubits can vary + out_qubits = var_result_def(BaseAttr(QubitType)) + + out_ctrl_qubits = var_result_def(BaseAttr(QubitType)) + +.. _dialects-1: + +Dialects +-------- + +The ``Dialect`` class in xDSL is used as a container around a list of +operations, types, and attributes, and it also declares a name for the +dialect. At the time of writing this, there are currently 4 custom +dialects available in the xDSL layer of Catalyst: + +- ``Quantum``: this dialect contains operations and attributes necessary + for general qubit-level operations, such as gates, measurements, + qubits, etc. +- ``Catalyst``: this dialect contains operations and attributes for + classical computing features unavailable out of the box with xDSL/MLIR +- ``QEC``: This dialect contains operations and attributes useful for + QEC, such as PPRs/PPMs. +- ``MBQC``: This dialect contains operations and attributes for + representing MBQC formalism. + +Pass API +-------- + +xDSL provides an API for defining and applying transformations on +programs (or modules), which is described below: + +``ModulePass`` +~~~~~~~~~~~~~~ + +``ModulePass`` is used to create rewrite passes over an IR module. It is +the parent class used to define compiler passes (or transforms). +``ModulePass`` has two key fields that must be implemented + +- ``name``: This is the name that is used to reference the pass. +- ``apply``: This method takes a ``ModuleOp`` as input and applies the + rewrite pattern of the pass to the module. Note that this mutates the + input module *in-place* rather than returning a new, transformed + module. + +``RewritePattern`` +~~~~~~~~~~~~~~~~~~ + +``RewritePattern`` is the class that provides the API for pattern +matching. The most important method of this class is +``match_and_rewrite``. The first argument to this method is an +operation, and it must be type-hinted using the specific operation we’re +trying to match. This type hint gets used by xDSL to match the operation +we’re trying to rewrite. The second argument is a ``PatternRewriter`` +instance. I will cover this class in detail below, but it is essentially +the class that provides the API for rewriting the operations that we’re +matching. + +For example, if I wanted to match all Hadamard gates, I would use +``CustomOp`` from the ``Quantum`` dialect in the type hint for the first +argument (since there is no ``Hadamard`` operation in the ``Quantum`` +dialect), and check in the body of the method if the op is a +``Hadamard``: + +.. code-block:: python + + from xdsl import pattern_rewriter + from pennylane.compiler.python_compiler.quantum_dialect import CustomOp + + class MyPattern(pattern_rewriter.RewritePattern): + """Dummy class for example.""" + + # This decorator is what xDSL uses to match operations + # based on the type hint. + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter + ): + if op.gate_name.data != "Hadamard": + # If not Hadamard, we do nothing + return + + # Do whatever we need + +``PatternRewriter`` +~~~~~~~~~~~~~~~~~~~ + +``PatternRewriter`` is the class that provides the API for rewriting the +IR. It includes several methods for replacing/removing/updating +operations, replacing values, replacing uses of a value with another +value, etc. In most cases, any rewriting that users want to do must be +done through this API rather than manually, as it includes state +management for keeping track of whether any changes were made, which is +necessary for the worklist algorithm (more on that later). + +Some key methods are: + +- ``replace_op``: Replaces one operation with another. +- ``replace_all_uses_with``: Replaces all uses of one value with + another. +- ``erase_op``: Erases an operation. If this operation returns any + values, all uses of these values must be updated accordingly before + the erasure. +- ``notify_op_modified``: Method to notify the rewriter that a change + was made to an operation manually. + +The example below shows us implementing a ``RewritePattern`` that +updates all ``Hadamard``\ s with ``PauliX``\ s: + +.. code-block:: python + + from xdsl import pattern_rewriter + from xdsl.dialects import builtin + from pennylane.compiler.python_compiler.dialects.quantum import CustomOp + + class HToXPattern(pattern_rewriter.RewritePattern): + """Dummy class for example.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter + ): + if op.gate_name.data != "Hadamard": + # If not Hadamard, we do nothing + return + + # Update the gate name to PauliX and notify the rewriter that + # the op was manually updated + op.gate_name = builtin.StringAttr("PauliX") + rewriter.notify_op_modified(op) + + # Alternatively, we could also create a new CustomOp for + # the PauliX from scratch, and replace the Hadamard with + # the new op: + # new_op = CustomOp( + # gate_name="PauliX", + # in_qubits=op.in_qubits, + # ) + # rewriter.replace_op(op, new_op) + +``PatternRewriteWalker`` +~~~~~~~~~~~~~~~~~~~~~~~~ + +``PatternRewriteWalker`` walks over the IR in depth-first order, and +applies a provided ``RewritePattern`` to it. By default, it implements a +worklist algorithm that keeps iterating over the operations and matching +and rewriting until a steady state is reached (i.e. no new changes are +detected; this is why ``PatternRewriter`` needs to keep track of whether +any changes were made). + +Putting everything together, we can create a ``ModulePass`` that +replaces all ``Hadamard``\ s with ``PauliX``\ s + +.. code-block:: python + + from xdsl import passes, pattern_rewriter + from xdsl.dialects import builtin + from pennylane.compiler.python_compiler.dialects.quantum import CustomOp + from pennylane.compiler.python_compiler import compiler_transform + + class HToXPattern(pattern_rewriter.RewritePattern): + """Dummy class for example.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter + ): + if op.gate_name.data != "Hadamard": + # If not Hadamard, we do nothing + return + + # Update the gate name to PauliX and notify the rewriter that + # the op was manually updated + op.gate_name = builtin.StringAttr("PauliX") + rewriter.notify_op_modified(op) + + class HToXPass(passes.ModulePass): + """Pass that replaces Hadamards with PauliXs""" + + name = "h-to-x" + + def apply(self, ctx, module): + """Apply the iterative pass.""" + walker = pattern_rewriter.PatternRewriteWalker( + pattern=HToXPattern() + ) + walker.rewrite_module(module) + + # We will cover this later + h_to_x_pass = compiler_transform(HToXPass) + +``PassPipeline`` +~~~~~~~~~~~~~~~~ + +``PassPipeline`` is a meta-pass that takes a sequence of +``ModulePass``\ es as input, and applies them to the input module. The +following example shows how a sequence of ``ModulePass``\ s can be +applied to a module ``mod``: + +.. code-block:: python + + from xdsl import passes + + pipeline = passes.PassPipeline((Pass1(), Pass2(), Pass3())) + pipeline.apply(xdsl.context.Context(), mod) + +To complete the example we’ve been building in this section, let’s put +it all together and implement a ``PassPipeline`` to apply the +``HToXPass`` to an xDSL module. + +Let’s first create the module to which we want to apply the pass. For +this, we will use the ``xdsl_from_qjit`` utility, which is described in +the “PennyLane integration” section below. + +- **Creating the module** + + .. code-block:: python + + import pennylane as qml + from pennylane.compiler.python_compiler.conversion import xdsl_from_qjit + + dev = qml.device("lightning.qubit", wires=3) + + @xdsl_from_qjit + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(): + qml.Hadamard(0) + qml.Hadamard(1) + qml.Hadamard(2) + return qml.state() + + >>> mod = circuit() + >>> print(mod) + builtin.module @circuit { + func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor + %3 = quantum.alloc(3) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = quantum.custom "Hadamard"() %5 : !quantum.bit + %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %8 = tensor.extract %7[] : tensor + %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit + %10 = quantum.custom "Hadamard"() %9 : !quantum.bit + %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor + %12 = tensor.extract %11[] : tensor + %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit + %14 = quantum.custom "Hadamard"() %13 : !quantum.bit + %15 = tensor.extract %0[] : tensor + %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit + %17 = tensor.extract %7[] : tensor + %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit + %19 = tensor.extract %11[] : tensor + %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit + %21 = quantum.compbasis qreg %20 : !quantum.obs + %22 = quantum.state %21 : tensor<8xcomplex> + quantum.dealloc %20 : !quantum.reg + quantum.device_release + func.return %22 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + } + +- **Transforming the module** + + In the above module, there are 3 ``CustomOp``\ s, each with gate name + ``Hadamard``. Let’s try applying our pass to it. Bear in mind that + passes update modules in-place: + + .. code-block:: python + + from xdsl import passes + + pipeline = passes.PassPipeline((HToXPass(),)) + pipeline.apply(xdsl.context.Context(), mod) + + >>> print(mod) + builtin.module @circuit { + func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor + %3 = quantum.alloc(3) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = quantum.custom "PauliX"() %5 : !quantum.bit + %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %8 = tensor.extract %7[] : tensor + %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit + %10 = quantum.custom "PauliX"() %9 : !quantum.bit + %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor + %12 = tensor.extract %11[] : tensor + %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit + %14 = quantum.custom "PauliX"() %13 : !quantum.bit + %15 = tensor.extract %0[] : tensor + %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit + %17 = tensor.extract %7[] : tensor + %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit + %19 = tensor.extract %11[] : tensor + %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit + %21 = quantum.compbasis qreg %20 : !quantum.obs + %22 = quantum.state %21 : tensor<8xcomplex> + quantum.dealloc %20 : !quantum.reg + quantum.device_release + func.return %22 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + } + + Great! We can see that all the ``Hadamard``\ s have been replaced with + ``PauliX``\ s, just how we wanted. + +PennyLane integration +===================== + +This section will cover the API in the ``qml.compiler.python_compiler`` +submodule. + +Lowering to MLIR +---------------- + +Catalyst compiles programs using the following workflow, when program +capture is enabled: + +- Trace user function to create ``ClosedJaxpr`` (plxpr) +- Convert plxpr to value-semantic jaxpr +- Lower value-semantic jaxpr to MLIR +- Apply passes defined in a static pipeline to the MLIR + + - These passes include optimizations, further lowering to more + elementary dialects, and lowering to LLVM + +- Generate machine code + +The integration with the xDSL layer happens after we lower to MLIR. We +currently rely on JAX’s API to lower to MLIR. This has the special +effect of lowering to a specific dialect called StableHLO, which is used +to represent all arithmetic operations present in the program. + +Once lowered to MLIR, if the original ``qjit`` decorator specified the +xDSL pass plugin, we pass control over to the xDSL layer, which applies +all transforms that were requested by the user. We can request the use +of the xDSL plugin like so: + +.. code-block:: python + + from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + ... + +.. _ir-structure-1: + +IR structure +~~~~~~~~~~~~ + +The lowered MLIR has a lot of structure to aid developers, which is +described below using the following example: + +.. code-block:: python + + import pennylane as qml + + qml.capture.enable() + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit(): + qml.X(0) + return qml.state() + +>>> print(circuit.mlir) +module @circuit { + func.func public @jit_circuit() -> tensor<2xcomplex> attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> + } + module @module_circuit { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "merge-rotations" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + %1 = transform.apply_registered_pass "remove-chained-self-inverse" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + func.func public @circuit() -> tensor<2xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %c0_i64 = arith.constant 0 : i64 + quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.state %3 : tensor<2xcomplex> + quantum.dealloc %2 : !quantum.reg + quantum.device_release + return %4 : tensor<2xcomplex> + } + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } +} + +- The program is represented as a module with a function inside it. This + function contains non-QNode user code (which there is none in our + example for simplicity), and calls to QNodes using the + ``catalyst.launch_kernel`` operation. +- QNodes are represented as modules as well—each QNode has its own + module. These modules have 4 key components: + + - The ``module attributes {transform.with_named_sequence}`` contains + the transforms that the user requested for the QNode. In our + example, those are the ``merge_rotations`` and ``cancel_inverses`` + transforms. + - The ``@circuit`` function represents the body of the QNode. Inside + this body, we initialize a device using the specified shots, + allocate a quantum register, and then apply both quantum and + classical instructions. Note that quantum instructions can *only* be + present inside this function. Note that this function will contain + an attribute called ``qnode``. + +SSA in the ``Quantum`` dialect +------------------------------ + +In the ``Quantum`` dialect, as mentioned earlier, gates are represented +using the ``CustomOp`` operation. This operation accepts ``in_qubits`` +and ``in_ctrl_qubits``, which are variable length sequences of +``QubitType`` attributes, which correspond to wire indices. +``CustomOp``\ s also return ``out_qubits`` and ``out_ctrl_qubits`` of +the same type. + +Before a ``QubitType`` can be used by any gates, it must be extracted +from the quantum register, or a ``QregType``. Quantum registers can be +thought of as a sequence containing all valid wire indices that are +available to be used by gates/measurements. Qubits can be extracted from +and inserted into a quantum register using the ``ExtractOp`` and +``InsertOp`` operations. An ``AllocOp`` is used to allocate a quantum +register with the user-provided number of device wires. + +Let’s take a look at a very simple example. Below, I have a very simple +circuit with two gates applied to the same wire. Let’s take a look at +its MLIR representation: + +- **Example** + + .. code-block:: python + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(): + qml.X(0) + qml.H(0) + return qml.state() + + >>> print(circuit.mlir) + module @circuit { + func.func public @jit_circuit() -> tensor<8xcomplex> attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> + } + module @module_circuit { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit() -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %c0_i64 = arith.constant 0 : i64 + quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %0 = quantum.alloc( 3) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.state %3 : tensor<8xcomplex> + quantum.dealloc %2 : !quantum.reg + quantum.device_release + return %4 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } + } + +**Notes** + +- ``quantum.alloc`` initializes a quantum register (``%0``) with 3 + wires, which is the number of wires used to create the PennyLane + device. +- ``quantum.extract`` extracts a qubit corresponding to wire index 0 + (``%1``) + + - ``%1`` is used as input by the ``X`` gate. + +- The ``X`` gate returns a qubit (``%out_qubits``), which is used by the + ``quantum.custom`` corresponding to the ``H`` gate. + + - Note that this is different from how the circuit is defined in + Python. Instead of just using ``%1`` again, the ``H`` gate uses + ``%out_qubits``. + +- The ``H`` gate returns a new qubit (``%out_qubits_0``). This qubit is + consumed by ``quantum.insert``, which inserts updates the quantum + register to essentially say that ``%out_qubits_0`` should be the new + qubit that corresponds to wire index 0. +- In this example, note that ``%1``, ``%out_qubits``, and + ``%out_qubits_0`` all correspond to wire index 0. This has cool + implications, that I will discuss below. + +Dynamic wires +~~~~~~~~~~~~~ + +The lowering rules for Catalyst automatically handle dynamic wires. When +a new dynamic wire, ``w``, is used, all wires used before it are first +inserted into the quantum register using ``InsertOp``. Only then does +the ``QubitType`` corresponding to ``w`` get extracted from the quantum +register using ``ExtractOp``. Consider the following example: + +- **Example** + + .. code-block:: python + + import pennylane as qml + + qml.capture.enable() + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(w1: int, w2: int): + qml.X(0) + qml.Y(w1) + qml.Z(w1) + qml.S(w2) + qml.T(w1) + qml.H(0) + return qml.state() + + >>> print(circuit.mlir) + module @circuit { + func.func public @jit_circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg0, %arg1) : (tensor, tensor) -> tensor<8xcomplex> + return %0 : tensor<8xcomplex> + } + module @module_circuit { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %c0_i64 = arith.constant 0 : i64 + quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %0 = quantum.alloc( 3) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %extracted = tensor.extract %arg0[] : tensor + %3 = quantum.extract %2[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "PauliY"() %3 : !quantum.bit + %out_qubits_1 = quantum.custom "PauliZ"() %out_qubits_0 : !quantum.bit + %extracted_2 = tensor.extract %arg0[] : tensor + %4 = quantum.insert %2[%extracted_2], %out_qubits_1 : !quantum.reg, !quantum.bit + %extracted_3 = tensor.extract %arg1[] : tensor + %5 = quantum.extract %4[%extracted_3] : !quantum.reg -> !quantum.bit + %out_qubits_4 = quantum.custom "S"() %5 : !quantum.bit + %extracted_5 = tensor.extract %arg1[] : tensor + %6 = quantum.insert %4[%extracted_5], %out_qubits_4 : !quantum.reg, !quantum.bit + %extracted_6 = tensor.extract %arg0[] : tensor + %7 = quantum.extract %6[%extracted_6] : !quantum.reg -> !quantum.bit + %out_qubits_7 = quantum.custom "T"() %7 : !quantum.bit + %extracted_8 = tensor.extract %arg0[] : tensor + %8 = quantum.insert %6[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit + %9 = quantum.extract %8[ 0] : !quantum.reg -> !quantum.bit + %out_qubits_9 = quantum.custom "Hadamard"() %9 : !quantum.bit + %10 = quantum.insert %8[ 0], %out_qubits_9 : !quantum.reg, !quantum.bit + %11 = quantum.compbasis qreg %10 : !quantum.obs + %12 = quantum.state %11 : tensor<8xcomplex> + quantum.dealloc %10 : !quantum.reg + quantum.device_release + return %12 : tensor<8xcomplex> + } + } + func.func @setup() { + quantum.init + return + } + func.func @teardown() { + quantum.finalize + return + } + } + + **Notes** + + - If a qubit is inserted into the quantum register, it is not reused + again, and a new qubit corresponding to the same wire label must be + extracted from the register. + - When a dynamic wire is going to be used for the first time, all + qubits that have been used previously without being re-inserted into + the quantum register must be inserted. + - If a dynamic wire is reused before any other wires, then it does not + need to be inserted to and extracted from the quantum register + again. + - To use static wires after dynamic wires, the dynamic wires are again + re-inserted into the quantum register. + - All of the above make it such that dynamic wires essentially create + barriers that break the qubit def-use chains. This causes some + functionality to be lost, but makes sure that we’re using qubits + safely. + +Implications/notes +~~~~~~~~~~~~~~~~~~ + +- Operations on the same wires can be tracked quickly using the chain of + definitions and uses of the qubits (def-use chains). +- Operations on dynamic wires (i.e wires whose values are only known at + run-time) are handled automatically and we don’t need to worry about + managing how we work around them. For context, this is an issue that + we found in the plxpr variant of ``cancel_inverses``, where two + operations on the same wire that were separated by another operation + on a dynamic wire would get cancelled, which is incorrect + (`source `__). +- One thing to keep in mind is that qubits (``QubitType``) and quantum + registers (``QregType``) are there so that we conform to SSA semantics + in a way that works well for our purposes. They get meaning from how + we choose to interpret them. We could just as easily have defined the + ``Quantum`` dialect and Catalyst’s lowering rules to use wire indices + the same way we do in Python, but we may lose capabilities that MLIR + and xDSL enable through the SSA form. + +``compiler_transform`` +---------------------- + +``compiler_transform`` is the function used to register xDSL +``ModulePass``\ es to be used with ``qjit``-ed workflows. It is +currently accessible as +``qml.compiler.python_compiler.compiler_transform``. + +.. code-block:: python + + from pennylane.compiler.python_compiler import compiler_transform + + class MyPass(xdsl.passes.ModulePass): + """MyPass that does something""" + + name = "my-pass" + + def apply(self, ctx, module): + # Apply the pass to module + return + + my_pass = compiler_transform(MyPass) + + # Program capture must be enabled to use the compiler transform + # as a decorator + qml.capture.enable() + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit( + pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()] + ) + @my_pass + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + circuit(1.5) + +The ``compiler_transform`` function returns an object that gives easy +access to the underlying ``ModulePass``, as well as its name as seen by +the compiler. + +>>> my_pass.module_pass +__main__.MyPass +>>> my_pass.name +'my-pass' + +Additionally, we don’t need to manually apply passes using +``PassPipeline`` when decorating QNodes with registered compiler +transforms. Those transforms get applied automatically when the workflow +is compiled! + +Conversion utilities +-------------------- + +The ``python_compiler.conversion`` submodule provides several utilities +for creating xDSL modules from Python functions. There are many more +utilities in the submodule, but I will focus on the most important ones, +and provide examples of how to use them. + +- ``xdsl_module(jitted_fn: Callable) -> Callable[Any, [xdsl.dialects.builtin.ModuleOp]]``: + Create a wrapper around a ``jax.jit``-ed function that returns an xDSL + module +- ``xdsl_from_qjit(qjitted_fn: QJIT) -> Callable[..., xbuiltin.ModuleOp]``: + Create a wrapper around a ``qjit``-ed function that returns an xDSL + module. This is currently not merged to ``master`` +- ``inline_module(from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None) -> None``: + This function takes two modules as input, and inlines the whole body + of the first module into the second. Additionally, if + ``change_main_to`` is provided, it looks for a function named + ``main``, and updates its name to ``change_main_to`` +- ``inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp, *args, **kwargs) -> None``: + This function takes a ``jax.jit``-ed function, converts it into an + xDSL module, and then inlines the contents of the xDSL module into + ``mod``. Note that this function does not return anything; instead, it + modifies ``mod`` in-place. + +Check out the :doc:`xDSL conversion utilities tutorial <./xdsl_utils_tutorial>` to see examples of how each of +the utilities can be used. + +Useful patterns +=============== + +Now that we have gone over compilers, xDSL, and how it’s being used in +PennyLane, let’s take a look at some common patterns that might be +useful. + +Post-processing functions +------------------------- + +Post-processing functions are purely classical, so we can leverage the +``xdsl_module`` utility function to create xDSL modules from Python +code, and inject it into the modules we are rewriting as needed. The +:doc:`xDSL post-processing tutorial <./xdsl_post_processing>` shows an +example where we perform very simple post-processing on a QNode that +returns an expectation value by squaring the expectation value. + +Splitting tapes +--------------- + +When implementing passes that may transform our module such that +multiple device executions are required (akin to tape transforms that +return multiple tapes), there are various strategies that can be used. +Because there is no guarantee of shared structure between the tapes, +there is no one perfect strategy that can be used for all transforms. +I’ll use tapes to provide details below about some common cases: + +- If all the tapes are identical (eg. ``dynamic_one_shot``), then the + entire execution can be put inside a ``for`` loop, with + post-processing on the execution outputs done how I showed in the + above notebook. +- If all gates are identical, but measurements are different (eg. + ``split_non_commuting``), we can capture all gates in a single + function, and then use a ``for`` loop that iterates over the number of + tapes. Within this ``for`` loop, we would call the aforementioned + function, which would evolve the state, and apply a different + measurement inside each iteration of the loop. Post-processing can be + handled same as above. +- If there is little/no shared structure between the tapes, we would + need separate functions for each of the transformed tapes. We would + need to call each function one by one, and then use the results for + post-processing. + +All of the above are very non-trivial. I will leave out code examples +for now, as that may be unnecessarily time consuming. If we get to the +stage where we need to write a transform that splits into multiple +tapes, we can revisit this section and the Python compiler/compilation +team can assist in developing such transforms. + +Note +~~~~ + +Currently, we don’t have a consistent API for implementing transforms +that split QNodes into multiple device executions (eg. +``split_non_commuting``, etc.). Thus, it is *very* hard to implement +transforms that do that. We cannot guarantee that one transform that +does split QNodes into multiple device executions will work well when +there are other transforms in the pipeline. + +Writing tests +============= + +**Note to readers**: this section is written based on how the testing +infrastructure for the Python compiler exists in PennyLane. However, the +Python compiler may be getting moved to Catalyst, in which case, the +infrastructure would likely change. + +FileCheck +--------- + +`FileCheck `__ is a +pattern matching file verifier developed by the LLVM project and is +widely used to test MLIR. Since xDSL and MLIR are syntactically +identical, we can use the string representation of xDSL modules for +testing with FileCheck. + +Below is an example of how an MLIR/xDSL string may look when populated +with directives that FileCheck can use for testing. In this example, we +have a void function ``test_func`` that takes no arguments, creates two +qubits (corresponding to different wires), and applies a ``PauliX`` to +each of these wires. We can see that there are 4 comments starting with +``CHECK`` - these comments are what FileCheck uses to match the expected +string against the actual string. + +The number of ``CHECK`` comments does not need to be the same as the +number of operations in the program - we can see below that there is no +``CHECK`` for ``func.func`` and ``return``. + +``CHECK`` statements can assign parameters that can be reused by later +``CHECK``\ s. Below, the first two ``CHECK``\ s create ``q0`` and +``q1``, which are matched using the regular expression ``%.+``, which +matches any expressions that start with ``%`` and contain at least one +character after it. The latter 2 ``CHECK``\ s then use ``q0`` and ``q1`` +as the input qubits for the assertion. + +``CHECK`` statements can contain partial statements. Below, the last two +``CHECK``\ s don’t include the outputs of the ``quantum.custom`` +operations. + +.. code-block:: python + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q1]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + %3 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + +FileCheck essentially uses the MLIR generated by a program and asserts +it against directives that can be specified by the user. Commonly used +directives are: + +- ``CHECK``: This is used to check if a specified pattern is found. Its + syntax is very simple: they are fixed strings that must occur in + order, and horizontal whitespace is ignored by default. +- ``CHECK-NOT``: This directive is used to check that the provided + string does not occur between two matches, before the first match, or + after the last match. +- ``CHECK-DAG``: This directive may be used when it’s necessary to match + strings that don’t have to occur in the same order, but it is able to + match valid topological orderings of the program DAG, with edges from + the definition to the uses of a variable. +- ``CHECK-NEXT``: This directive is used to check that the matched line + occurs right after the previously matched line with no other lines in + between them. +- ``CHECK-SAME``: This directive is used when we want to match lines and + would like to verify that matches happen on the same line as the + previous match. + +To find more details about the above directives, or to learn about other +available directives, please refer to the `FileCheck +documentation `__. + +Test dialect +------------ + +xDSL provides a ``Test`` dialect +(`source `__), +which contains many operations that are useful for unit testing. In our +testing, we found the ``TestOp`` operation to be the most useful. This +operation can produce arbitrary results, which we can use to limit +artificial dependencies on other dialects. + +For example, if I just need to assert that a specific gate is present, +without ``TestOp``, I would need my module to contain an ``AllocOp`` +that creates a quantum register, an ``ExtractOp`` that extracts a qubit +from the register, and only then I can use the qubit for the gate I’m +trying to match. Instead, I can just insert a ``TestOp`` that returns a +qubit and use that. + +This is very powerful for unit testing, as it makes writing tests much +simpler, while also limiting the scope of the test as one would expect +for unit tests. + +In the code block in the previous section, ``TestOp`` has been used in +exactly the described way - we use it to create 2 qubits that are then +used as input for the 2 ``PauliX`` gates. + +.. _pennylane-integration-1: + +PennyLane integration +--------------------- + +To use FileCheck with ``pytest``, we use the ```filecheck`` Python +package `__, which allows us to use +assertions for testing in a way that ``pytest`` can understand. All of +the ``filecheck`` API has been captured inside two fixtures available +within the ``tests/python_compiler`` folder: + +- ``run_filecheck``: This fixture is for unit testing. One can specify a + program along with filecheck directives as a multi-line string. +- ``run_filecheck_qjit``: This fixture is for integration testing. One + can create a normal ``qml.qjit``-ed workflow and include filecheck + directives as in-line comments. + +Let’s write tests for the ``HToXPass`` that was implemented in the +`“Pass API” <#pass-api>`__ sub-section to illustrate. The dev comments +will explain what is going on. + +``run_filecheck`` example +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + def test_h_to_x_pass(run_filecheck): + """Test that Hadamard gets converted into PauliX.""" + # The original program creates a qubit, and gives it to a + # CustomOp that is a Hadamard. The CHECK directives check that + # the transformed program has a CustomOp that is a PauliX applied + # to the same qubit, and no CustomOp that is a Hadamard. + + # Below we also see how we can use `test.op`, which we use to + # create a qubit to give to the Hadamard. + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %1 = quantum.custom "Hadamard"() %0 : !quantum.bit + return + } + """ + + # First, we must create the pass pipeline that we want to apply + # to our program. + pipeline = (HToXPass(),) + + # Next, we use run_filecheck to run filecheck testing on the + # program. The fixture will create an xDSL module from the program + # string, call the filecheck API, and assert correctness. + run_filecheck(program, pipeline) + + # Optionally, there are two keyword arguments that can modify the + # behaviour of run_filecheck. Both the roundtrip and verify + # arguments are false by default. + run_filecheck(program, pipeline, roundtrip=True, verify=True) + + # roundtrip=True makes it so that after parsing the program string + # to an xDSL module, we print it as a string again, and then parse + # it back into an xDSL module. This is useful when writing tests + # for dialects, so that we can check that the dialects can be printed + # parsed correctly. + + # verify=True simply runs `module.verify()`, which iteratively uses + # xDSL's verifiers to verify all operations and attributes in the + # module. + +``run_filecheck_qjit`` example +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath + + def test_h_to_x_pass_integration(run_filecheck_qjit): + """Test that Hadamard gets converted into PauliX.""" + # The original program simply applies a Hadamard to a circuit + # Remember that we need to specify the pass plugin. Additionally, + # with qjit, we can use the decorator created using + # `compiler_transform`. To make sure that the xDSL API works + # correctly, program capture must be enabled. + # qml.capture.enable() + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath]) + @h_to_x_pass + def circuit(): + # CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + # CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + # CHECK-NOT: quantum.custom "Hadamard" + qml.Hadamard(0) + return qml.state() + + # Finally, we use the run_filecheck_qjit fixture. We pass it our + # original qjitted workflow. It extracts the filecheck directives + # from the workflow, creates an MLIR program, parses it to xDSL, + # applies the specified transforms, and uses the filecheck API to + # assert correctness. + run_filecheck_qjit(circuit) + + # run_filecheck_qjit also accepts a boolean `verify` argument + # that is false by default, which works exactly the same way as + # the `verify` argument of `run_filecheck`. + run_filecheck_qjit(circuit, verify=True) + +Key blockers +============ + +There are several blockers that are currently disabling developers from +taking full advantage of the ``python_compiler`` submodule. These +include: + +* Lack of support for quantum subroutines. This impacts pattern + matching passes that need to substitute the matched operation(s) with + subroutines containing quantum instructions. + +Strategies to circumvent blockers +--------------------------------- + +* We can use dummy subroutines for now. We know what the inputs and + outputs of these subroutines should be, so we can create our own + ``FuncOp``\ s that adhere to the input/output spec and just have + their body be empty for now. To see an example where we create a + dummy quantum subroutine and use it to develop a pass, check out the + :doc:`xDSL subroutines tutorial <./xdsl_dummy_quantum_subroutines>`. + +Suggested reading +================= + +Useful dialects +--------------- + +- ``scf``: Structured control flow +- ``func``: Functions +- ``builtin``: Core types and attributes +- ``arith``: arithmetic operations +- ``stablehlo``: Advanced match dialect + +References +========== + +#. Wikimedia Foundation. (2025, August 11). *Static single-assignment + form*. Wikipedia. + https://en.wikipedia.org/wiki/Static_single-assignment_form +#. *MLIR Language Reference*. MLIR. (n.d.). + https://mlir.llvm.org/docs/LangRef/ +#. *Understanding the IR Structure*. MLIR. (n.d.-b). + https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/ +#. *Mlir::Value class reference*. MLIR. (n.d.-b). + https://mlir.llvm.org/doxygen/classmlir_1_1Value.html +#. *LLVM Programmer’s Manual*. LLVM. (n.d.). + https://llvm.org/docs/ProgrammersManual.html diff --git a/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst new file mode 100644 index 0000000000..1a951c5cc7 --- /dev/null +++ b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst @@ -0,0 +1,216 @@ +.. code-block:: python + + from dataclasses import dataclass + + import pennylane as qml + from pennylane.compiler.python_compiler.conversion import xdsl_from_qjit + from pennylane.compiler.python_compiler.dialects.quantum import CustomOp, QubitType + + from xdsl import context, passes, pattern_rewriter + from xdsl.builder import ImplicitBuilder + from xdsl.dialects import builtin, func + from xdsl.ir import Block, Region + +Convert into xDSL module +======================== + +.. code-block:: python + + dev = qml.device("lightning.qubit", wires=5) + + @xdsl_from_qjit + @qml.qjit(target="mlir") + @qml.qnode(dev) + def circuit(x): + qml.H(0) + return qml.expval(qml.Z(0)) + + +>>> qjit_mod = circuit(1.5) +>>> print(qjit_mod) +builtin.module @circuit { + func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + func.return %0 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = quantum.custom "Hadamard"() %5 : !quantum.bit + %7 = quantum.namedobs %6[PauliZ] : !quantum.obs + %8 = quantum.expval %7 : f64 + %9 = tensor.from_elements %8 : tensor + %10 = tensor.extract %0[] : tensor + %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit + quantum.dealloc %11 : !quantum.reg + quantum.device_release + func.return %9 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +Let’s create a quantum subroutine +================================= + +This subroutine’s purpose is to replace Hadamard gates with a blackbox +(which will be empty for now), so it must take a single qubit as its +argument. Additionally, we are assuming that the qubits on which the +subroutine will act are not just the ones that the subroutine takes as +input, so we must also provide the quantum register as input, and also +give it as the output. + +``FuncOp``\ s have a single region with a single block that contains the +body of the function. + +We need to build the function’s body by populating its inner ``Block`` +with operations. We can do so using the ``xdsl.builder.ImplicitBuilder`` +class. This class can be used as a context manager that takes a +``Block`` as input, and any operations created within the context of the +builder get added to the block. Let’s try it out. + +Here, we create a subroutine that we will use to replace +``Hadamard``\ s. This subroutine applies a gate provided by the user, +and returns the ``out_qubit`` of the gate. + +.. code-block:: python + + def create_hadamard_replacement_subroutine(gate_name): + input_types = (QubitType(),) + output_types = (QubitType(),) + block = Block(arg_types=input_types) + + with ImplicitBuilder(block): + in_qubits = [block.args[0]] + op1 = CustomOp(in_qubits=in_qubits, gate_name=gate_name) + func.ReturnOp(op1.out_qubits[0]) + + region = Region([block]) + funcOp = func.FuncOp("replace_hadamard", (input_types, output_types), region=region) + return funcOp + + +>>> funcOp = create_hadamard_replacement_subroutine("S") +>>> print(funcOp) +func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit { + %1 = quantum.custom "S"() %0 : !quantum.bit + func.return %1 : !quantum.bit +} + + +Now, we write a pass to do the substitution +=========================================== + +.. code-block:: python + + class ReplaceHadamardPattern(pattern_rewriter.RewritePattern): + + def __init__(self, subroutine: func.FuncOp): + self.subroutine = subroutine + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, customOp: CustomOp, rewriter: pattern_rewriter.PatternRewriter): + if customOp.gate_name.data != "Hadamard": + return + + callOp = func.CallOp( + builtin.SymbolRefAttr("replace_hadamard"), + [customOp.in_qubits[0]], + self.subroutine.function_type.outputs.data, + ) + rewriter.insert_op_after_matched_op(callOp) + rewriter.replace_all_uses_with(customOp.out_qubits[0], callOp.results[0]) + rewriter.erase_op(customOp) + + + @dataclass(frozen=True) + class ReplaceHadamardPass(passes.ModulePass): + name = "replace-hadamard" + gate_name: str + + def apply(self, ctx: context.Context, module: builtin.ModuleOp): + funcOp = create_hadamard_replacement_subroutine(self.gate_name) + module.regions[0].blocks.first.add_op(funcOp) + + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([ReplaceHadamardPattern(funcOp)]) + ).rewrite_module(module) + +Let’s see it in action +====================== + +Here, we will replace all ``Hadamard``\ s with ``S``\ s + +.. code-block:: python + + ctx = context.Context() + + pipeline = passes.PassPipeline((ReplaceHadamardPass("S"),)) + pipeline.apply(ctx, qjit_mod) + +Great! We can see below that ``Hadamard`` was replaced by a call to +``replace_hadamard``, which applies a single ``S`` gate. + +>>> print(qjit_mod) +builtin.module @circuit { + func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + func.return %0 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = func.call @replace_hadamard(%5) : (!quantum.bit) -> !quantum.bit + %7 = quantum.namedobs %6[PauliZ] : !quantum.obs + %8 = quantum.expval %7 : f64 + %9 = tensor.from_elements %8 : tensor + %10 = tensor.extract %0[] : tensor + %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit + quantum.dealloc %11 : !quantum.reg + quantum.device_release + func.return %9 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit { + %1 = quantum.custom "S"() %0 : !quantum.bit + func.return %1 : !quantum.bit + } +} diff --git a/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst new file mode 100644 index 0000000000..3a0924fd23 --- /dev/null +++ b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst @@ -0,0 +1,225 @@ +Simple tutorial for injecting functions into xDSL modules +========================================================= + +.. code-block:: python + + from dataclasses import dataclass + import jax + + import pennylane as qml + from pennylane.compiler.python_compiler.conversion import inline_module, xdsl_from_qjit, xdsl_module + + from xdsl import context, passes, pattern_rewriter + from xdsl.dialects import builtin, func + from xdsl.traits import SymbolTable + from xdsl.rewriter import InsertPoint + +Create workflow and convert to xDSL module +========================================== + +.. code-block:: python + + @xdsl_from_qjit + @qml.qjit(target="mlir") + def workflow(x, y): + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + res = circuit(x) + return res - y + + +>>> xmod = workflow(3.5, 4.5) +>>> print(xmod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + %1 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %2 = "stablehlo.subtract"(%0, %1) : (tensor, tensor) -> tensor + func.return %2 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +Now, let’s try creating a pass that squares the output of the qnode +=================================================================== + +To do so, we can use the ``inline_module`` utility to easily add our +post-processing function into the module we’re transforming. First, we +create the function that squares the input value and turn it into an +xDSL module. + +.. code-block:: python + + @jax.jit + def square(x): + return x * x + + +>>> square_mod = xdsl_module(square)(1.5) +>>> print(square_mod) +builtin.module @jit_square attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +.. code-block:: python + + def is_kernel_launch(op): + return op.name == "catalyst.launch_kernel" + + + class SquarePattern(pattern_rewriter.RewritePattern): + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + # We only rewrite the function that calls the qnode, and the caller of the qnode will + # always have catalyst.launch_kernel present. Additionally, we only rewrite the caller + # if it hasn't already been rewritten. We can put a UnitAttr() inside the caller's + # attributes to indicate whether it has been rewritten or not. + if funcOp.attributes.get("transformed") == builtin.UnitAttr() or not any( + is_kernel_launch(op) for op in funcOp.body.ops + ): + return + + # Update funcOp to inidicate that it has been rewritten + funcOp.attributes["transformed"] = builtin.UnitAttr() + + # Insert square into the module + mod = funcOp.parent_op() + inline_module(square_mod, mod, change_main_to="square") + square_fn = SymbolTable.lookup_symbol(mod, "square") + + # Call square_fn and use its results instead of the qnode's results for + # the rest of the function + for op in funcOp.body.walk(): + if is_kernel_launch(op): + callOp = func.CallOp( + builtin.SymbolRefAttr(square_fn.sym_name), + op.results, + square_fn.function_type.outputs.data, + ) + rewriter.insert_op(callOp, InsertPoint.after(op)) + + # We have inserted a CallOp that takes the output of the qnode as input. Let's call + # the qnode output %0, and the CallOp output %1. The following replaces all uses of + # %0 with %1 EXCEPT for the case where %0 is an input to callOp + op.results[0].replace_by_if(callOp.results[0], lambda use: use.operation != callOp) + rewriter.notify_op_modified(funcOp) + + + @dataclass(frozen=True) + class SquarePass(passes.ModulePass): + name = "square" + + def apply(self, ctx: context.Context, module: builtin.ModuleOp): + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([SquarePattern()]) + ).rewrite_module(module) + +Let’s apply the pass to our workflow +==================================== + +.. code-block:: python + + ctx = context.Context() + + pipeline = passes.PassPipeline((SquarePass(),)) + pipeline.apply(ctx, xmod) + +Great! Let’s see what the transformed module looks like +======================================================= + +As you can see below, the ``square_xdsl`` function is the first function +in the module, and it gets called by ``jit_workflow``, and its +inputs/outputs are consistent with the behaviour we wanted. + +>>> print(xmod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface, transformed} { + %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor + %1 = func.call @square(%0) : (tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_circuit { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func public @square(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} diff --git a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst new file mode 100644 index 0000000000..532a384e1a --- /dev/null +++ b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst @@ -0,0 +1,404 @@ +Python compiler utilities +========================= + +All utilities we care about are in the +``pennylane.compiler.python_compiler.conversion`` submodule. + +.. code-block:: python + + import pennylane as qml + + from pennylane.compiler.python_compiler.conversion import ( + inline_jit_to_module, + inline_module, + xdsl_from_qjit, + xdsl_module, + ) + +``xdsl_module`` +=============== + +This function takes a ``jax.jit``-ed function as input, and returns a +wrapper. This wrapper can be called to return an xDSL module. Note that +this function is intended to be used to covert purely classical +functions into xDSL modules. Let’s take a look at a very simple example: + +.. code-block:: python + + import jax + + + @jax.jit + def inner(x): + return x**2 + + + @jax.jit + def outer(x, y): + return inner(x) - y + + +>>> wrapped_outer = xdsl_module(outer) +>>> jit_mod = wrapped_outer(1.5, 2.5) +>>> print(jit_mod) +builtin.module @jit_outer attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +Nice! Key points to note: \* The module has the same name as the +decorated function (``outer``) with the ``jit_`` prefix. \* The entry +point function is aptly named ``main``. \* Any jitted functions called +within the entry point have their own function inside the module, and a +corresponding ``func.call`` operation where it gets called. \* If +``inner`` was not decorated with ``jax.jit``, its body would have been +inlined into ``outer``: + +.. code-block:: python + + def inner2(x): + return x**2 + + + @xdsl_module + @jax.jit + def outer2(x, y): + return inner2(x) - y + + +>>> mod2 = outer2(1.5, 2.5) +>>> print(mod2) +builtin.module @jit_outer2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0 : tensor, %arg1 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg1) : (tensor, tensor) -> tensor + func.return %1 : tensor + } +} + + +``xdsl_from_qjit`` +================== + +Since ``xdsl_module`` is for purely classical code, there is another +function that can lower hybrid quantum-classical code written in Python +to an xDSL module. ``xdsl_from_qjit`` takes a ``QJIT``-ed function as +input, and converts it into an xDSL module with the same structure as a +Catalyst program. Let’s check it out: + +.. code-block:: python + + @xdsl_from_qjit + @qml.qjit + def workflow(x, y): + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def qnode(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + return qnode(x) ** 2 - y + + +>>> qjit_mod = workflow(2.5, 3.5) +>>> print(qjit_mod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +Nice! The usefulness of having this utility is that all dialects that we +would commonly find being used at the MLIR layer are already loaded, so +users don’t need to worry about loading dialects manually. + +``inline_jit_to_module`` +======================== + +This utility takes ``xdsl_module`` a step beyond. It takes a +``jax.jit``-ed function ``func`` and an xDSL module ``mod`` as input. It +lowers ``func`` to an xDSL module, and appends the lowered module’s body +to the end of ``mod``\ ’s body. An additional step that this function +takes is that it renames the ``main`` function in the lowered module to +the same name as ``func`` so that it’s easier to find. + +Let’s try inlining the previous ``outer`` function into ``qjit_mod``. +``inline_jit_to_module`` will modify ``qjit_mod`` in-place: + +>>> inline_jit_to_module(outer, qjit_mod)(1.5, 3.5) +>>> print(qjit_mod) +builtin.module @workflow { + func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func public @outer(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +Nice! We can see that two new functions have been added to the bottom of +``qjit_mod``, corresponding to ``outer`` and ``inner``. This allows us +to make function calls to both functions, which can be useful for +embedding post-processing logic into a module, say, when applying a +compiler pass. + +``inline_module`` +================= + +This utility takes two modules as input, and inlines the body of the +first module into the body of the second module. Additionally, +recognizing that modules created from ``jax.jit``-ed functions may +contain a ``FuncOp`` called ``main``, the function also takes a string +as an optional input to rename the ``main`` function to something else. + +Let’s try revisiting the above modules and inlining them using +``inline_module`` instead of ``inline_jit_to_module``: + +.. code-block:: python + + @xdsl_from_qjit + @qml.qjit + def workflow2(x, y): + dev = qml.device("lightning.qubit", wires=5) + + @qml.qnode(dev) + def qnode(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + return qnode(x) ** 2 - y + + +>>> qjit_mod2 = workflow2(2.5, 3.5) +>>> print(qjit_mod2) +builtin.module @workflow2 { + func.func public @jit_workflow2(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} + + +.. code-block:: python + + @jax.jit + def inner3(x): + return x**2 + + + @jax.jit + def outer3(x, y): + return inner3(x) - y + + +>>> wrapped_outer3 = xdsl_module(outer3) +>>> jit_mod3 = wrapped_outer3(1.5, 2.5) +>>> print(jit_mod3) +builtin.module @jit_outer3 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner3(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner3(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + + +Now, we will inline the contents of ``jit_mod3`` into ``qjit_mod2``, and +rename ``main`` to ``outer3``. As seen below, ``outer3`` and ``inner3`` +have been inlined into ``qjit_mod2``, just like we wanted. + +One might wonder why we need both ``inline_jit_to_module`` and +``inline_module``. At this stage, the intention is just to enable as +much functionality for users of the Python compiler as possible. There +may be use cases where users may want to create their own module +manually instead of having one be created automatically by JAX or +Catalyst. + +>>> inline_module(jit_mod3, qjit_mod2, change_main_to="outer3") +>>> print(qjit_mod2) +builtin.module @workflow2 { + func.func public @jit_workflow2(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_qnode::@qnode(%arg2) : (tensor) -> tensor + %1 = "stablehlo.multiply"(%0, %0) : (tensor, tensor) -> tensor + %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor + %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor + func.return %3 : tensor + } + builtin.module @module_qnode { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @qnode(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor + %3 = quantum.alloc(5) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = tensor.extract %arg0[] : tensor + %7 = quantum.custom "RX"(%6) %5 : !quantum.bit + %8 = quantum.namedobs %7[PauliZ] : !quantum.obs + %9 = quantum.expval %8 : f64 + %10 = tensor.from_elements %9 : tensor + %11 = tensor.extract %0[] : tensor + %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit + quantum.dealloc %12 : !quantum.reg + quantum.device_release + func.return %10 : tensor + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } + func.func public @outer3(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) { + %0 = func.call @inner3(%arg1) : (tensor) -> tensor + %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + func.func private @inner3(%arg0 : tensor) -> (tensor) { + %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} diff --git a/frontend/catalyst/python_interface/parser.py b/frontend/catalyst/python_interface/parser.py new file mode 100644 index 0000000000..c6b52f05b2 --- /dev/null +++ b/frontend/catalyst/python_interface/parser.py @@ -0,0 +1,69 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for translating JAX to xDSL""" + +from collections.abc import Sequence + +from xdsl.context import Context as xContext +from xdsl.dialects import arith as xarith +from xdsl.dialects import builtin as xbuiltin +from xdsl.dialects import func as xfunc +from xdsl.dialects import scf as xscf +from xdsl.dialects import tensor as xtensor +from xdsl.ir import Dialect as xDialect +from xdsl.parser import Parser as xParser + +from catalyst.python_interface.dialects import MBQC, QEC, Catalyst, Quantum, StableHLO, Transform + + +class QuantumParser(xParser): # pylint: disable=abstract-method,too-few-public-methods + """A subclass of ``xdsl.parser.Parser`` that automatically loads relevant dialects + into the input context. + + Args: + ctx (xdsl.context.Context): Context to use for parsing. + input (str): Input program string to parse. + name (str): The name for the input. ``""`` by default. + extra_dialects (Sequence[xdsl.ir.Dialect]): Any additional dialects + that should be loaded into the context before parsing. + """ + + default_dialects: tuple[xDialect] = ( + xarith.Arith, + xbuiltin.Builtin, + xfunc.Func, + xscf.Scf, + StableHLO, + xtensor.Tensor, + Transform, + Quantum, + MBQC, + Catalyst, + QEC, + ) + + def __init__( + self, + ctx: xContext, + input: str, + name: str = "", + extra_dialects: Sequence[xDialect] | None = (), + ) -> None: + super().__init__(ctx, input, name) + + extra_dialects = extra_dialects or () + for dialect in self.default_dialects + tuple(extra_dialects): + if self.ctx.get_optional_dialect(dialect.name) is None: + self.ctx.load_dialect(dialect) diff --git a/frontend/catalyst/python_interface/pass_api/__init__.py b/frontend/catalyst/python_interface/pass_api/__init__.py new file mode 100644 index 0000000000..8f2560030f --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""xDSL transforms core API.""" + +from .apply_transform_sequence import ( + ApplyTransformSequence, + available_passes, + is_xdsl_pass, + register_pass, +) +from .transform_interpreter import TransformFunctionsExt, TransformInterpreterPass +from .compiler_transform import PassDispatcher, compiler_transform + +__all__ = [ + "ApplyTransformSequence", + "available_passes", + "is_xdsl_pass", + "PassDispatcher", + "register_pass", + "TransformFunctionsExt", + "TransformInterpreterPass", + "compiler_transform", +] diff --git a/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py new file mode 100644 index 0000000000..a9da6640a0 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py @@ -0,0 +1,85 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the pass that applies all passes present in the program representation.""" + + +from dataclasses import dataclass + +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.passes import ModulePass, PassPipeline + +from pennylane.typing import Callable + +from .transform_interpreter import TransformInterpreterPass + +available_passes = {} + + +def register_pass(name, _callable): + """Registers the passes available in the dictionary""" + available_passes[name] = _callable # pragma: no cover + + +def is_xdsl_pass(pass_name: str) -> bool: + """Check if a pass name corresponds to an xDSL implemented pass. + + This function checks if the pass is registered in PennyLane's unified compiler + pass registry, which dynamically tracks all available xDSL passes. + + Args: + pass_name (str): Name of the pass to check + + Returns: + bool: True if this is an xDSL compiler pass + """ + return pass_name in available_passes + + +@dataclass(frozen=True) +class ApplyTransformSequence(ModulePass): + """ + Looks for nested modules. Nested modules in this context are guaranteed to correspond + to qnodes. These modules are already annotated with which passes are to be executed. + The pass ApplyTransformSequence will run passes annotated in the qnode modules. + + At the end, we delete the list of passes as they have already been applied. + """ + + name = "apply-transform-sequence" + callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None + + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: # pylint: disable=no-self-use + """Applies the transformation""" + nested_modules = [] + for region in op.regions: + for block in region.blocks: + for operation in block.ops: + if isinstance(operation, builtin.ModuleOp): + nested_modules.append(operation) + + pipeline = PassPipeline( + (TransformInterpreterPass(passes=available_passes, callback=self.callback),) + ) + for operation in nested_modules: + pipeline.apply(ctx, operation) + + for mod in nested_modules: + for region in mod.regions: + for block in region.blocks: + for operation in block.ops: + if isinstance(operation, builtin.ModuleOp) and operation.get_attr_or_prop( + "transform.with_named_sequence" + ): + block.erase_op(operation) # pragma: no cover diff --git a/frontend/catalyst/python_interface/pass_api/compiler_transform.py b/frontend/catalyst/python_interface/pass_api/compiler_transform.py new file mode 100644 index 0000000000..2016c16866 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/compiler_transform.py @@ -0,0 +1,64 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core API for registering xDSL transforms for use with PennyLane and Catalyst.""" + +from collections.abc import Callable + +from catalyst.from_plxpr import register_transform +from xdsl.passes import ModulePass + +from pennylane.transforms.core.transform_dispatcher import TransformDispatcher + +from .apply_transform_sequence import register_pass + + +def _create_null_transform(name: str) -> Callable: + """Create a dummy tape transform. The tape transform raises an error if used.""" + + def null_transform(_): + raise RuntimeError( + f"Cannot apply the {name} pass without '@qml.qjit'. Additionally, program capture " + "must be enabled using 'qml.capture.enable()'. Otherwise, the pass must be applied " + f"using the '@catalyst.passes.apply_pass(\"{name}\")' decorator." + ) + + return null_transform + + +class PassDispatcher(TransformDispatcher): + """Wrapper class for applying passes to QJIT-ed workflows.""" + + name: str + module_pass: ModulePass + + def __init__(self, module_pass: ModulePass): + self.module_pass = module_pass + self.name = module_pass.name + tape_transform = _create_null_transform(self.name) + super().__init__(tape_transform) + + +def compiler_transform(module_pass: ModulePass) -> PassDispatcher: + """Wrapper function to register xDSL passes to use with QJIT-ed workflows.""" + dispatcher = PassDispatcher(module_pass) + + # Registration to map from plxpr primitive to pass + register_transform(dispatcher, module_pass.name, False) + + # Registration for apply-transform-sequence interpreter + def get_pass_cls(): + return module_pass + + register_pass(module_pass.name, get_pass_cls) + return dispatcher diff --git a/frontend/catalyst/python_interface/pass_api/transform_interpreter.py b/frontend/catalyst/python_interface/pass_api/transform_interpreter.py new file mode 100644 index 0000000000..87983ec491 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/transform_interpreter.py @@ -0,0 +1,146 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom Transform Dialect Interpreter Pass + +Differs from xDSL's upstream implementation by allowing passes +to be passed in as options and for the pipeline to have a callback and apply it +after every pass is run. The callback differs from how xDSL callback mechanism +is integrated into the PassPipeline object since PassPipeline only runs +if there are more than two passes. Here we are running one pass at a time +which will prevent the callback from being called. + + +See here (link valid with xDSL 0.46): https://github.com/xdslproject/xdsl/blob/334492e660b1726bc661efc7afb927e74bac48f4/xdsl/passes.py#L211-L222 +""" + +import io +from collections.abc import Callable + +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.dialects.transform import NamedSequenceOp +from xdsl.interpreter import Interpreter, PythonValues, impl, register_impls +from xdsl.interpreters.transform import TransformFunctions +from xdsl.parser import Parser +from xdsl.passes import ModulePass, PassPipeline +from xdsl.printer import Printer +from xdsl.rewriter import Rewriter +from xdsl.utils.exceptions import PassFailedException + +from catalyst.compiler import _quantum_opt +from catalyst.python_interface.dialects.transform import ApplyRegisteredPassOp + + +# pylint: disable=too-few-public-methods +@register_impls +class TransformFunctionsExt(TransformFunctions): + """ + Unlike the implementation available in xDSL, this implementation overrides + the semantics of the `transform.apply_registered_pass` operation by + first always attempting to apply the xDSL pass, but if it isn't found + then it will try to run this pass in Catalyst. + """ + + def __init__(self, ctx, passes, callback=None): + super().__init__(ctx, passes) + # The signature of the callback function is assumed to be + # def callback(previous_pass: ModulePass, module: ModuleOp, next_pass: ModulePass, pass_level=None) -> None + self.callback = callback + self.pass_level = 0 + + def _pre_pass_callback(self, compilation_pass, module): + """Callback wrapper to run the callback function before the pass.""" + if not self.callback: + return + if self.pass_level == 0: + # Since this is the first pass, there is no previous pass + self.callback(None, module, compilation_pass, pass_level=0) + + def _post_pass_callback(self, compilation_pass, module): + """Increment level and run callback if defined.""" + if not self.callback: + return + self.pass_level += 1 + self.callback(compilation_pass, module, None, pass_level=self.pass_level) + + @impl(ApplyRegisteredPassOp) + def run_apply_registered_pass_op( + self, + _interpreter: Interpreter, + op: ApplyRegisteredPassOp, + args: PythonValues, + ) -> PythonValues: + """Try to run the pass in xDSL, if not found then run it in Catalyst.""" + + pass_name = op.pass_name.data + module = args[0] + + # ---- xDSL path ---- + if pass_name in self.passes: + pass_class = self.passes[pass_name]() + pass_instance = pass_class(**op.options.data) + pipeline = PassPipeline((pass_instance,)) + self._pre_pass_callback(pass_instance, module) + pipeline.apply(self.ctx, module) + self._post_pass_callback(pass_instance, module) + return (module,) + + # ---- Catalyst path ---- + buffer = io.StringIO() + Printer(stream=buffer, print_generic_format=True).print_op(module) + + schedule = f"--{pass_name}" + self._pre_pass_callback(pass_name, module) + modified = _quantum_opt(schedule, "-mlir-print-op-generic", stdin=buffer.getvalue()) + + data = Parser(self.ctx, modified).parse_module() + rewriter = Rewriter() + rewriter.replace_op(module, data) + self._post_pass_callback(pass_name, data) + return (data,) + + +class TransformInterpreterPass(ModulePass): + """Transform dialect interpreter""" + + passes: dict[str, Callable[[], type[ModulePass]]] + name = "transform-interpreter" + callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None + + entry_point: str = "__transform_main" + + def __init__(self, passes, callback): + self.passes = passes + self.callback = callback + + @staticmethod + def find_transform_entry_point(root: builtin.ModuleOp, entry_point: str) -> NamedSequenceOp: + """Find the entry point of the program""" + for op in root.walk(): + if isinstance(op, NamedSequenceOp) and op.sym_name.data == entry_point: + return op + raise PassFailedException( # pragma: no cover + f"{root} could not find a nested named sequence with name: {entry_point}" + ) + + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: + """Run the interpreter with op.""" + schedule = TransformInterpreterPass.find_transform_entry_point(op, self.entry_point) + interpreter = Interpreter(op) + interpreter.register_implementations(TransformFunctionsExt(ctx, self.passes, self.callback)) + schedule.parent_op().detach() + if self.callback: + self.callback(None, op, None, pass_level=0) + interpreter.call_op(schedule, (op,)) diff --git a/frontend/catalyst/python_interface/transforms/__init__.py b/frontend/catalyst/python_interface/transforms/__init__.py new file mode 100644 index 0000000000..ac661ba0de --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PennyLane-xDSL transformations API.""" + +from .mbqc import ( + convert_to_mbqc_formalism_pass, + ConvertToMBQCFormalismPass, + decompose_graph_state_pass, + DecomposeGraphStatePass, + outline_state_evolution_pass, + OutlineStateEvolutionPass, + null_decompose_graph_state_pass, + NullDecomposeGraphStatePass, +) + +from .quantum import ( + combine_global_phases_pass, + CombineGlobalPhasesPass, + diagonalize_final_measurements_pass, + DiagonalizeFinalMeasurementsPass, + iterative_cancel_inverses_pass, + IterativeCancelInversesPass, + measurements_from_samples_pass, + MeasurementsFromSamplesPass, + merge_rotations_pass, + MergeRotationsPass, + split_non_commuting_pass, + SplitNonCommutingPass, +) + + +__all__ = [ + # Quantum + "combine_global_phases_pass", + "CombineGlobalPhasesPass", + "diagonalize_final_measurements_pass", + "DiagonalizeFinalMeasurementsPass", + "iterative_cancel_inverses_pass", + "IterativeCancelInversesPass", + "measurements_from_samples_pass", + "MeasurementsFromSamplesPass", + "merge_rotations_pass", + "MergeRotationsPass", + "split_non_commuting_pass", + "SplitNonCommutingPass", + # MBQC + "convert_to_mbqc_formalism_pass", + "ConvertToMBQCFormalismPass", + "decompose_graph_state_pass", + "DecomposeGraphStatePass", + "OutlineStateEvolutionPass", + "outline_state_evolution_pass", + "null_decompose_graph_state_pass", + "NullDecomposeGraphStatePass", +] diff --git a/frontend/catalyst/python_interface/transforms/mbqc/__init__.py b/frontend/catalyst/python_interface/transforms/mbqc/__init__.py new file mode 100644 index 0000000000..134e9d3e76 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""xDSL transformations API specifically for the MBQC transform.""" + +from .convert_to_mbqc_formalism import ConvertToMBQCFormalismPass, convert_to_mbqc_formalism_pass +from .decompose_graph_state import ( + DecomposeGraphStatePass, + NullDecomposeGraphStatePass, + decompose_graph_state_pass, + null_decompose_graph_state_pass, +) +from .graph_state_utils import ( + edge_iter, + generate_adj_matrix, + get_graph_state_edges, + get_num_aux_wires, + n_vertices_from_packed_adj_matrix, +) +from .outline_state_evolution import OutlineStateEvolutionPass, outline_state_evolution_pass + +__all__ = [ + # Passes + "ConvertToMBQCFormalismPass", + "DecomposeGraphStatePass", + "OutlineStateEvolutionPass", + "NullDecomposeGraphStatePass", + "convert_to_mbqc_formalism_pass", + "decompose_graph_state_pass", + "null_decompose_graph_state_pass", + "outline_state_evolution_pass", + # Utils + "get_num_aux_wires", + "get_graph_state_edges", + "n_vertices_from_packed_adj_matrix", + "edge_iter", + "generate_adj_matrix", +] diff --git a/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py b/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py new file mode 100644 index 0000000000..fcd8bd81df --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py @@ -0,0 +1,733 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the convert_to_mbqc_formalism transform, +written using xDSL.""" + +import math +from dataclasses import dataclass + +from xdsl import builder, context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func, scf +from xdsl.dialects.scf import ForOp, IfOp, IndexSwitchOp, WhileOp +from xdsl.ir import SSAValue +from xdsl.ir.core import Block, OpResult, Region +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.mbqc import ( + GraphStatePrepOp, + MeasureInBasisOp, + MeasurementPlaneAttr, + MeasurementPlaneEnum, +) +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + DeallocQubitOp, + ExtractOp, + GlobalPhaseOp, + QubitType, +) +from catalyst.python_interface.pass_api import compiler_transform + +from .graph_state_utils import generate_adj_matrix, get_num_aux_wires + +_PAULIS = { + "PauliX", + "PauliY", + "PauliZ", + "Identity", +} + +_MBQC_ONE_QUBIT_GATES = { + "Hadamard", + "S", + "RZ", + "RotXZX", +} + +_MBQC_TWO_QUBIT_GATES = { + "CNOT", +} + +_MBQC_GATES = _MBQC_ONE_QUBIT_GATES | _MBQC_TWO_QUBIT_GATES + + +@dataclass(frozen=True) +class ConvertToMBQCFormalismPass(passes.ModulePass): + """Pass that converts gates in the MBQC gate set to the MBQC formalism.""" + + name = "convert-to-mbqc-formalism" + + def _prep_graph_state(self, gate_name: str): + """Add a graph state prep operation into the subroutine for each gate and extract and return auxiliary qubits + in the graph state. + + Args: + gate_name[str]: Name of gate operation. + + Return: + graph_qubit_dict : A dictionary of qubits in the graph + """ + num_aux_wres = get_num_aux_wires(gate_name) + + adj_matrix_op = arith.ConstantOp( + builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType( + builtin.IntegerType(1), shape=(len(generate_adj_matrix(gate_name)),) + ), + data=generate_adj_matrix(gate_name), + ) + ) + + graph_state_prep_op = GraphStatePrepOp(adj_matrix_op.result, "Hadamard", "CZ") + graph_state_reg = graph_state_prep_op.results[0] + + graph_qubit_dict = {} + # Extract qubit from the graph state reg + for i in range(num_aux_wres): + extract_op = ExtractOp(graph_state_reg, i) + + # Note the following line maps the aux qubit index in the register to the + # standard context book MBQC representation. Note that auxiliary qubits in + # the graph state for one qubit gates only hit the `if` branch as `i` is + # always less than `4`, while the auxiliary qubits in graph state for a `CNOT` + # gate with an index >= 7 would hit the `else` branch. + key = i + 2 if i < 7 else i + 3 + + graph_qubit_dict[key] = extract_op.results[0] + + return graph_qubit_dict + + def _insert_xy_basis_measure_op( + self, + const_angle_op: arith.ConstantOp, + qubit: QubitType, + ): + """Add an arbitrary basis measure related operations to the subroutine. + Args: + const_angle_op (arith.ConstantOp) : The angle of measurement basis. + qubit (QubitType) : The target qubit to be measured. + + Returns: + The results include: 1. a measurement result; 2, a result qubit. + """ + plane_op = MeasurementPlaneAttr(MeasurementPlaneEnum("XY")) + # Create a MeasureInBasisOp op + measure_op = MeasureInBasisOp(in_qubit=qubit, plane=plane_op, angle=const_angle_op) + # Returns the results of the newly created measure_op. + # The results include: 1, a measurement result; 2, a result qubit. + return measure_op.results + + def _insert_cond_arbitrary_basis_measure_op( + self, + meas_parity: builtin.IntegerType, + angle: SSAValue[builtin.Float64Type], + plane: str, + qubit: QubitType, + ): # pylint: disable=too-many-arguments, too-many-positional-arguments + """ + Add a conditional arbitrary basis measurement operation based on a previous measurement result. + Args: + meas_parity (builtin.IntegerType) : A parity of previous measurements. + angle (SSAValue[builtin.Float64Type]) : An angle SSAValue from a parametric gate operation. + plane (str): Plane of the measurement basis. + qubit (QubitType) : The target qubit to be measured. + + Returns: + The results include: 1. a measurement result; 2, a result qubit. + """ + constant_one_op = arith.ConstantOp.from_int_and_width(1, builtin.i1) + cond = arith.CmpiOp(meas_parity, constant_one_op, "eq") + branch = scf.IfOp( + cond, + ( + builtin.IntegerType(1), + QubitType(), + ), + Region(Block()), + Region(Block()), + ) + + plane_op = MeasurementPlaneAttr(MeasurementPlaneEnum(plane)) + + with builder.ImplicitBuilder(branch.true_region): + measure_op = MeasureInBasisOp(in_qubit=qubit, plane=plane_op, angle=angle) + scf.YieldOp(measure_op.results[0], measure_op.results[1]) + with builder.ImplicitBuilder(branch.false_region): + const_neg_angle_op = arith.NegfOp(angle) + measure_neg_op = MeasureInBasisOp( + in_qubit=qubit, plane=plane_op, angle=const_neg_angle_op.result + ) + scf.YieldOp(measure_neg_op.results[0], measure_neg_op.results[1]) + + return branch.results + + def _hadamard_measurements(self, graph_qubit_dict): + """Add measurement ops for a Hadamard gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + const_y_angle = arith.ConstantOp( + builtin.FloatAttr(data=math.pi / 2, type=builtin.Float64Type()) + ) # measure_y + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _s_measurements(self, graph_qubit_dict): + """Add measurement ops for a S gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + const_y_angle = arith.ConstantOp( + builtin.FloatAttr(data=math.pi / 2, type=builtin.Float64Type()) + ) # measure_y + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _rz_measurements(self, graph_qubit_dict, params): + """Add measurement ops for a RZ gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_cond_arbitrary_basis_measure_op( + m2, params[0], "XY", graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _rotxzx_measurements(self, graph_qubit_dict, params): + """Add measurement ops for a RotXZX gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_cond_arbitrary_basis_measure_op( + m1, params[0], "XY", graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_cond_arbitrary_basis_measure_op( + m2, params[1], "XY", graph_qubit_dict[3] + ) + + m1_xor_m3 = arith.XOrIOp(m1, m3) + + m4, graph_qubit_dict[4] = self._insert_cond_arbitrary_basis_measure_op( + m1_xor_m3.result, params[2], "XY", graph_qubit_dict[4] + ) + return [m1, m2, m3, m4], graph_qubit_dict + + def _cnot_measurements(self, graph_qubit_dict): + """Add measurement ops for a CNOT gate to the subroutine""" + const_x_angle = arith.ConstantOp( + builtin.FloatAttr(data=0.0, type=builtin.Float64Type()) + ) # measure_x + const_y_angle = arith.ConstantOp( + builtin.FloatAttr(data=math.pi / 2, type=builtin.Float64Type()) + ) # measure_y + m1, graph_qubit_dict[1] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[1] + ) + m2, graph_qubit_dict[2] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[2] + ) + m3, graph_qubit_dict[3] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[3] + ) + m4, graph_qubit_dict[4] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[4] + ) + m5, graph_qubit_dict[5] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[5] + ) + m6, graph_qubit_dict[6] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[6] + ) + m8, graph_qubit_dict[8] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[8] + ) + m9, graph_qubit_dict[9] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[9] + ) + m10, graph_qubit_dict[10] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[10] + ) + m11, graph_qubit_dict[11] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[11] + ) + m12, graph_qubit_dict[12] = self._insert_xy_basis_measure_op( + const_y_angle, graph_qubit_dict[12] + ) + m13, graph_qubit_dict[13] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[13] + ) + m14, graph_qubit_dict[14] = self._insert_xy_basis_measure_op( + const_x_angle, graph_qubit_dict[14] + ) + + return [m1, m2, m3, m4, m5, m6, m8, m9, m10, m11, m12, m13, m14], graph_qubit_dict + + def _parity_check( + self, + mres: list[builtin.IntegerType], + additional_const_one: bool = False, + ): + """Add parity check related operations to the subroutine. + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + additional_const_one (bool) : Whether we need to add an additional const one to get the + parity or not. Defaults to False. + Returns: + The result of parity check. + """ + prev_res = mres[0] + xor_op = None + # Create xor ops to iterate all elements in the mres and insert them to the IR + for i in range(1, len(mres)): + xor_op = arith.XOrIOp(prev_res, mres[i]) + prev_res = xor_op.result + + # Create an xor op for an additional const one and insert ops to the IR + if additional_const_one: + constant_one_op = arith.ConstantOp.from_int_and_width(1, builtin.i1) + xor_op = arith.XOrIOp(prev_res, constant_one_op) + prev_res = xor_op.result + + return prev_res + + def _insert_cond_byproduct_op( + self, + parity_res: OpResult, + gate_name: str, + qubit: QubitType, + ): # pylint: disable=too-many-arguments, too-many-positional-arguments + """Add a byproduct op related operations to the subroutine. + Args: + parity_res (OpResult) : Parity check result. + gate_name (str) : The name of the gate to be corrected. + qubit (QubitType) : The result auxiliary qubit to be corrected. + + Return: + The result auxiliary qubit. + """ + constant_one_op = arith.ConstantOp.from_int_and_width(1, builtin.i1) + cond = arith.CmpiOp(parity_res, constant_one_op, "eq") + branch = scf.IfOp(cond, (QubitType(),), Region(Block()), Region(Block())) + + with builder.ImplicitBuilder(branch.true_region): + byproduct_op = CustomOp(in_qubits=qubit, gate_name=gate_name) + scf.YieldOp(byproduct_op.results[0]) + with builder.ImplicitBuilder(branch.false_region): + scf.YieldOp(qubit) + return branch.results[0] + + def _hadamard_corrections( + self, + mres: list[builtin.IntegerType], + qubit: QubitType, + ): + """Add correction ops of a Hadamard gate to the subroutine. + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubit (QubitType) : An auxiliary result qubit. + + Returns: + The result auxiliary qubit. + """ + m1, m2, m3, m4 = mres + + # X correction + x_parity = self._parity_check([m1, m3, m4]) + res_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubit) + + # Z correction + z_parity = self._parity_check([m2, m3]) + res_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", res_aux_qubit) + + return res_aux_qubit + + def _s_corrections( + self, + mres: list[builtin.IntegerType], + qubit: QubitType, + ): + """Add correction ops of a S gate to the subroutine. + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubit (QubitType) : An auxiliary result qubit. + + Returns: + The result auxiliary qubit. + """ + m1, m2, m3, m4 = mres + + # X correction + x_parity = self._parity_check([m2, m4]) + res_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubit) + + # Z correction + z_parity = self._parity_check([m1, m2, m3], additional_const_one=True) + res_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", res_aux_qubit) + return res_aux_qubit + + def _rot_corrections( + self, + mres: list[builtin.IntegerType], + qubit: QubitType, + ): + """Add correction ops of a RotXZX or RZ gate to the subroutine. + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubit (QubitType) : An auxiliary result qubit. + + Returns: + The result auxiliary qubit. + """ + m1, m2, m3, m4 = mres + # X correction + x_parity = self._parity_check([m2, m4]) + res_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubit) + + # Z correction + z_parity = self._parity_check([m1, m3]) + res_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", res_aux_qubit) + return res_aux_qubit + + def _cnot_corrections( + self, + mres: list[builtin.IntegerType], + qubits: list[QubitType], + ): + """Add correction ops of a CNOT gate to the subroutine. + Args: + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubits (list[QubitType]) : A list of auxiliary result qubits. + Returns: + The result auxiliary qubits. + """ + m1, m2, m3, m4, m5, m6, m8, m9, m10, m11, m12, m13, m14 = mres + # Corrections for the control qubit + x_parity = self._parity_check([m2, m3, m5, m6]) + ctrl_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubits[0]) + z_parity = self._parity_check([m1, m3, m4, m5, m8, m9, m11], additional_const_one=True) + ctrl_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", ctrl_aux_qubit) + + # Corrections for the target qubit + x_parity = self._parity_check([m2, m3, m8, m10, m12, m14]) + tgt_aux_qubit = self._insert_cond_byproduct_op(x_parity, "PauliX", qubits[1]) + z_parity = self._parity_check([m9, m11, m13]) + tgt_aux_qubit = self._insert_cond_byproduct_op(z_parity, "PauliZ", tgt_aux_qubit) + + return ctrl_aux_qubit, tgt_aux_qubit + + def _queue_measurements( + self, gate_name: str, graph_qubit_dict, params: None | list[builtin.Float64Type] = None + ): + """Add measurement ops to the subroutine. + Args: + gate_name (str): Gate name. + graph_qubit_dict (list[builtin.IntegerType]): A list of the mid-measurement results. + params (None | list[builtin.Float64Type]) : Parameters of the gate. + + Returns: + The measurement results and updated graph_qubit_dict. + + """ + match gate_name: + case "Hadamard": + return self._hadamard_measurements(graph_qubit_dict) + case "S": + return self._s_measurements(graph_qubit_dict) + case "RZ": + return self._rz_measurements(graph_qubit_dict, params) + case "RotXZX": + return self._rotxzx_measurements(graph_qubit_dict, params) + case "CNOT": + return self._cnot_measurements(graph_qubit_dict) + case _: + raise ValueError( + f"{gate_name} is not supported in the MBQC formalism. Please decompose it into the MBQC gate set." + ) + + def _insert_byprod_corrections( + self, + gate_name: str, + mres: list[builtin.IntegerType], + qubits: QubitType | list[QubitType], + ): + """Add correction ops for the result auxiliary qubit/s to the subroutine. + Args: + gate_name (str): Gate name. + mres (list[builtin.IntegerType]): A list of the mid-measurement results. + qubits (QubitType | list[QubitType]) : An or a list of auxiliary result qubit. + + Returns: + The result auxiliary qubits. + """ + match gate_name: + case "Hadamard": + return self._hadamard_corrections(mres, qubits) + case "S": + return self._s_corrections(mres, qubits) + case "RotXZX": + return self._rot_corrections(mres, qubits) + case "RZ": + return self._rot_corrections(mres, qubits) + case "CNOT": + return self._cnot_corrections(mres, qubits) + case _: + raise ValueError( + f"{gate_name} is not supported in the MBQC formalism. Please decompose it into the MBQC gate set." + ) + + def _create_single_qubit_gate_subroutine(self, gate_name: str): + """Create a subroutine for a single qubit gate based on the given name. + Args: + gate_name (str): Name of the gate. + + Returns: + The corresponding subroutine (func.FuncOp). + """ + if gate_name not in _MBQC_ONE_QUBIT_GATES: + raise NotImplementedError(f"Subroutine for the {gate_name} gate is not supported.") + # ensure the order of parameters are aligned with customOp + input_types = () + if gate_name == "RZ": + input_types += (builtin.Float64Type(),) + if gate_name == "RotXZX": + input_types += (builtin.Float64Type(),) * 3 + input_types += (QubitType(),) + + output_types = (QubitType(),) + block = Block(arg_types=input_types) + + with builder.ImplicitBuilder(block): + in_qubits = [block.args[-1]] + params = None + if gate_name == "RZ": + params = [block.args[0]] + if gate_name == "RotXZX": + params = [ + block.args[0], + block.args[1], + block.args[2], + ] + + graph_qubit_dict = self._prep_graph_state(gate_name=gate_name) + + cz_op = CustomOp(in_qubits=[in_qubits[0], graph_qubit_dict[2]], gate_name="CZ") + + graph_qubit_dict[1], graph_qubit_dict[2] = cz_op.results + + mres, graph_qubit_dict = self._queue_measurements(gate_name, graph_qubit_dict, params) + + # The following could be removed to support Pauli tracker + by_product_correction = self._insert_byprod_corrections( + gate_name, mres, graph_qubit_dict[5] + ) + + graph_qubit_dict[5] = by_product_correction + + for node in graph_qubit_dict: + if node not in [5]: + _ = DeallocQubitOp(graph_qubit_dict[node]) + + func.ReturnOp(graph_qubit_dict[5]) + + region = Region([block]) + # Note that visibility is set as private to ensure the subroutines that are + # not called (dead code) can be eliminated as the ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # pass was added to the pipeline. + funcOp = func.FuncOp( + gate_name.lower() + "_in_mbqc", + (input_types, output_types), + visibility="private", + region=region, + ) + # Add an attribute to the mbqc transform subroutine + funcOp.attributes["mbqc_transform"] = builtin.NoneAttr() + return funcOp + + def _create_cnot_gate_subroutine(self): + """Create a subroutine for a CNOT gate.""" + gate_name = "CNOT" + input_types = ( + QubitType(), + QubitType(), + ) + output_types = ( + QubitType(), + QubitType(), + ) + block = Block(arg_types=input_types) + + with builder.ImplicitBuilder(block): + in_qubits = [block.args[0], block.args[1]] + + graph_qubit_dict = self._prep_graph_state(gate_name=gate_name) + + # Entangle the op.in_qubits[0] with the graph_qubits_dict[2] + cz_op = CustomOp(in_qubits=[in_qubits[0], graph_qubit_dict[2]], gate_name="CZ") + graph_qubit_dict[1], graph_qubit_dict[2] = cz_op.results + + # Entangle op.in_qubits[1] with with the graph_qubits_dict[10] for a CNOT gate + cz_op = CustomOp(in_qubits=[in_qubits[1], graph_qubit_dict[10]], gate_name="CZ") + graph_qubit_dict[9], graph_qubit_dict[10] = cz_op.results + + mres, graph_qubit_dict = self._queue_measurements(gate_name, graph_qubit_dict) + + # The following could be removed to support Pauli tracker + graph_qubit_dict[7], graph_qubit_dict[15] = self._insert_byprod_corrections( + gate_name, mres, [graph_qubit_dict[7], graph_qubit_dict[15]] + ) + + for node in graph_qubit_dict: + if node not in [7, 15]: + _ = DeallocQubitOp(graph_qubit_dict[node]) + + func.ReturnOp( + *( + graph_qubit_dict[7], + graph_qubit_dict[15], + ) + ) + + region = Region([block]) + # Note that visibility is set as private to ensure the subroutines that are + # not called (dead code) can be eliminated as the ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # pass was added to the pipeline. + funcOp = func.FuncOp( + gate_name.lower() + "_in_mbqc", + (input_types, output_types), + visibility="private", + region=region, + ) + # Add an attribute to the mbqc transform subroutine + funcOp.attributes["mbqc_transform"] = builtin.NoneAttr() + return funcOp + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the convert-to-mbqc-formalism pass.""" + # Insert subroutines for all gates in the MBQC gate set to the module. + # Note that the visibility of those subroutines are set as private, which ensure + # the ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # pass could eliminate the unreferenced subroutines. + subroutine_dict = {} + + for gate_name in _MBQC_ONE_QUBIT_GATES: + funcOp = self._create_single_qubit_gate_subroutine(gate_name) + module.regions[0].blocks.first.add_op(funcOp) + subroutine_dict[gate_name] = funcOp + + cnot_funcOp = self._create_cnot_gate_subroutine() + module.regions[0].blocks.first.add_op(cnot_funcOp) + subroutine_dict["CNOT"] = cnot_funcOp + + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier( + [ + ConvertToMBQCFormalismPattern(subroutine_dict), + ] + ), + apply_recursively=False, + ).rewrite_module(module) + + +convert_to_mbqc_formalism_pass = compiler_transform(ConvertToMBQCFormalismPass) + + +class ConvertToMBQCFormalismPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods,no-self-use + """RewritePattern for converting to the MBQC formalism.""" + + def __init__(self, subroutines_dict): + self.subroutine_dict = subroutines_dict + + # pylint: disable=no-self-use + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, + root: func.FuncOp | IfOp | WhileOp | ForOp | IndexSwitchOp, + rewriter: pattern_rewriter.PatternRewriter, + ): + """Match and rewrite for converting to the MBQC formalism.""" + + # Ensure that "Hadamard"/"CZ" gates in mbqc_transform subroutines are not converted. + if isinstance(root, func.FuncOp) and "mbqc_transform" in root.attributes: + return + + for region in root.regions: + # Continue if the region has no block (i.e., function that has no body, and the body is + # defined in runtime.) + if not region.blocks: + continue + + for op in region.ops: + if isinstance(op, CustomOp) and op.gate_name.data in _MBQC_GATES: + callee = builtin.SymbolRefAttr(op.gate_name.data.lower() + "_in_mbqc") + arguments = [] + for param in op.params: + arguments.append(param) + for qubit in op.in_qubits: + arguments.append(qubit) + + return_types = self.subroutine_dict[ + op.gate_name.data + ].function_type.outputs.data + callOp = func.CallOp(callee, arguments, return_types) + rewriter.insert_op(callOp, InsertPoint.before(op)) + for i, out_qubit in enumerate(op.out_qubits): + rewriter.replace_all_uses_with(out_qubit, callOp.results[i]) + rewriter.erase_op(op) + + elif isinstance(op, GlobalPhaseOp) or ( + isinstance(op, CustomOp) and op.gate_name.data in _PAULIS + ): + continue + elif isinstance(op, CustomOp): + raise NotImplementedError( + f"{op.gate_name.data} cannot be converted to the MBQC formalism." + ) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py b/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py new file mode 100644 index 0000000000..294c0dd467 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py @@ -0,0 +1,209 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the implementations of the decompose_graph_state and +null_decompose_graph_state transforms, written using xDSL. + +.. note:: + + The transforms contained in this module make frequent use of the *densely packed adjacency + matrix* graph representation. For a detailed description of this graph representation, see the + documentation for the GraphStatePrepOp operation in the MBQC dialect in Catalyst. +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TypeAlias + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern + +from catalyst.python_interface.dialects import mbqc, quantum +from catalyst.python_interface.pass_api import compiler_transform + +from .graph_state_utils import edge_iter, n_vertices_from_packed_adj_matrix + +DenselyPackedAdjMatrix: TypeAlias = Sequence[int] | Sequence[bool] + + +@dataclass(frozen=True) +class DecomposeGraphStatePass(passes.ModulePass): + """The decompose-graph-state pass replaces ``graph_state_prep`` operations with their + corresponding sequence of quantum operations for execution on state simulators. + """ + + name = "decompose-graph-state" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the decompose-graph-state pass.""" + + walker = pattern_rewriter.PatternRewriteWalker(DecomposeGraphStatePattern()) + walker.rewrite_module(module) + + +decompose_graph_state_pass = compiler_transform(DecomposeGraphStatePass) + + +# pylint: disable=too-few-public-methods +class DecomposeGraphStatePattern(RewritePattern): + """Rewrite pattern for the decompose-graph-state transform.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, graph_prep_op: mbqc.GraphStatePrepOp, rewriter: PatternRewriter, /): + """Match and rewrite pattern for graph_state_prep ops.""" + # These are the names of the gates that realize the desired initial individual qubit state + # and entangled state, respectively + init_op_gate_name = graph_prep_op.init_op.data + entangle_op_gate_name = graph_prep_op.entangle_op.data + + adj_matrix = _parse_adj_matrix(graph_prep_op) + n_vertices = n_vertices_from_packed_adj_matrix(adj_matrix) + + # Allocate a register with as many qubits as vertices in the graph + alloc_op = quantum.AllocOp(n_vertices) + rewriter.insert_op(alloc_op) + + # This dictionary maps wires indices in the register to qubit SSA values + graph_qubits_map: dict[int, quantum.QubitSSAValue] = {} + + # In this section, we create the sequences of quantum.extract, quantum.custom (for the init + # and entangle gates) and quantum.insert ops, and gather them up into list for insertion + # later on. + qextract_ops: list[quantum.ExtractOp] = [] + for i in range(n_vertices): + qextract_op = quantum.ExtractOp(alloc_op.qreg, i) + qextract_ops.append(qextract_op) + graph_qubits_map[i] = qextract_op.qubit + + init_ops: list[quantum.CustomOp] = [] + for i in range(n_vertices): + init_op = quantum.CustomOp(in_qubits=graph_qubits_map[i], gate_name=init_op_gate_name) + init_ops.append(init_op) + graph_qubits_map[i] = init_op.out_qubits[0] + + entangle_ops: list[quantum.CustomOp] = [] + for edge in edge_iter(adj_matrix): + q0 = graph_qubits_map[edge[0]] + q1 = graph_qubits_map[edge[1]] + entangle_op = quantum.CustomOp(in_qubits=(q0, q1), gate_name=entangle_op_gate_name) + entangle_ops.append(entangle_op) + graph_qubits_map[edge[0]] = entangle_op.out_qubits[0] + graph_qubits_map[edge[1]] = entangle_op.out_qubits[1] + + qinsert_ops: list[quantum.InsertOp] = [] + qreg = alloc_op.qreg + for i in range(n_vertices): + qinsert_op = quantum.InsertOp(in_qreg=qreg, idx=i, qubit=graph_qubits_map[i]) + qinsert_ops.append(qinsert_op) + qreg = qinsert_op.out_qreg + + # In this section, we iterate over the ops created above and insert them. + # Note that we do not need to specify the insertion point here; all ops are inserted before + # the matched op, automatically putting them in the order we want them in. + for qextract_op in qextract_ops: + rewriter.insert_op(qextract_op) + + for init_op in init_ops: + rewriter.insert_op(init_op) + + for entangle_op in entangle_ops: + rewriter.insert_op(entangle_op) + + for qinsert_op in qinsert_ops: + rewriter.insert_op(qinsert_op) + + # The register that is the result of the last quantum.insert op replaces the register that + # was the result of the graph_state_prep op + rewriter.replace_all_uses_with(graph_prep_op.results[0], qinsert_ops[-1].results[0]) + + # Finally, erase the ops that have now been replaced with quantum ops + rewriter.erase_matched_op() + + # Erase the constant op that returned the adjacency matrix only if it has no other uses + if graph_prep_op.adj_matrix.uses.get_length() == 0: + rewriter.erase_op(graph_prep_op.adj_matrix.owner) + + +@dataclass(frozen=True) +class NullDecomposeGraphStatePass(passes.ModulePass): + """The null-decompose-graph-state pass replaces ``graph_state_prep`` operations with a single + quantum-register allocation operation for execution on null devices. + """ + + name = "null-decompose-graph-state" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the null-decompose-graph-state pass.""" + + walker = pattern_rewriter.PatternRewriteWalker(NullDecomposeGraphStatePattern()) + walker.rewrite_module(module) + + +null_decompose_graph_state_pass = compiler_transform(NullDecomposeGraphStatePass) + + +# pylint: disable=too-few-public-methods +class NullDecomposeGraphStatePattern(RewritePattern): + """Rewrite pattern for the null-decompose-graph-state transform.""" + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, graph_prep_op: mbqc.GraphStatePrepOp, rewriter: PatternRewriter, /): + """Match and rewrite pattern for graph_state_prep ops.""" + adj_matrix = _parse_adj_matrix(graph_prep_op) + n_vertices = n_vertices_from_packed_adj_matrix(adj_matrix) + + # Allocate a register with as many qubits as vertices in the graph + alloc_op = quantum.AllocOp(n_vertices) + rewriter.insert_op(alloc_op) + + # The newly allocated register replaces the register that was the result of the + # graph_state_prep op + rewriter.replace_all_uses_with(graph_prep_op.results[0], alloc_op.results[0]) + + # Finally, erase the ops that have now been replaced with quantum ops + rewriter.erase_matched_op() + + # Erase the constant op that returned the adjacency matrix only if it has no other uses + if graph_prep_op.adj_matrix.uses.get_length() == 0: + rewriter.erase_op(graph_prep_op.adj_matrix.owner) + + +def _parse_adj_matrix(graph_prep_op: mbqc.GraphStatePrepOp) -> list[int]: + """Parse the adjacency matrix from the result of the ConstantOp given as input to the + graph_state_prep op. + + We assume that the adjacency matrix is stored as a DenseIntOrFPElementsAttr, whose data is + accessible as a Python 'bytes' array. Converting this bytes array to a list results in integer + elements, whose values are typically either 0 for 'false' or 255 for 'true'. + + Returns: + list[int]: The densely packed adjacency matrix as a list of ints. See the note in the module + documentation for a description of this format. + """ + adj_matrix_const_op = graph_prep_op.adj_matrix.owner + adj_matrix_value = adj_matrix_const_op.properties.get("value") + assert adj_matrix_value is not None and hasattr( + adj_matrix_value, "data" + ), f"Unable to read graph adjacency matrix from op `{adj_matrix_const_op}`" + + adj_matrix_bytes = adj_matrix_value.data + assert isinstance(adj_matrix_bytes, builtin.BytesAttr), ( + f"Expected graph adjacency matrix data to be of type 'builtin.BytesAttr', but got " + f"{type(adj_matrix_bytes).__name__}" + ) + + return list(adj_matrix_bytes.data) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py b/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py new file mode 100644 index 0000000000..bf13b09971 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py @@ -0,0 +1,256 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for graph state representations""" + +import math +from collections.abc import Generator, Sequence +from typing import TypeAlias + +from pennylane.exceptions import CompileError + +DenselyPackedAdjMatrix: TypeAlias = Sequence[int] | Sequence[bool] + + +_MBQC_GATE_SET = {"Hadamard", "S", "RZ", "RotXZX", "CNOT"} +_SINGLE_QUBIT_AUX_WIRE_NUM = 4 +_CNOT_AUX_WIRE_NUM = 13 + + +def get_num_aux_wires(gate_name: str) -> int: + """ + Return the number of auxiliary wires required for gates from the MBQC gate set. + The number of auxiliary qubits for a single qubit gate is 4, while it is 13 for a + CNOT gate. + + Args: + gate_name (str): The name of a gate. + + Returns: + The number of auxiliary wires. + """ + if gate_name == "CNOT": + return _CNOT_AUX_WIRE_NUM + if gate_name in _MBQC_GATE_SET: + return _SINGLE_QUBIT_AUX_WIRE_NUM + raise ValueError(f"{gate_name} is not supported in the MBQC formalism.") + + +def get_graph_state_edges(gate_name: str) -> list[tuple[int, int]]: + """ + Return a list of edges information in the graph state of a gate. + + - The connectivity of the target qubits in the register and auxiliary qubits for a single-qubit gate is: + + tgt -- 0 -- 1 -- 2 -- 3 + + Note that the target qubit is not in the adjacency matrix and the connectivity + of the auxiliary qubits is: + edges_in_adj_matrix = [ + (0, 1), + (1, 2), + (2, 3), + ] + + Wire 1 in the above isn't the target wire described in the Fig.2 of [`arXiv:quant-ph/0301052 `_], + 1 in the above maps to 3 in the figure. + + - The connectivity of the ctrl/target qubits in the register and auxiliary qubits for a CNOT gate is: + + ctl -- 0 -- 1 -- 2 -- 3 -- 4 -- 5 + | + 6 + | + tgt -- 7 -- 8 -- 9 -- 10 -- 11 -- 12 + + Note that both ctrl and target qubits are not in the adjacency matrix and the connectivity + of the auxiliary qubits is: + edges_in_adj_matrix = [ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (2, 6), + (7, 8), + (6, 9), + (8, 9), + (9, 10), + (10, 11), + (11, 12), + ] + + This graph is labelled based on the rows and columns of the adjacent matrix, but maps on to the graph described in + the Fig.2 of [`arXiv:quant-ph/0301052 `_], where wire 1 is the control and + wire 9 is the target. + + Args: + gate_name (str): The name of a gate. + + Returns: + A list of edges information in the graph state for the given gate. + """ + + if gate_name == "CNOT": + return [ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (2, 6), + (7, 8), + (6, 9), + (8, 9), + (9, 10), + (10, 11), + (11, 12), + ] + if gate_name in _MBQC_GATE_SET: + return [ + (0, 1), + (1, 2), + (2, 3), + ] + raise ValueError(f"{gate_name} is not supported in the MBQC formalism.") + + +def n_vertices_from_packed_adj_matrix(adj_matrix: DenselyPackedAdjMatrix) -> int: + """Returns the number of vertices in the graph represented by the given densely packed adjacency + matrix. + + Args: + adj_matrix (DenselyPackedAdjMatrix): The densely packed adjacency matrix, given as a + sequence of bools or ints. See the note in the module documentation for a description of + this format. + + Raises: + CompileError: If the number of elements in `adj_matrix` is not compatible with the number of + elements in the lower-triangular part of a square matrix, excluding the elements along + the diagonal. + + Returns: + int: The number of vertices in the graph. + + Example: + >>> _n_vertices_from_packed_adj_matrix([1, 1, 0, 0, 1, 1]) + 4 + """ + assert isinstance( + adj_matrix, Sequence + ), f"Expected `adj_matrix` to be a sequence, but got {type(adj_matrix).__name__}" + + m = len(adj_matrix) + + # The formula to compute the number of vertices, N, in the graph from the number elements in the + # densely packed adjacency matrix, m, is + # N = (1 + sqrt(1 + 8m)) / 2 + # To avoid floating-point errors in the sqrt function, we break it down into integer-arithmetic + # operations and ensure that the solution is one where N is mathematically a true integer. + + discriminant = 1 + 8 * m + sqrt_discriminant = math.isqrt(discriminant) + + # Check if it's a perfect square + if sqrt_discriminant * sqrt_discriminant != discriminant: + raise CompileError( + f"The number of elements in the densely packed adjacency matrix is {m}, which does not " + f"correspond to an integer number of graph vertices" + ) + + # The numerator, 1 + sqrt(1 + 8m), must be even for the result to be an integer. The quantity + # sqrt(1 + 8m) will always be odd if it's a perfect square, so the quantity (1 + sqrt(1 + 8m)) + # will always be even. We can therefore safely divide (using integer division). + return (1 + sqrt_discriminant) // 2 + + +def edge_iter(adj_matrix: DenselyPackedAdjMatrix) -> Generator[tuple[int, int], None, None]: + """Generate an iterator over the edges in a graph represented by the given densely packed + adjacency matrix. + + Args: + adj_matrix (DenselyPackedAdjMatrix): The densely packed adjacency matrix, given as a + sequence of bools or ints. See the note in the module documentation for a description of + this format. + + Yields: + tuple[int, int]: The next edge in the graph, represented as the pair of vertices labelled + according to their indices in the adjacency matrix. + + Example: + >>> for edge in _edge_iter([1, 1, 0, 0, 1, 1]): + ... print(edge) + (0, 1) + (0, 2) + (1, 3) + (2, 3) + """ + # Calling `_n_vertices_from_packed_adj_matrix()` asserts that the input `adj_matrix` is in the + # correct format and is valid. + n_vertices_from_packed_adj_matrix(adj_matrix) + + j = 1 + k = 0 + for entry in adj_matrix: + if entry: + yield (k, j) + k += 1 + if k == j: + k = 0 + j += 1 + + +def _adj_matrix_generation_helper( + num_vertices: int, edges_in_adj_matrix: list[tuple[int, int]] +) -> list: + """Helper function to generate an adjacency matrix to represent the connectivity of auxiliary qubits in + a graph state for a gate operation with the number of vertices and edges information. + Note that the adjacency matrix here means the lower triangular part of the full adjacency matrix. + It can be represented as below and `x` marks here denotes the matrix diagonal. + x + + x + + + x + . + ........ + . + + + + + + x + + Args: + num_vertices (int) : Number of vertices in the adjacency matrix. + edges_in_adj_matrix (list[tuple[int, int]]): List of edges in the adjacency matrix. + + Return: + An adjacency matrix represents the connectivity of vertices. + """ + adj_matrix_length = num_vertices * (num_vertices - 1) // 2 + adj_matrix = [0] * adj_matrix_length + for edge in edges_in_adj_matrix: + col, row = edge + n = col + (row - 1) * row // 2 + adj_matrix[n] = 1 + + return adj_matrix + + +def generate_adj_matrix(op_name: str) -> list: + """Generate an adjacency matrix represents the connectivity of auxiliary qubits in a + graph state for a gate operation. + Args: + op_name (str): The gate name. Note that only a gate in the MBQC gate set is supported. + Returns: + An adjacent matrix represents the connectivity of auxiliary qubits. + """ + num_aux_wires = get_num_aux_wires(op_name) + edges_list = get_graph_state_edges(op_name) + return _adj_matrix_generation_helper(num_aux_wires, edges_list) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py b/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py new file mode 100644 index 0000000000..dc9fbcfc98 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py @@ -0,0 +1,470 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the outline_state_evolution transform. + +Known limitations +----------------- + + * If the current pass is applied multiple times, the transform will fail as it would redefined the `state_evolution` func. This is + caused by the way we define the terminal_boundary_op. Each time the pass is applied to the IR, it would insert a new + terminal_boundary_op into the IR. TODOs: Instead of inserting a new `terminal_boundary_op` op to the IR when applying the pass, it + would be better to: 1. define a quantum.terminator op before this pass and use it as a delineation of quantum gate operation; + 2. move the `simplify_io` to a separate pass. +""" + +from dataclasses import dataclass +from itertools import chain + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func +from xdsl.ir import Operation, SSAValue +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform + + +@dataclass(frozen=True) +class OutlineStateEvolutionPass(passes.ModulePass): + """Pass that puts gate operations into an outline_state_evolution callable.""" + + name = "outline-state-evolution" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the outline-state-evolution pass.""" + for op in module.ops: + if isinstance(op, func.FuncOp) and "qnode" in op.attributes: + rewriter = pattern_rewriter.PatternRewriter(op) + OutlineStateEvolutionPattern().match_and_rewrite(op, rewriter) + + +outline_state_evolution_pass = compiler_transform(OutlineStateEvolutionPass) + + +class OutlineStateEvolutionPattern(pattern_rewriter.RewritePattern): + """RewritePattern for outlined state evolution regions in a quantum function.""" + + # pylint: disable=too-few-public-methods + def _get_parent_module(self, op: func.FuncOp) -> builtin.ModuleOp: + """Get the first ancestral builtin.ModuleOp op of a given func.func op.""" + while (op := op.parent_op()) and not isinstance(op, builtin.ModuleOp): + pass + if op is None: + raise RuntimeError( + "The given qnode func is not nested within a builtin.module. Please ensure the qnode func is defined in a builtin.module." + ) + return op + + def __init__(self): + self.module: builtin.ModuleOp = None + self.original_func_op: func.FuncOp = None + + # To determine the boundary of quantum gate operations in the IR + self.alloc_op: quantum.AllocOp = None + self.terminal_boundary_op: Operation = None + + # Input and outputs of the state evolution func + self.required_inputs: list[SSAValue] = None + self.required_outputs: list[SSAValue] = None + + # State evolution function region + self.state_evolution_func: func.FuncOp = None + + # pylint: disable=too-many-arguments + def match_and_rewrite(self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + """Transform a quantum function (qnode) to outline state evolution regions. + This implementation assumes that there is only one `quantum.alloc` operation in + the func operations with a "qnode" attribute and all quantum operations are between + the unique `quantum.alloc` operation and the terminal_boundary_op. All operations in between + are to be moved to the newly created outline-state-evolution function operation.""" + if "qnode" not in func_op.attributes: + return + + self.original_func_op = func_op + + self.module = self._get_parent_module(func_op) + + # Simplify the quantum I/O to use only registers at boundaries + self._simplify_quantum_io(func_op, rewriter) + + # Create a new function op for the state evolution region and insert it + # into the parent scope of the original func with qnode attribute + self._create_state_evolution_function(rewriter) + + # Replace the original region with a call to the state evolution function + # by inserting the corresponding callOp and update the rest of operations + # in the qnode func. + self._finalize_transformation() + + # pylint: disable=no-else-return + def _get_qubit_idx(self, op: Operation) -> int | None: + """Get the index of qubit that an ExtractOp op extracts.""" + if getattr(op, "idx", None): + return op.idx + return getattr(op, "idx_attr", None) + + def _set_up_terminal_boundary_op( + self, + current_reg: quantum.QuregType, + terminal_boundary_op: Operation | None, + qubit_to_reg_idx: dict, + op: Operation, + rewriter: pattern_rewriter.PatternRewriter, + ): + """Set up the terminal boundary operation. This terminal_boundary_op is set as the last + quantum.insert operations added to the IR.""" + insert_ops = set() + + # Insert all qubits recorded in the qubit_to_reg_idx dict before the + # pre-assumed terminal operations. + insertion_point = InsertPoint.before(op) + for qb, idx in qubit_to_reg_idx.items(): + insert_op = quantum.InsertOp(current_reg, idx, qb) + rewriter.insert_op(insert_op, insertion_point=insertion_point) + insert_ops.add(insert_op) + terminal_boundary_op = insert_op + current_reg = insert_op.out_qreg + + # Add the `"terminal_boundary"` attribute to the last newly added + # `quantum.insert` operation. + if terminal_boundary_op is None: + raise RuntimeError("A terminal_boundary_op op is not found in the circuit.") + terminal_boundary_op.attributes["terminal_boundary"] = builtin.UnitAttr() + prev_qreg = terminal_boundary_op.in_qreg + + # extract ops + insertion_point = InsertPoint.before(op) + for qb, idx in list(qubit_to_reg_idx.items()): + extract_op = quantum.ExtractOp(current_reg, idx) + rewriter.insert_op(extract_op, insertion_point=insertion_point) + qb.replace_by_if(extract_op.qubit, lambda use: use.operation not in insert_ops) + for use in qb.uses: + rewriter.notify_op_modified(use.operation) + # update the qubit_to_reg_idx dict + qubit_to_reg_idx[extract_op.qubit] = idx + # pop out qb from the dict + del qubit_to_reg_idx[qb] + return current_reg, prev_qreg, terminal_boundary_op + + # pylint: disable=cell-var-from-loop, too-many-branches + def _simplify_quantum_io( + self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter + ) -> func.FuncOp: + """Simplify quantum I/O to use only registers at segment boundaries. + + This ensures that state evolution regions only take registers as input/output, + not individual qubits. + """ + current_reg = None + # Note that all qubits recorded in the `qubit_to_reg_idx` will be inserted into + # the IR and the last insert_op will be set as the `terminal_boundary_op`.` + qubit_to_reg_idx = {} + terminal_boundary_op = None + terminal_op_in_reg = None + + for op in func_op.body.ops: + match op: + case quantum.AllocOp(): + current_reg = op.qreg + case quantum.ExtractOp(): + # Update register mapping + extract_idx = self._get_qubit_idx(op) + qubit_to_reg_idx[op.qubit] = extract_idx + # branch to update extract_op with new qreg + if op.qreg is terminal_op_in_reg: + insertion_point = InsertPoint.before(op) + extract_op = quantum.ExtractOp(current_reg, extract_idx) + rewriter.insert_op(extract_op, insertion_point=insertion_point) + rewriter.replace_all_uses_with(op.results[0], extract_op.results[0]) + rewriter.erase_op(op) + + case quantum.MeasureOp(): + # TODOs: what if the qubit that quantum.measure target at is reset? + # Not a concern by EOY 2025 + qubit_to_reg_idx[op.out_qubit] = qubit_to_reg_idx[op.in_qubit] + del qubit_to_reg_idx[op.in_qubit] + case quantum.CustomOp(): + # To update the qubit_to_reg_idx map for the return type. + for i, qb in enumerate(chain(op.in_qubits, op.in_ctrl_qubits)): + qubit_to_reg_idx[op.results[i]] = qubit_to_reg_idx[qb] + del qubit_to_reg_idx[qb] + case quantum.InsertOp(): + if not terminal_op_in_reg: + if op.idx_attr and qubit_to_reg_idx[op.qubit] is not op.idx_attr: + raise ValueError("op.qubit should be op.idx_attr.") + del qubit_to_reg_idx[op.qubit] + current_reg = op.out_qreg + # branch to update insert_op with new qreg + if op.in_qreg is terminal_op_in_reg: + insertion_point = InsertPoint.before(op) + index = op.idx if op.idx else op.idx_attr + insert_op = quantum.InsertOp(current_reg, index, op.qubit) + rewriter.insert_op(insert_op, insertion_point=insertion_point) + rewriter.replace_all_uses_with(op.out_qreg, insert_op.out_qreg) + rewriter.erase_op(op) + + case _ if ( + isinstance( + op, + ( + quantum.ComputationalBasisOp, + quantum.NamedObsOp, + quantum.HamiltonianOp, + quantum.TensorOp, + ), + ) + and not terminal_boundary_op + ): + current_reg, terminal_op_in_reg, terminal_boundary_op = ( + self._set_up_terminal_boundary_op( + current_reg, terminal_boundary_op, qubit_to_reg_idx, op, rewriter + ) + ) + case _: + # Handle other operations that might has qreg result + # Note that this branch might not be tested so far as adjoint op is not + # tested so far. + if reg := next( + (reg for reg in op.results if isinstance(reg.type, quantum.QuregType)), None + ): + current_reg = reg + + def _create_state_evolution_function(self, rewriter: pattern_rewriter.PatternRewriter): + """Create a new func.func for the state evolution region using clone approach.""" + + alloc_op, terminal_boundary_op = self._find_evolution_range() + + # collect operation from alloc_op to terminal_boundary_op + ops_to_clone = self._collect_operations_in_range(alloc_op, terminal_boundary_op) + + # collect required inputs for the state evolution funcOp + required_inputs = self._collect_required_inputs_for_state_evolution_func(ops_to_clone) + + # collect required outputs for the state evolution funcOp + required_outputs = self._collect_required_outputs(ops_to_clone, terminal_boundary_op) + + register_inputs = [] + other_inputs = [] + for val in required_inputs: + if isinstance(val.type, quantum.QuregType): + register_inputs.append(val) + else: + other_inputs.append(val) + + register_outputs = [] + other_outputs = [] + for val in required_outputs: + if isinstance(val.type, quantum.QuregType): + register_outputs.append(val) + else: + other_outputs.append(val) + + ordered_inputs = register_inputs + other_inputs + ordered_outputs = register_outputs + other_outputs + + input_types = [val.type for val in ordered_inputs] + output_types = [val.type for val in ordered_outputs] + fun_type = builtin.FunctionType.from_lists(input_types, output_types) + + # create a new func.func op and insert it into the IR + state_evolution_func = func.FuncOp( + self.original_func_op.sym_name.data + ".state_evolution", fun_type, visibility="public" + ) + rewriter.insert_op(state_evolution_func, InsertPoint.at_end(self.module.body.block)) + + # TODOs: how to define the `value_mapper` arg is not stated in the xdl.core module [here](https://github.com/xdslproject/xdsl/blob/e1301e0204bcf6ea5ed433e7da00bee57d07e695/xdsl/ir/core.py#L1429)_. + # It looks like storing ssa value to be cloned would maintain the dependency relationship required to build the new DAG for the new ops. + block = state_evolution_func.regions[0].block + value_mapper = {} # only args ssavlue is required + for input, block_arg in zip(ordered_inputs, block.args): + value_mapper[input] = block_arg + + self._clone_operations_to_block(ops_to_clone, block, value_mapper) + self._add_return_statement(block, ordered_outputs, value_mapper) + + self.required_inputs = ordered_inputs + self.required_outputs = ordered_outputs + self.alloc_op = alloc_op + self.terminal_boundary_op = terminal_boundary_op + self.state_evolution_func = state_evolution_func + + def _find_evolution_range(self): + """find alloc_op and terminal_boundary_op""" + alloc_op = None + terminal_boundary_op = None + + for op in self.original_func_op.body.walk(): + if isinstance(op, quantum.AllocOp): + alloc_op = op + elif hasattr(op, "attributes") and "terminal_boundary" in op.attributes: + terminal_boundary_op = op + + if not (alloc_op and terminal_boundary_op): + raise RuntimeError("Could not find both alloc_op and terminal_boundary_op") + + return alloc_op, terminal_boundary_op + + def _collect_operations_in_range(self, begin_op, end_op): + """collect top-level operations in range, let op.clone() handle nesting""" + ops_to_clone = [] + + if begin_op.parent_block() != end_op.parent_block(): + raise ValueError("begin_op and end_op are not in the same block") + + block = begin_op.parent_block() + + # skip until begin_op + op_iter = iter(block.ops) + while (op := next(op_iter, None)) != begin_op: + pass + + # collect top-level operations until end_op + while (op := next(op_iter, None)) != end_op: + ops_to_clone.append(op) + + # collect the terminal_boundary_op + if op is not None: + ops_to_clone.append(op) + + return ops_to_clone + + def _collect_required_inputs_for_state_evolution_func( + self, ops: list[Operation] + ) -> list[SSAValue]: + """Collect required inputs for the state evolution funcOp with a given list of operations. + Note that this method does not intent to keep the order of required input SSAValues. + """ + ops_walk = list(chain(*[op.walk() for op in ops])) + + # a set records the ssa values defined by the ops list + ops_defined_values = set() + # a set records all the ssa values required for all operations in the ops list + all_operands = set() + + for nested_op in ops_walk: + ops_defined_values.update(nested_op.results) + all_operands.update(nested_op.operands) + + if hasattr(nested_op, "regions") and nested_op.regions: + for region in nested_op.regions: + for block in region.blocks: + ops_defined_values.update(block.args) + + # the ssa values not defined by the operations in the ops list + missing_defs = list(all_operands - ops_defined_values) + required_inputs = [v for v in missing_defs if v is not None] + + return required_inputs + + def _collect_required_outputs( + self, ops: list[Operation], terminal_op: Operation + ) -> list[SSAValue]: + """Get required outputs for the state evolution funcOp with a given list of operations. + Note: It only considers the values that are defined in the operations and required by + the operations after the terminal operation. + """ + ops_walk = list(chain(*[op.walk() for op in ops])) + + ops_defined_values = set() + + for op_walk in ops_walk: + ops_defined_values.update(op_walk.results) + + # use list here to maintain the order of required outputs + required_outputs = [] + found_terminal = False + for op in self.original_func_op.body.walk(): + # branch for the operations before the terminal_op + if op == terminal_op: + found_terminal = True + continue + + # branch for the operations after the terminal_op + if found_terminal: + for operand in op.operands: + # a required output is an operand defined by a result of op in the ops + if operand in ops_defined_values and operand not in required_outputs: + required_outputs.append(operand) + + return required_outputs + + def _add_return_statement(self, target_block, required_outputs, value_mapper): + """add a func.return op to function""" + return_values = [] + for output_val in required_outputs: + if output_val not in value_mapper: + raise ValueError(f"output_val {output_val} not in value_mapper") + return_values.append(value_mapper[output_val]) + + return_op = func.ReturnOp(*return_values) + target_block.add_op(return_op) + + def _clone_operations_to_block(self, ops_to_clone, target_block, value_mapper): + """Clone operations to target block, use value_mapper to update references""" + for op in ops_to_clone: + cloned_op = op.clone(value_mapper) + target_block.add_op(cloned_op) + + def _finalize_transformation(self): + """Replace the original function with a call to the state evolution function.""" + original_block = self.original_func_op.body.block + ops_list = list(original_block.ops) + + begin_idx = None + end_idx = None + for i, op in enumerate(ops_list): + if op == self.alloc_op: + begin_idx = i + 1 + elif op == self.terminal_boundary_op: + end_idx = i + 1 + break + + if begin_idx is None: + raise RuntimeError("A quantum.alloc operation is not found in original function.") + if end_idx is None: + raise RuntimeError( + "A terminal_boundary_op operation is not found in original function." + ) + if begin_idx > end_idx: + raise RuntimeError( + "A quantum.alloc operation should come before the terminal_boundary_op." + ) + + post_ops = ops_list[end_idx:] + + call_args = list(self.required_inputs) + result_types = [val.type for val in self.required_outputs] + + call_op = func.CallOp(self.state_evolution_func.sym_name.data, call_args, result_types) + + # TODO: I just removed all ops and add them again to update with value_mapper. + # It's not efficient, just because it's easy to implement. Should using replace use method + # instead. + # De-attach all ops of the original function + call_result_mapper = {} + for i, required_output in enumerate(self.required_outputs): + call_result_mapper[required_output] = call_op.results[i] + + value_mapper = call_result_mapper.copy() + original_block.add_op(call_op) + for op in post_ops: + cloned_op = op.clone(value_mapper) + original_block.add_op(cloned_op) + + # replace ops_list with call_op + for op in chain(reversed(post_ops), reversed(ops_list[begin_idx:end_idx])): + op.detach() + op.erase() diff --git a/frontend/catalyst/python_interface/transforms/quantum/__init__.py b/frontend/catalyst/python_interface/transforms/quantum/__init__.py new file mode 100644 index 0000000000..a3141f8b78 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""xDSL API for quantum transforms""" +from .cancel_inverses import IterativeCancelInversesPass, iterative_cancel_inverses_pass +from .combine_global_phases import CombineGlobalPhasesPass, combine_global_phases_pass +from .diagonalize_measurements import ( + DiagonalizeFinalMeasurementsPass, + diagonalize_final_measurements_pass, +) +from .measurements_from_samples import ( + MeasurementsFromSamplesPass, + measurements_from_samples_pass, +) +from .merge_rotations import MergeRotationsPass, merge_rotations_pass +from .split_non_commuting import SplitNonCommutingPass, split_non_commuting_pass + +__all__ = [ + "combine_global_phases_pass", + "CombineGlobalPhasesPass", + "diagonalize_final_measurements_pass", + "DiagonalizeFinalMeasurementsPass", + "iterative_cancel_inverses_pass", + "IterativeCancelInversesPass", + "measurements_from_samples_pass", + "MeasurementsFromSamplesPass", + "merge_rotations_pass", + "MergeRotationsPass", + "split_non_commuting_pass", + "SplitNonCommutingPass", +] diff --git a/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py b/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py new file mode 100644 index 0000000000..9f5053cc67 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py @@ -0,0 +1,102 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the cancel_inverses transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func +from xdsl.ir import Operation + +from catalyst.python_interface.dialects.quantum import CustomOp +from catalyst.python_interface.pass_api import compiler_transform + +self_inverses = [ + "Identity", + "Hadamard", + "PauliX", + "PauliY", + "PauliZ", + "CNOT", + "CZ", + "CY", + "CH", + "SWAP", + "Toffoli", + "CCZ", +] + + +def _can_cancel(op: CustomOp, next_op: Operation) -> bool: + if isinstance(next_op, CustomOp): + if op.gate_name.data == next_op.gate_name.data: + if ( + op.out_qubits == next_op.in_qubits + and op.out_ctrl_qubits == next_op.in_ctrl_qubits + and op.in_ctrl_values == next_op.in_ctrl_values + ): + return True + + return False + + +class IterativeCancelInversesPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for iteratively cancelling consecutive self-inverse gates.""" + + # pylint: disable=no-self-use + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + """Implementation of rewriting FuncOps that may contain operations corresponding to + self-inverse gates.""" + for op in funcOp.body.walk(): + + while isinstance(op, CustomOp) and op.gate_name.data in self_inverses: + + next_user = None + for use in op.results[0].uses: + user = use.operation + if _can_cancel(op, user): + next_user: CustomOp = user + break + + if next_user is None: + break + + for q1, q2 in zip(op.in_qubits, next_user.out_qubits, strict=True): + rewriter.replace_all_uses_with(q2, q1) + for cq1, cq2 in zip(op.in_ctrl_qubits, next_user.out_ctrl_qubits, strict=True): + rewriter.replace_all_uses_with(cq2, cq1) + rewriter.erase_op(next_user) + rewriter.erase_op(op) + op = op.in_qubits[0].owner + + +@dataclass(frozen=True) +class IterativeCancelInversesPass(passes.ModulePass): + """Pass for iteratively cancelling consecutive self-inverse gates.""" + + name = "xdsl-cancel-inverses" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the iterative cancel inverses pass.""" + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([IterativeCancelInversesPattern()]) + ).rewrite_module(module) + + +iterative_cancel_inverses_pass = compiler_transform(IterativeCancelInversesPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py b/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py new file mode 100644 index 0000000000..4b9d027f84 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py @@ -0,0 +1,86 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the combine_global_phases transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func +from xdsl.dialects.scf import ForOp, IfOp, WhileOp +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.quantum import GlobalPhaseOp +from catalyst.python_interface.pass_api import compiler_transform + + +class CombineGlobalPhasesPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for combining all :class:`~pennylane.GlobalPhase` gates within the same region + at the last global phase gate.""" + + # pylint: disable=no-self-use + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, root: func.FuncOp | IfOp | ForOp | WhileOp, rewriter: pattern_rewriter.PatternRewriter + ): # pylint: disable=cell-var-from-loop + """Match and rewrite for the combine-global-phases pattern acting on functions or + control-flow blocks containing GlobalPhase operations. + """ + + for region in root.regions: + phi = None + global_phases = [] + for op in region.ops: + if isinstance(op, GlobalPhaseOp): + global_phases.append(op) + + if len(global_phases) < 2: + continue + + prev = global_phases[0] + phi_sum = prev.operands[0] + for current in global_phases[1:]: + phi = current.operands[0] + addOp = arith.AddfOp(phi, phi_sum) + rewriter.insert_op(addOp, InsertPoint.before(current)) + phi_sum = addOp.result + + rewriter.erase_op(prev) + prev = current + + prev.operands[0].replace_by_if(phi_sum, lambda use: use.operation == prev) + rewriter.notify_op_modified(prev) + + +@dataclass(frozen=True) +class CombineGlobalPhasesPass(passes.ModulePass): + """Pass that combines all global phases within a region into the last global phase operation + within the region. + """ + + name = "combine-global-phases" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the combine-global-phases pass.""" + pattern_rewriter.PatternRewriteWalker( + CombineGlobalPhasesPattern(), + apply_recursively=False, + ).rewrite_module(module) + + +combine_global_phases_pass = compiler_transform(CombineGlobalPhasesPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py b/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py new file mode 100644 index 0000000000..763bff79e9 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py @@ -0,0 +1,158 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the diagonalize_final_measurements transform, +written using xDSL. + + +Known Limitations +----------------- + * Only observables PauliX, PauliY, PauliZ, Hadamard and Identity are currently supported when using this transform (but + these are also the only observables currently supported in the Quantum dialect as NamedObservable). + * Unlike the current tape-based implementation of the transform, it doesn't allow for diagonalization of a subset + of observables. + * Unlike the current tape-based implementation of the transform, conversion to measurements based on eigvals + and wires (rather than the PauliZ observable) is not currently supported. + * Unlike the tape-based implementation, this pass will NOT raise an error if given a circuit that is invalid because + it contains non-commuting measurements. It should be assumed that this transform results in incorrect outputs unless + split_non_commuting is applied to break non-commuting measurements into separate tapes. +""" + +from dataclasses import dataclass + +from pennylane.ops import Hadamard, PauliX, PauliY +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + GlobalPhaseOp, + MultiRZOp, + NamedObservable, + NamedObservableAttr, + NamedObsOp, + QubitUnitaryOp, +) + +from ...pass_api import compiler_transform + + +def _generate_mapping(): + _gate_map = {} + _params_map = {} + + for op in PauliX(0), PauliY(0), Hadamard(0): + diagonalizing_gates = op.diagonalizing_gates() + + _gate_map[op.name] = [gate.name for gate in diagonalizing_gates] + _params_map[op.name] = [gate.data for gate in diagonalizing_gates] + + return _gate_map, _params_map + + +_gate_map, _params_map = _generate_mapping() + + +def _diagonalize(obs: NamedObsOp) -> bool: + """Whether to diagonalize a given observable.""" + if obs.type.data in {"PauliZ", "Identity"}: + return False + if obs.type.data in _gate_map: + return True + raise NotImplementedError(f"Observable {obs.type.data} is not supported for diagonalization") + + +class DiagonalizeFinalMeasurementsPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for diagonalizing final measurements.""" + + # pylint: disable=no-self-use + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, observable: NamedObsOp, rewriter: pattern_rewriter.PatternRewriter): + """Replace non-diagonalized observables with their diagonalizing gates and PauliZ.""" + + if _diagonalize(observable): + + diagonalizing_gates = _gate_map[observable.type.data] + params = _params_map[observable.type.data] + + qubit = observable.qubit + + insert_point = InsertPoint.before(observable) + + for name, op_data in zip(diagonalizing_gates, params): + if op_data: + param_ssa_values = [] + for param in op_data: + paramOp = arith.ConstantOp( + builtin.FloatAttr(data=param, type=builtin.Float64Type()) + ) + rewriter.insert_op(paramOp, insert_point) + param_ssa_values.append(paramOp.results[0]) + + gate = CustomOp(in_qubits=qubit, gate_name=name, params=param_ssa_values) + else: + gate = CustomOp(in_qubits=qubit, gate_name=name) + + rewriter.insert_op(gate, insert_point) + + qubit = gate.out_qubits[0] + + # we need to replace the initial qubit use everwhere EXCEPT the use that is now the + # input to the first diagonalizing gate. Its not enough to only change the NamedObsOp, + # because the qubit might be inserted/deallocated later + uses_to_change = [ + use + for use in observable.qubit.uses + if not isinstance( + use.operation, (CustomOp, GlobalPhaseOp, MultiRZOp, QubitUnitaryOp) + ) + ] + num_observables = len( + [use for use in uses_to_change if isinstance(use.operation, NamedObsOp)] + ) + + if num_observables > 1: + raise RuntimeError( + "Each wire can only have one set of diagonalizing gates applied, but the circuit contains multiple observables with the same wire." + ) + + observable.qubit.replace_by_if(qubit, lambda use: use in uses_to_change) + for use in uses_to_change: + rewriter.notify_op_modified(use.operation) + + # then we also update the observable to be in the Z basis. Since this is done with the + # rewriter, we don't need to call `rewriter.notify_modified(observable)` regarding this + diag_obs = NamedObsOp( + qubit=qubit, obs_type=NamedObservableAttr(NamedObservable("PauliZ")) + ) + rewriter.replace_op(observable, diag_obs) + + +@dataclass(frozen=True) +class DiagonalizeFinalMeasurementsPass(passes.ModulePass): + """Pass for diagonalizing final measurements.""" + + name = "diagonalize-final-measurements" + + # pylint: disable= no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the diagonalize final measurements pass.""" + pattern_rewriter.PatternRewriteWalker(DiagonalizeFinalMeasurementsPattern()).rewrite_module( + module + ) + + +diagonalize_final_measurements_pass = compiler_transform(DiagonalizeFinalMeasurementsPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py b/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py new file mode 100644 index 0000000000..527c433ff2 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py @@ -0,0 +1,659 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the implementation of the measurements_from_samples transform, +written using xDSL. + +Known Limitations +----------------- + + * Only measurements in the computational basis (or where the observable is a Pauli Z op) are + currently supported; for arbitrary observables we require an equivalent compilation pass of the + diagonalize_measurements transform. + * The compilation pass assumes a static number of shots. + * Usage patterns that are not yet supported with program capture are also not supported in the + compilation pass. For example, operator arithmetic is not currently supported, such as + qml.expval(qml.Y(0) @ qml.X(1)). + * qml.counts() is not supported since the return type/shape is different in PennyLane and Catalyst. + See https://docs.pennylane.ai/projects/catalyst/en/stable/dev/quick_start.html#measurements + for more information. +""" + +from abc import abstractmethod +from dataclasses import dataclass +from itertools import islice + +import jax +import jax.numpy as jnp +from pennylane.exceptions import CompileError +from xdsl import context, ir, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func, tensor +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.conversion import xdsl_module +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform + + +@dataclass(frozen=True) +class MeasurementsFromSamplesPass(passes.ModulePass): + """Pass that replaces all terminal measurements in a program with a single + :func:`pennylane.sample` measurement, and adds postprocessing instructions to recover the + original measurement. + """ + + name = "measurements-from-samples" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the measurements-from-samples pass.""" + shots = _get_static_shots_value_from_first_device_op(module) + + greedy_applier = pattern_rewriter.GreedyRewritePatternApplier( + [ + ExpvalAndVarPattern(shots), + ProbsPattern(shots), + CountsPattern(shots), + StatePattern(shots), + ] + ) + walker = pattern_rewriter.PatternRewriteWalker(greedy_applier, apply_recursively=False) + walker.rewrite_module(module) + + +measurements_from_samples_pass = compiler_transform(MeasurementsFromSamplesPass) + + +class MeasurementsFromSamplesPattern(RewritePattern): + """Rewrite pattern base class for the ``measurements_from_samples`` transform, which replaces + all terminal measurements in a program with a single :func:`pennylane.sample` measurement, and + adds postprocessing instructions to recover the original measurement. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + def __init__(self, shots: int): + super().__init__() + assert isinstance( + shots, int + ), f"Expected `shots` to be an integer value but got {type(shots).__name__}" + self._shots = shots + + @abstractmethod + def match_and_rewrite(self, op: ir.Operation, rewriter: PatternRewriter, /): + """Abstract method for measurements-from-samples match-and-rewrite patterns.""" + + @classmethod + def get_observable_op(cls, op: quantum.ExpvalOp | quantum.VarianceOp) -> quantum.NamedObsOp: + """Return the observable op (quantum.NamedObsOp) given as an input operand to `op`. + + We assume that `op` is either a quantum.ExpvalOp or quantum.VarianceOp, but this is not + strictly enforced. + + Args: + op (quantum.ExpvalOp | quantum.VarianceOp): The op that uses the observable op. + + Returns: + quantum.NamedObsOp: The observable op. + """ + observable_op = op.operands[0].owner + cls._validate_observable_op(observable_op) + + return observable_op + + @staticmethod + def _validate_observable_op(op: quantum.NamedObsOp): + """Validate the observable op. + + Assert that the op is a quantum.NamedObsOp and check if it is supported in the current + implementation of the measurements-from-samples transform. + + Raises: + NotImplementedError: If the observable is anything but a PauliZ quantum.NamedObsOp. + """ + assert isinstance( + op, quantum.NamedObsOp + ), f"Expected `op` to be a quantum.NamedObsOp, but got {type(op).__name__}" + + if op.type.data != "PauliZ": + raise NotImplementedError( + f"Observable '{op.type.data}' used as input to measurement operation is not " + f"supported for the measurements_from_samples transform; currently only the " + f"PauliZ observable is permitted" + ) + + @staticmethod + def insert_compbasis_op( + in_qubit: ir.SSAValue, ref_op: ir.Operation, rewriter: PatternRewriter + ) -> quantum.ComputationalBasisOp: + """Create and insert a computational-basis op (quantum.ComputationalBasisOp). + + The computation-basis op uses `in_qubit` as its input operand. It is inserted *before* the + given reference operation, `ref_op`, using the supplied `rewriter`. + + Args: + in_qubit (SSAValue): The SSA value used as input to the computational-basis op. + ref_op (Operation): The reference op before which the quantum.ComputationalBasisOp is + inserted. + rewriter (PatternRewriter): The xDSL pattern rewriter. + + Returns: + quantum.ComputationalBasisOp: The inserted computation-basis op. + """ + assert isinstance(in_qubit, ir.SSAValue) and isinstance(in_qubit.type, quantum.QubitType), ( + f"Expected `in_qubit` to be an SSAValue with type quantum.QubitType, but got " + f"{type(in_qubit).__name__}" + ) + + # The input operands are [[qubit, ...], qreg] + compbasis_op = quantum.ComputationalBasisOp( + operands=[in_qubit, None], result_types=[quantum.ObservableType()] + ) + rewriter.insert_op(compbasis_op, insertion_point=InsertPoint.before(ref_op)) + + return compbasis_op + + @staticmethod + def insert_sample_op( + compbasis_op: quantum.ComputationalBasisOp, + shots: int, + n_qubits: int, + rewriter: PatternRewriter, + ) -> quantum.SampleOp: + """Create and insert a sample op (quantum.SampleOp). + + The type of the returned samples array is currently restricted to be static, with shape + (shots, n_qubits). + + The sample op is inserted after the supplied `compbasis_op`. + + Args: + compbasis_op (quantum.ComputationalBasisOp): The computational-basis op used as the + input operand to the sample op. + shots (int): Number of shots (to set the shape of the sample op returned array). + n_qubits (int): Number of qubits (to set the shape of the sample op returned array). + rewriter (PatternRewriter): The xDSL pattern rewriter. + + Returns: + quantum.SampleOp: The inserted sample op. + """ + assert isinstance(compbasis_op, quantum.ComputationalBasisOp), ( + f"Expected `compbasis_op` to be a quantum.ComputationalBasisOp, but got " + f"{type(compbasis_op).__name__}" + ) + + sample_op = quantum.SampleOp( + operands=[compbasis_op.results[0], None, None], + result_types=[builtin.TensorType(builtin.Float64Type(), [shots, n_qubits])], + ) + rewriter.insert_op(sample_op, insertion_point=InsertPoint.after(compbasis_op)) + + return sample_op + + @staticmethod + def get_postprocessing_func_op_from_block_by_name( + block: ir.Block, name: str + ) -> func.FuncOp | None: + """Return the post-processing FuncOp from the given `block` with the given `name`. + + If the block does not contain a FuncOp with the matching name, returns None. + + Args: + block (Block): The xDSL block to search. + name (str): The name of the post-processing FuncOp. + + Returns: + func.FuncOp: The FuncOp with matching name. + None: If no match was found. + """ + for op in block.ops: + if isinstance(op, func.FuncOp) and op.sym_name.data == name: + return op + + return None + + @classmethod + def get_postprocessing_funcs_from_module_and_insert( + cls, + postprocessing_module: builtin.ModuleOp, + matched_op: ir.Operation, + name: str | None = None, + ) -> func.FuncOp: + """Get the post-processing FuncOp from `postprocessing_module` (and any helper functions + also contained in `postprocessing_module`) and insert it (them) immediately after the FuncOp + (in the same block) that contains `matched_op`. + + The post-processing function recovers the original measurement process result from the + samples array. This post-postprocessing function is optionally renamed to `name`, if given. + + Args: + postprocessing_module (builtin.ModuleOp): The MLIR module containing the post-processing + FuncOp. + matched_op (Operation): The reference op, the parent of which is used as the + reference point when inserting the post-processing FuncOp. This is usually the op + matched in the call to match_and_rewrite(). + name (str, optional): The name to assign to the post-processing FuncOp, if given. + + Returns: + func.FuncOp: The inserted post-processing FuncOp. + """ + parent_func_op = matched_op.parent_op() + + assert isinstance(parent_func_op, func.FuncOp), ( + f"Expected parent of matched op '{matched_op}' to be a func.FuncOp, but got " + f"{type(parent_func_op).__name__}" + ) + + # This first op in `postprocessing_module` is the "main" post-processing function + postprocessing_func_op = postprocessing_module.body.ops.first.clone() + assert isinstance(postprocessing_func_op, func.FuncOp), ( + f"Expected the first operator of `postprocessing_module` to be a func.FuncOp but " + f"got {type(postprocessing_func_op).__name__}" + ) + + # The function name from jax.jit is 'main'; rename it here + if name is not None: + postprocessing_func_op.sym_name = builtin.StringAttr(data=name) + + parent_block = parent_func_op.parent + parent_block.insert_op_after(postprocessing_func_op, parent_func_op) + + # Get and insert helper functions, if any + if len(postprocessing_module.body.ops) > 1: + prev_op = postprocessing_func_op + for _op in islice(postprocessing_module.body.ops, 1, None): + helper_op = _op.clone() + parent_block.insert_op_after(helper_op, prev_op) + prev_op = helper_op + + return postprocessing_func_op + + @staticmethod + def insert_constant_int_op( + value: int, + insert_point: InsertPoint, + rewriter: PatternRewriter, + value_type: int = 64, + ) -> arith.ConstantOp: + """Create and insert a constant op with the given integer value. + + The integer value is contained within a rankless, dense tensor. + + Args: + value (int): The integer value. + insert_point (InsertPoint): The insertion point for the constant op. + rewriter (PatternRewriter): The xDSL pattern rewriter. + value_type (int, optional): The integer value type (i.e. number of bits). Defaults to 64. + + Returns: + arith.ConstantOp: The created constant op. + """ + constant_int_op = arith.ConstantOp( + builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(builtin.IntegerType(value_type), shape=()), data=(value,) + ) + ) + rewriter.insert_op(constant_int_op, insertion_point=insert_point) + + return constant_int_op + + @staticmethod + def get_n_qubits_from_qreg(qreg: ir.SSAValue): + """Get the number of qubits from a qreg SSA value. + + This method walks back through the SSA graph from the qreg until it reaches its root + quantum.alloc op, or alloc-like op (with possibly zero or more quantum.insert ops + in-between), from which the number of qubits is extracted. + + An op is "alloc-like" if it has an 'nqubits_attr' attribute. + + Args: + qreg (SSAValue): The qreg SSA value. + """ + assert isinstance(qreg, ir.SSAValue) and isinstance(qreg.type, quantum.QuregType), ( + f"Expected `qreg` to be an SSAValue with type quantum.QuregType, but got " + f"{type(qreg).__name__}" + ) + + def _walk_back_to_alloc_op( + insert_or_alloc_op: quantum.AllocOp | quantum.InsertOp, + ) -> quantum.AllocOp | None: + """Recursively walk back from a quantum.insert op to its root quantum.alloc op or + alloc-like op. + + Once found, return the quantum.alloc op. + """ + if ( + isinstance(insert_or_alloc_op, quantum.AllocOp) + or insert_or_alloc_op.properties.get("nqubits_attr") is not None + ): + return insert_or_alloc_op + + if isinstance(insert_or_alloc_op, quantum.InsertOp): + return _walk_back_to_alloc_op(insert_or_alloc_op.operands[0].owner) + + return None + + alloc_op = _walk_back_to_alloc_op(qreg.owner) + assert alloc_op is not None, "Unable to walk back from qreg to alloc op" + + nqubits_attr = alloc_op.properties.get("nqubits_attr") + assert ( + nqubits_attr is not None + ), "Unable to determine number of qubits from alloc op; missing property 'nqubits_attr'" + + n_qubits = nqubits_attr.value.data + assert n_qubits is not None, "Unable to determine number of qubits from qreg SSA value" + + return n_qubits + + +class ExpvalAndVarPattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.expval()`` and ``qml.var()`` operations. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite( + self, matched_op: quantum.ExpvalOp | quantum.VarianceOp, rewriter: PatternRewriter, / + ): + """Match and rewrite for quantum.ExpvalOp.""" + observable_op = self.get_observable_op(matched_op) + in_qubit = observable_op.operands[0] + compbasis_op = self.insert_compbasis_op(in_qubit, observable_op, rewriter) + sample_op = self.insert_sample_op(compbasis_op, self._shots, 1, rewriter) + + # Insert the post-processing function into current module or get handle to it if already + # inserted + match matched_op: + case quantum.ExpvalOp(): + postprocessing_func_name = f"expval_from_samples.tensor.{self._shots}x1xf64" + postprocessing_jit_func = _postprocessing_expval + case quantum.VarianceOp(): + postprocessing_func_name = f"var_from_samples.tensor.{self._shots}x1xf64" + postprocessing_jit_func = _postprocessing_var + case _: + assert False, ( + f"Expected a quantum.ExpvalOp or quantum.VarianceOp, but got " + f"{type(matched_op).__name__}" + ) + + postprocessing_func_op = self.get_postprocessing_func_op_from_block_by_name( + matched_op.parent_op().parent, postprocessing_func_name + ) + + if postprocessing_func_op is None: + # TODO: Do we have to set the shape of the samples array statically here? Or can the + # shape (shots, wire) be dynamic and given as SSA values? + # Same goes for the column/wire indices (the second argument). + postprocessing_module = postprocessing_jit_func( + jax.core.ShapedArray([self._shots, 1], float), 0 + ) + + postprocessing_func_op = self.get_postprocessing_funcs_from_module_and_insert( + postprocessing_module, matched_op, postprocessing_func_name + ) + + # Insert the op that specifies which column in the samples array to access. + # TODO: This also assumes MP acts on a single wire (hence we always use column 0 of the + # samples here); what if MP acts on multiple wires? + column_index_op = self.insert_constant_int_op( + 0, insert_point=InsertPoint.after(sample_op), rewriter=rewriter + ) + + # Insert the call to the post-processing function + postprocessing_func_call_op = func.CallOp( + callee=builtin.FlatSymbolRefAttr(postprocessing_func_op.sym_name), + arguments=[sample_op.results[0], column_index_op], + return_types=[builtin.TensorType(builtin.Float64Type(), shape=())], + ) + + # The op to replace is not the expval/var op itself, but the tensor.from_elements op that follows + op_to_replace = list(matched_op.results[0].uses)[0].operation + assert isinstance( + op_to_replace, tensor.FromElementsOp + ), f"Expected to replace a tensor.from_elements op, but got {type(op_to_replace).__name__}" + rewriter.replace_op(op_to_replace, postprocessing_func_call_op) + + # Finally, erase the expval/var op and its associated observable op + rewriter.erase_matched_op() + rewriter.erase_op(observable_op) + + +class ProbsPattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.probs()`` operations. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, probs_op: quantum.ProbsOp, rewriter: PatternRewriter, /): + """Match and rewrite for quantum.ProbsOp.""" + compbasis_op = probs_op.operands[0].owner + + n_qubits = None + if compbasis_op.qreg is not None: + n_qubits = self.get_n_qubits_from_qreg(compbasis_op.qreg) + + elif not compbasis_op.qubits == (): + n_qubits = len(compbasis_op.qubits) + + assert ( + n_qubits is not None + ), "Unable to determine number of qubits from quantum.compbasis op" + + sample_op = self.insert_sample_op(compbasis_op, self._shots, n_qubits, rewriter) + + # Insert the post-processing function into current module or + # get handle to it if already inserted + postprocessing_func_name = f"probs_from_samples.tensor.{self._shots}x{n_qubits}xf64" + + postprocessing_func_op = self.get_postprocessing_func_op_from_block_by_name( + probs_op.parent_op().parent, postprocessing_func_name + ) + + if postprocessing_func_op is None: + # TODO: Do we have to set the shape of the samples array statically here? Or can the + # shape (shots, wire) be dynamic and given as SSA values? + # Same goes for the column/wire indices (the second argument). + postprocessing_module = _postprocessing_probs( + jax.core.ShapedArray([self._shots, n_qubits], float) + ) + + postprocessing_func_op = self.get_postprocessing_funcs_from_module_and_insert( + postprocessing_module, probs_op, postprocessing_func_name + ) + + # Insert the call to the post-processing function + postprocessing_func_call_op = func.CallOp( + callee=builtin.FlatSymbolRefAttr(postprocessing_func_op.sym_name), + arguments=[sample_op.results[0]], + return_types=[builtin.TensorType(builtin.Float64Type(), shape=(2**n_qubits,))], + ) + + rewriter.replace_matched_op(postprocessing_func_call_op) + + +class CountsPattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.counts()`` operations. + + Currently there is no plan to support ``qml.counts()`` for this transform. It is included for + completeness and to notify users that workloads containing ``counts`` measurement processes are + not supported with the measurements-from-samples transform. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, counts_op: quantum.CountsOp, rewriter: PatternRewriter, /): + """Match and rewrite for quantum.CountsOp.""" + raise NotImplementedError("qml.counts() operations are not supported.") + + +class StatePattern(MeasurementsFromSamplesPattern): + """A rewrite pattern for the ``measurements_from_samples`` transform that matches and rewrites + ``qml.state()`` operations. + + It is not possible to recover a quantum state from samples; this pattern is included for + completeness and to notify users that workloads containing ``state`` measurement processes are + not supported with the measurements-from-samples transform. + + Args: + shots (int): The number of shots (e.g. as retrieved from the DeviceInitOp). + """ + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, state_op: quantum.StateOp, rewriter: PatternRewriter, /): + """Match and rewrite for quantum.StateOp.""" + raise NotImplementedError("qml.state() operations are not supported.") + + +def _get_static_shots_value_from_first_device_op(module: builtin.ModuleOp) -> int: + """Returns the number of shots as a static (i.e. known at compile time) integer value from the + first instance of a device-initialization op (quantum.DeviceInitOp) found in `module`. + + If `module` contains multiple quantum.DeviceInitOp ops, only the number of shots from the + *first* instance is used, and the others are ignored. + + This function expects the number of shots to be an SSA value given as an operand to the + quantum.DeviceInitOp op. It also assumes that the number of shots is static, retrieving it from + the 'value' attribute of its corresponding constant op. + + Args: + module (builtin.ModuleOp): The MLIR module containing the quantum.DeviceInitOp. + + Returns: + int: The number of shots. + + Raises: + CompileError: If `module` does not contain a quantum.DeviceInitOp. + """ + device_op = None + + for op in module.body.walk(): + if isinstance(op, quantum.DeviceInitOp): + device_op = op + break + + if device_op is None: + raise CompileError( + "Cannot get number of shots; the module does not contain a quantum.DeviceInitOp" + ) + + # The number of shots is passed as an SSA value operand to the DeviceInitOp + shots_operand = device_op.shots + shots_extract_op = shots_operand.owner + + if isinstance(shots_extract_op, tensor.ExtractOp): + shots_constant_op = shots_extract_op.operands[0].owner + shots_value_attribute: builtin.DenseIntOrFPElementsAttr = shots_constant_op.properties.get( + "value" + ) + if shots_value_attribute is None: + raise ValueError("Cannot get number of shots; the constant op has no 'value' attribute") + + shots_int_values = shots_value_attribute.get_values() + if len(shots_int_values) != 1: + raise ValueError(f"Expected a single shots value, got {len(shots_int_values)}") + + return shots_int_values[0] + + if isinstance(shots_extract_op, arith.ConstantOp): + shots_value_attribute: builtin.IntAttr = shots_extract_op.properties.get("value") + return shots_value_attribute.value.data + + raise ValueError( + f"Expected owner of shots operand to be a tensor.ExtractOp or arith.ConstantOp but got " + f"{type(shots_extract_op).__name__}" + ) + + +@xdsl_module +@jax.jit +def _postprocessing_expval(samples, column): + """Post-processing to recover the expectation value from the given `samples` array for each + requested `column` in the array. + + This function assumes that the samples are in the computational basis (0s and 1s) and that the + observable operand of the expectation value has eigenvalues +1 and -1. + + Args: + samples (jax.core.ShapedArray): Array of samples, with shape (shots, wires). + column (int, jax.core.ShapedArray): Column index (or indices) of the `samples` array over + which the expectation value is computed. + + Returns: + jax.core.ShapedArray: The expectation value for each requested column. + """ + return jnp.mean(1.0 - 2.0 * samples[:, column], axis=0) + + +@xdsl_module +@jax.jit +def _postprocessing_var(samples, column): + """Post-processing to recover the variance from the given `samples` array for each requested + `column` in the array. + + This function assumes that the samples are in the computational basis (0s and 1s) and that the + observable operand of the variance has eigenvalues +1 and -1. + + Args: + samples (jax.core.ShapedArray): Array of samples, with shape (shots, wires). + column (int, jax.core.ShapedArray): Column index (or indices) of the `samples` array over + which the variance is computed. + + Returns: + jax.core.ShapedArray: The variance for each requested column. + """ + return jnp.var(1.0 - 2.0 * samples[:, column], axis=0) + + +@xdsl_module +@jax.jit +def _postprocessing_probs(samples): + """Post-processing to recover the probability values from the given `samples` array. + + This function assumes that the samples are in the computational basis (0s and 1s). + + Args: + samples (jax.core.ShapedArray): Array of samples, with shape (shots, wires). + """ + n_samples = samples.shape[0] + n_wires = samples.shape[1] + + # Convert samples from a list of 0, 1 integers to base 10 representation + powers_of_two = 2 ** jnp.arange(n_wires)[::-1] + indices = samples @ powers_of_two + dim = 2**n_wires + + # This block is effectively equivalent to `jnp.bincount(indices.astype(int), length=dim)`. + # However, we are currently not able to use jnp.bincount with Catalyst because after lowering, + # it contains a stablehlo.scatter op with , + # which we currently do not support. + # If Catalyst PR https://github.com/PennyLaneAI/catalyst/pull/1849 is merged, then we should be + # able to use bincount. + counts = jnp.zeros(dim, dtype=int) + for i in indices.astype(int): + counts = counts.at[i].add(1) + + return counts / n_samples diff --git a/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py b/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py new file mode 100644 index 0000000000..f1477f7c7b --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py @@ -0,0 +1,141 @@ +# Copyright 2018-2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the merge_rotations transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func +from xdsl.ir import Operation +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects.quantum import CustomOp +from catalyst.python_interface.pass_api import compiler_transform + +# Can handle all composible rotations except Rot... for now +composable_rotations = [ + "RX", + "RY", + "RZ", + "PhaseShift", + "CRX", + "CRY", + "CRZ", + "ControlledPhaseShift", + "IsingXX", + "IsingYY", + "IsingXY", + "IsingZZ", + "MultiRZ", + "SingleExcitation", + "DoubleExcitation", + "SingleExcitationMinus", + "SingleExcitationPlus", + "DoubleExcitationMinus", + "DoubleExcitationPlus", + "OrbitalRotation", +] + + +def _can_merge(op: CustomOp, next_op: Operation) -> bool: + if isinstance(next_op, CustomOp): + if op.gate_name.data == next_op.gate_name.data: + if ( + op.out_qubits == next_op.in_qubits + and op.out_ctrl_qubits == next_op.in_ctrl_qubits + and op.in_ctrl_values == next_op.in_ctrl_values + ): + return True + + return False + + +class MergeRotationsPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for merging consecutive composable rotations.""" + + # pylint: disable=no-self-use + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + """Implementation of rewriting FuncOps that may contain operations corresponding to + consecutive composable rotations.""" + for op in funcOp.body.walk(): + if not isinstance(op, CustomOp): + continue + + gate_name = op.gate_name.data + if gate_name not in composable_rotations: + continue + + param = op.operands[0] + while True: + for use in op.results[0].uses: + user = use.operation + if _can_merge(op, user): + next_user = user + break + else: + # No next_user was set because no user could be merged. Go to next op + break + + for q1, q2 in zip(op.in_qubits, op.out_qubits, strict=True): + rewriter.replace_all_uses_with(q2, q1) + for cq1, cq2 in zip(op.in_ctrl_qubits, op.out_ctrl_qubits, strict=True): + rewriter.replace_all_uses_with(cq2, cq1) + + # Whether op is adjoint will determine adjoint of new_op, and how to combine angles + is_adjoint = getattr(op, "adjoint", False) + rewriter.erase_op(op) + + next_param = next_user.operands[0] + next_is_adjoint = getattr(next_user, "adjoint", False) + if is_adjoint == next_is_adjoint: + # If both op and next_user are adjoint, or neither is, we add angles + combOp = arith.AddfOp(param, next_param) + else: + # If one of op and next_user is adjoint, we subtract angles + combOp = arith.SubfOp(param, next_param) + rewriter.insert_op(combOp, InsertPoint.before(next_user)) + param = combOp.result + + new_op = CustomOp( + in_qubits=next_user.in_qubits, + gate_name=next_user.gate_name, + params=(param,), + in_ctrl_qubits=next_user.in_ctrl_qubits, + in_ctrl_values=next_user.in_ctrl_values, + adjoint=is_adjoint, + ) + rewriter.replace_op(next_user, new_op) + op = new_op + + +@dataclass(frozen=True) +class MergeRotationsPass(passes.ModulePass): + """Pass for merging consecutive composable rotation gates.""" + + name = "xdsl-merge-rotations" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the merge rotations pass.""" + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()]) + ).rewrite_module(module) + + +merge_rotations_pass = compiler_transform(MergeRotationsPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py b/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py new file mode 100644 index 0000000000..8dd0378ed9 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py @@ -0,0 +1,498 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains a limited prototype of the split_non_commuting pass. + +Known Limitations +----------------- + + * Only single-term observables with no coefficients are supported - there is no support for CompositeOp or SymbolicOp observables + * Only the Expval measurement process is supported + * There is no option to specify a grouping strategy (this will be more relevant once CompositeOp support is added) + * Hence, only the "wires" grouping strategy is implemented, not taking into account observable-commutation logic yet. + * There is no efficient handling of duplicate observables - a circuit that returns multiple measurements on the same observable will split into multiple executions (this will be more relevant once CompositeOp support is added) + +Example: +------------------ +For the following IR: +``` +func.func public circ(...) { + ... + %reg0 = quantum.insert ... + %reg1 = quantum.insert ... + %0 = func.call @circ.state_evolution(%reg0, ...) + %q0 = quantum.extract ... + %q1 = quantum.extract ... + ... + %1 = quantum.expval %q0 + %2 = quantum.expval %q1 + ... + func.return %1, %2 +} +``` + +We want to split the function into two functions based on the wire-based grouping strategy. +``` +func.func public circ(...) { + %0 = func.call @circ.dup0(...) + %1 = func.call @circ.dup1(...) + func.return %0, %1 +} +func.func circ.dup0(...) { + ... + %reg0 = quantum.insert ... + %reg1 = quantum.insert ... + %0 = func.call @circ.state_evolution(%reg1, ...) + %q0 = quantum.extract ... + ... + %1 = quantum.expval %q0 + return %1 +} +func.func circ.dup1(...) { + ... + %reg1 = quantum.insert ... + %0 = func.call @circ.state_evolution(%reg1, ...) + %q1 = quantum.extract ... + ... + %1 = quantum.expval %q1 + return %1 +} +func.func circ.state_evolution(%reg, ...) { + %q0 = quantum.extract ... + %q1 = quantum.extract ... + ... + %new_reg0 = quantum.insert ... + %new_reg1 = quantum.insert ... + return %new_reg1 +} +``` + +reference: https://docs.pennylane.ai/en/stable/code/api/pennylane.transforms.split_non_commuting.html +""" + +from dataclasses import dataclass +from typing import Type, TypeVar + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func +from xdsl.ir import Operation, SSAValue +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform + + +@dataclass(frozen=True) +class SplitNonCommutingPass(passes.ModulePass): + """Pass that splits quantum functions measuring non-commuting observables. + + This pass groups measurements using the "wires" grouping strategy and splits + the function into multiple executions, one per group of measurements. + """ + + name = "split-non-commuting" + + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the split non-commuting pass to all QNode functions in the module.""" + for op in module.ops: + if isinstance(op, func.FuncOp) and "qnode" in op.attributes: + rewriter = pattern_rewriter.PatternRewriter(op) + SplitNonCommutingPattern().match_and_rewrite(op, rewriter) + + +split_non_commuting_pass = compiler_transform(SplitNonCommutingPass) + + +class SplitNonCommutingPattern(pattern_rewriter.RewritePattern): + """Pattern that splits a quantum function into multiple functions based on wire-based grouping. + + Measurements acting on different wires are grouped together, while measurements + acting on the same wire are split into separate groups. + """ + + def __init__(self): + self.module: builtin.ModuleOp = None + + T = TypeVar("T") + + def get_parent_of_type(self, op: Operation, kind: Type[T]) -> T | None: + """Walk up the parent tree until an op of the specified type is found.""" + while (op := op.parent_op()) and not isinstance(op, kind): + pass + if not isinstance(op, kind): + raise ValueError(f"get_parent_of_type: expected {kind} but got {type(op)}, op: {op}") + return op + + def clone_operations_to_block(self, ops_to_clone, target_block, value_mapper): + """Clone operations to target block, use value_mapper to update references""" + for op in ops_to_clone: + cloned_op = op.clone(value_mapper) + target_block.add_op(cloned_op) + + def create_dup_function( + self, func_op: func.FuncOp, i: int, rewriter: pattern_rewriter.PatternRewriter + ): + """Create a new function for the dup region by fully cloning the original function.""" + # Use the same signature as the original function + original_func_type = func_op.function_type + input_types = list(original_func_type.inputs.data) + output_types = list(original_func_type.outputs.data) + fun_type = builtin.FunctionType.from_lists(input_types, output_types) + + dup_func = func.FuncOp(func_op.sym_name.data + ".dup." + str(i), fun_type) + rewriter.insert_op(dup_func, InsertPoint.at_end(self.module.body.block)) + + # Map original function arguments to dup function arguments + dup_block = dup_func.regions[0].block + orig_block = func_op.body.block + value_mapper = {} + for orig_arg, dup_arg in zip(orig_block.args, dup_block.args): + value_mapper[orig_arg] = dup_arg + + # Clone all operations except the return statement + ops_to_clone = [] + return_op = None + for op in orig_block.ops: + if isinstance(op, func.ReturnOp): + return_op = op + else: + ops_to_clone.append(op) + + # Clone operations + self.clone_operations_to_block(ops_to_clone, dup_block, value_mapper) + + # Clone the return statement + if return_op: + return_values = [value_mapper.get(val, val) for val in return_op.operands] + new_return_op = func.ReturnOp(*return_values) + dup_block.add_op(new_return_op) + + # Remove expvals from other groups and update return statement + self.remove_group(dup_func, i) + + return dup_func + + def remove_group(self, dup_func: func.FuncOp, target_group: int): + """Remove measurement operations from other groups and update return statement.""" + # Find the return operation in the dup function + return_op = list(dup_func.body.ops)[-1] + + return_values_to_remove = set[SSAValue]() + for operand in return_op.operands: + group_id = self.find_group_for_return_value(operand) + if group_id != target_group: + return_values_to_remove.add(operand) + + # collect all operations to remove + remove_ops = list[Operation]([value.owner for value in return_values_to_remove]) + + # update return statement + self.update_return_statement(dup_func, return_values_to_remove) + + # remove operations + while remove_ops: + op = remove_ops.pop(0) + users = [use.operation for result in list(op.results) for use in list(result.uses)] + + # if the operation has users, skip it + if len(users) > 0: + continue + + if not self.is_observable_op(op): + # keep walking up the chain + for operand in op.operands: + if operand not in remove_ops: + remove_ops.append(operand.owner) + + op.detach() + op.erase() + + def update_return_statement(self, func_op: func.FuncOp, values_to_remove: set[SSAValue]): + """Update the return statement to remove specified values.""" + # Find the return operation + return_op = None + for op in func_op.body.ops: + if isinstance(op, func.ReturnOp): + return_op = op + break + + if not return_op: + return + + # Filter out values to remove + new_return_values = [val for val in return_op.operands if val not in values_to_remove] + + # Create new return operation + new_return_op = func.ReturnOp(*new_return_values) + + # Replace the old return operation + return_op.detach() + return_op.erase() # Important: erase to remove operand uses + func_op.body.block.add_op(new_return_op) + + # Update function signature + new_output_types = [val.type for val in new_return_values] + input_types = [arg.type for arg in func_op.body.block.args] + new_fun_type = builtin.FunctionType.from_lists(input_types, new_output_types) + func_op.function_type = new_fun_type + + def is_measurement_op(self, op: Operation) -> bool: + """Check if an operation is a measurement operation.""" + # TODO: support more measurement operations + if isinstance(op, quantum.ExpvalOp): + return True + if isinstance(op, (quantum.VarianceOp, quantum.ProbsOp, quantum.SampleOp)): + raise NotImplementedError( + f"measurement operations other than expval are not supported: {op}" + ) + return False + + def is_observable_op(self, op: Operation) -> bool: + """Check if an operation is an observable operation.""" + if isinstance( + op, + ( + quantum.NamedObsOp, + quantum.ComputationalBasisOp, + quantum.HamiltonianOp, + quantum.TensorOp, + ), + ): + return True + return False + + def calculate_num_groups(self, func_op: func.FuncOp) -> int: + """Calculate the number of groups using the "wires" grouping strategy. + + This function groups measurements based on wire overlaps only, disregarding + the actual commutation relations between observables. Measurements acting on + different wires are grouped together, while measurements acting on the same + wire are split into different groups. + + The function also stores the group ID in the "group" attribute of each + measurement operation, which is later used to handle the splitting mechanics. + + Args: + func_op: The function operation containing measurements to group. + + Returns: + The number of groups created. + """ + # Find all measurement operations in the current function + measurement_ops = [op for op in func_op.body.ops if self.is_measurement_op(op)] + + # For each measurement operation, find the qubits the operation acts on + op_to_acted_qubits: dict[Operation, set[SSAValue]] = { + measurement_op: set() for measurement_op in measurement_ops + } + + for measurement_op in measurement_ops: + observable = measurement_op.operands[0] + op_to_acted_qubits[measurement_op].update(self.get_qubits_from_observable(observable)) + + # Group measurements: operations on different qubits can be in the same group + # Operations on the same qubit must be in different groups + groups: list[dict[Operation, set[SSAValue]]] = [] # Each group stores op -> qubits mapping + + for measurement_op, qubits in op_to_acted_qubits.items(): + if len(qubits) > 1: + raise NotImplementedError("operations acting on multiple qubits are not supported") + + # Find a group where no operation acts on any of the same qubits + assigned_group_id = None + + for group_id, group in enumerate(groups): + # Get all qubits already used in this group + used_qubits = set() + for group_qubits in group.values(): + used_qubits.update(group_qubits) + + # Check if this measurement's qubits conflict with the group + if not qubits.intersection(used_qubits): + # No conflict - can add to this group + group[measurement_op] = qubits + assigned_group_id = group_id + break + + # If no suitable group found, create a new one + if assigned_group_id is None: + assigned_group_id = len(groups) + groups.append({measurement_op: qubits}) + + # Tag the measurement operation with the group attribute + measurement_op.attributes["group"] = builtin.IntegerAttr( + assigned_group_id, builtin.IntegerType(64) + ) + + return len(groups) + + def get_qubits_from_observable(self, observable: SSAValue) -> set[SSAValue] | None: + """Get the qubit used by an observable operation. + + Traces back from an observable to find the qubit it operates on. + Handles NamedObsOp, ComputationalBasisOp, HamiltonianOp, and TensorOp. + """ + assert observable.owner is not None, "observable should have an owner" + + acted_qubits = set[SSAValue]() + + obs_op = observable.owner + + # For NamedObsOp, the first operand is the qubit + if isinstance(obs_op, (quantum.NamedObsOp)): + acted_qubits.add(obs_op.operands[0]) + + # For other observable operations, we need to handle multiple qubits + elif isinstance( + obs_op, (quantum.HamiltonianOp, quantum.TensorOp, quantum.ComputationalBasisOp) + ): + raise NotImplementedError(f"unsupported observable operation: {obs_op}") + + return acted_qubits + + def analyze_group_return_positions( + self, func_op: func.FuncOp, num_groups: int + ) -> dict[int, list[int]]: + """Analyze which return value positions belong to each group. + + Returns a dict mapping group_id -> list of final return value positions + Example: {0: [0, 2], 1: [1]} for + return qml.expval(qml.X(0)), qml.expval(qml.X(1)), qml.expval(qml.Y(0)) + """ + # Find the return operation + return_op = list(func_op.body.ops)[-1] + + # For each return value, trace back to find its group + group_positions = {i: [] for i in range(num_groups)} + + for position, return_value in enumerate(return_op.operands): + # Trace back to find the expval operation + group_id = self.find_group_for_return_value(return_value) + if group_id is not None: + group_positions[group_id].append(position) + + return group_positions + + def find_group_for_return_value(self, return_value: SSAValue) -> int | None: + """Trace back from a return value to find which group's expval produced it.""" + # BFS backward to find expval + to_check = [return_value] + checked = set() + + while to_check: + val = to_check.pop(0) + if val in checked: + continue + checked.add(val) + + op = val.owner + + # If we found a measurement operation, check its group + if self.is_measurement_op(op) and "group" in op.attributes: + group_attr = op.attributes["group"] + return group_attr.value.data + + # Otherwise, check operands + to_check.extend([operand for operand in op.operands if operand not in checked]) + + return None + + def replace_original_with_calls( + self, + func_op: func.FuncOp, + dup_functions: list[func.FuncOp], + group_return_positions: dict[int, list[int]], + ): + """Replace original function body with calls to dup functions. + + Args: + dup_functions: List of duplicate functions (one per group) + group_return_positions: Dict mapping group_id -> list of return positions + """ + original_block = func_op.body.block + + for op in reversed(func_op.body.ops): + op.detach() + op.erase() + + # Collect parameters needed for dup function calls + # Dup functions take the same parameters as the original begin/end region + # Look at what original function was using and find corresponding values + call_args = list(original_block.args) # Use function arguments as base + + group_results = dict[int, list[SSAValue]]() # group_id -> list of result values + + for group_id, dup_func in enumerate(dup_functions): + # Get the function signature to determine result types + func_type = dup_func.function_type + result_types = list(func_type.outputs.data) + + # Create the call operation + call_op = func.CallOp(dup_func.sym_name.data, call_args, result_types) + original_block.add_op(call_op) + + # Store results for this group + group_results[group_id] = list(call_op.results) + + # Reconstruct the return statement in the original order + # Calculate total number of return values + total_returns = sum(len(positions) for positions in group_return_positions.values()) + final_return_values = [None] * total_returns + + for group_id, positions in group_return_positions.items(): + group_vals = group_results[group_id] + assert len(group_vals) == len( + positions + ), "number of group values and positions must match" + + for i, position in enumerate(positions): + final_return_values[position] = group_vals[i] + + assert all( + v is not None for v in final_return_values + ), "final return values should not be None" + + # Create new return operation + return_op = func.ReturnOp(*final_return_values) + original_block.add_op(return_op) + + def match_and_rewrite(self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + """Split a quantum function into multiple functions using wire-based grouping. + + Creates one duplicate function per group, where each duplicate function contains + only the measurements from that group. The original function is replaced with + calls to these duplicate functions, and the results are combined in the original + return order. + + Args: + func_op: The function operation to split. + rewriter: The pattern rewriter for creating new operations. + """ + self.module = self.get_parent_of_type(func_op, builtin.ModuleOp) + assert self.module is not None, "got orphaned qnode function" + + # Calculate the number of groups using wires-based grouping strategy + num_groups = self.calculate_num_groups(func_op) + + # Analyze return value positions for each group + group_return_positions = self.analyze_group_return_positions(func_op, num_groups) + + # Create dup function for each group + dup_functions = [] + for i in range(num_groups): + dup_func = self.create_dup_function(func_op, i, rewriter) + dup_functions.append(dup_func) + + # Replace original function body with calls to dup functions + self.replace_original_with_calls(func_op, dup_functions, group_return_positions) diff --git a/frontend/catalyst/python_interface/utils.py b/frontend/catalyst/python_interface/utils.py new file mode 100644 index 0000000000..9db8e39008 --- /dev/null +++ b/frontend/catalyst/python_interface/utils.py @@ -0,0 +1,77 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General purpose utilities to use with xDSL.""" + +from numbers import Number + +from xdsl.dialects.arith import ConstantOp as arithConstantOp +from xdsl.dialects.builtin import ComplexType, ShapedType +from xdsl.dialects.tensor import ExtractOp as tensorExtractOp +from xdsl.ir import SSAValue + +from catalyst.python_interface.dialects.stablehlo import ConstantOp as hloConstantOp + + +def get_constant_from_ssa(value: SSAValue) -> Number | None: + """Return the concrete value corresponding to an SSA value if it is a numerical constant. + + .. note:: + + This function currently only returns constants if they are scalar. For non-scalar + constants, ``None`` will be returned. + + Args: + value (xdsl.ir.SSAValue): the SSA value to check + + Returns: + Number or None: If the value corresponds to a constant, its concrete value will + be returned, else ``None``. + """ + + # If the value has a shape, we can assume that it is not scalar. We check + # this because constant-like operations can return container types. This includes + # arith.constant, which may return containers, and stablehlo.constant, which + # always returns a container. + if not isinstance(value.type, ShapedType): + owner = value.owner + + if isinstance(owner, arithConstantOp): + const_attr = owner.value + return const_attr.value.data + + # Constant-like operations can also create scalars by returning rank 0 tensors. + # In this case, the owner of a scalar value should be a tensor.extract, which + # uses the aforementioned rank 0 constant tensor as input. + if isinstance(owner, tensorExtractOp): + tensor_ = owner.tensor + if ( + len(owner.indices) == 0 + and len(tensor_.type.shape) == 0 + and isinstance(tensor_.owner, (arithConstantOp, hloConstantOp)) + ): + dense_attr = tensor_.owner.value + # We know that the tensor has shape (). Dense element attributes store + # their data as a sequence. For a scalar, this will be a sequence with + # a single element. + val = dense_attr.get_values()[0] + if isinstance(tensor_.type.element_type, ComplexType): + # If the dtype is complex, the value will be a 2-tuple containing + # the real and imaginary components of the number rather than a + # Python complex number + val = val[0] + 1j * val[1] + + return val + + return None diff --git a/frontend/catalyst/python_interface/visualization/__init__.py b/frontend/catalyst/python_interface/visualization/__init__.py new file mode 100644 index 0000000000..0a550976e5 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Visualization functions for Catalyst and xDSL transformations. +""" + + +from .collector import QMLCollector +from .draw import draw +from .mlir_graph import generate_mlir_graph + +__all__ = ["QMLCollector", "draw", "generate_mlir_graph"] diff --git a/frontend/catalyst/python_interface/visualization/collector.py b/frontend/catalyst/python_interface/visualization/collector.py new file mode 100644 index 0000000000..d2c355e0bf --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/collector.py @@ -0,0 +1,146 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the QMLCollector class, +which collects and maps PennyLane operations and measurements from xDSL.""" + +from functools import singledispatchmethod + +import xdsl +from pennylane.measurements import MeasurementProcess +from pennylane.operation import Operator +from xdsl.dialects import builtin, func +from xdsl.ir import SSAValue + +from catalyst.python_interface.dialects.quantum import ( + AllocOp, + CustomOp, + ExpvalOp, + ExtractOp, + GlobalPhaseOp, + MeasureOp, + MultiRZOp, + ProbsOp, + QubitUnitaryOp, + SampleOp, + SetBasisStateOp, + SetStateOp, + StateOp, + VarianceOp, +) + +from .xdsl_conversion import dispatch_wires_extract, xdsl_to_qml_measurement, xdsl_to_qml_op + + +class QMLCollector: + """Collects PennyLane ops and measurements from an xDSL module. + + Walks all `FuncOp`s in the given module, building a mapping of SSA qubits to wire indices, + and converting supported xDSL operations and measurements to PennyLane objects. + """ + + def __init__(self, module: builtin.ModuleOp): + self.module = module + self.wire_to_ssa_qubits: dict[int, SSAValue] = {} + self.quantum_register: SSAValue | None = None + + @singledispatchmethod + def handle(self, xdsl_op: xdsl.ir.Operation) -> None: + """Default handler for unsupported operations.""" + if len(xdsl_op.regions) > 0: + raise NotImplementedError("xDSL operations with regions are not yet supported.") + + ############################################################ + ### Measurements + ############################################################ + + @handle.register + def _(self, xdsl_meas: StateOp) -> MeasurementProcess: + return xdsl_to_qml_measurement(xdsl_meas) + + @handle.register + def _(self, xdsl_meas_op: ExpvalOp | VarianceOp | ProbsOp | SampleOp) -> MeasurementProcess: + obs_op = xdsl_meas_op.obs.owner + return xdsl_to_qml_measurement(xdsl_meas_op, xdsl_to_qml_measurement(obs_op)) + + @handle.register + def _(self, xdsl_measure: MeasureOp) -> MeasurementProcess: + return xdsl_to_qml_measurement(xdsl_measure) + + ############################################################ + ### Operators + ############################################################ + + @handle.register + def _( + self, + xdsl_op: ( + CustomOp | GlobalPhaseOp | QubitUnitaryOp | SetStateOp | MultiRZOp | SetBasisStateOp + ), + ) -> Operator: + if self.quantum_register is None: + raise ValueError("Quantum register (AllocOp) not found.") + if not self.wire_to_ssa_qubits: + raise NotImplementedError("No wires extracted from the register found.") + return xdsl_to_qml_op(xdsl_op) + + ############################################################ + ### Internal Methods + ############################################################ + + # TODO: this will probably no longer be needed once PR #7937 is merged + def _process_qubit_mapping(self, op): + """Populate wire mappings from AllocOp and ExtractOp.""" + + if isinstance(op, AllocOp): + if self.quantum_register is not None: + raise ValueError("Found more than one AllocOp for this FuncOp.") + self.quantum_register = op.qreg + + elif isinstance(op, ExtractOp): + wire = dispatch_wires_extract(op) + if wire not in self.wire_to_ssa_qubits: + self.wire_to_ssa_qubits[wire] = op.qubit + + def clear_mappings(self): + """Clear all wire and parameter mappings.""" + + self.wire_to_ssa_qubits.clear() + self.quantum_register = None + + def collect(self, reset: bool = True) -> tuple[list[Operator], list[MeasurementProcess]]: + """Collect PennyLane ops and measurements from the module.""" + + if reset: + self.clear_mappings() + + collected_ops: list[Operator] = [] + collected_meas: list[MeasurementProcess] = [] + + for func_op in self.module.body.ops: + + if not isinstance(func_op, func.FuncOp): + continue + + for op in func_op.body.ops: + + self._process_qubit_mapping(op) + result = self.handle(op) + + if isinstance(result, Operator): + collected_ops.append(result) + + elif isinstance(result, MeasurementProcess): + collected_meas.append(result) + + return collected_ops, collected_meas diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/visualization/draw.py new file mode 100644 index 0000000000..87088b73a1 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/draw.py @@ -0,0 +1,103 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the `draw` function for the Unified Compiler.""" + +from __future__ import annotations + +import warnings +from functools import wraps +from typing import TYPE_CHECKING + +from pennylane.tape import QuantumScript + +from catalyst import qjit +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.compiler import Compiler + +from .collector import QMLCollector + +if TYPE_CHECKING: + from pennylane.typing import Callable + from pennylane.workflow.qnode import QNode + from xdsl.dialects.builtin import ModuleOp + +# TODO: This caching mechanism should be improved, +# because now it relies on a mutable global state +_cache_store: dict[Callable, dict[int, tuple[str, str]]] = {} + + +def _get_mlir_module(qnode: QNode, args, kwargs) -> ModuleOp: + """Ensure the QNode is compiled and return its MLIR module.""" + if hasattr(qnode, "mlir_module") and qnode.mlir_module is not None: + return qnode.mlir_module + + func = getattr(qnode, "user_function", qnode) + jitted_qnode = qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(func) + jitted_qnode.jit_compile(args, **kwargs) + return jitted_qnode.mlir_module + + +def draw(qnode: QNode, *, level: None | int = None) -> Callable: + """ + Draw the QNode at the specified level. + + This function can be used to visualize the QNode at different stages of the transformation pipeline + when xDSL or Catalyst compilation passes are applied. + If the specified level is not available, the highest level will be used as a fallback. + + The provided QNode is assumed to be decorated with compilation passes. + If no passes are applied, the original QNode is visualized. + + Args: + qnode (.QNode): the input QNode that is to be visualized. The QNode is assumed to be compiled with ``qjit``. + level (None | int): the level of transformation to visualize. If `None`, the final level is visualized. + + + Returns: + Callable: A wrapper function that visualizes the QNode at the specified level. + + """ + cache: dict[int, tuple[str, str]] = _cache_store.setdefault(qnode, {}) + + def _draw_callback(previous_pass, module, next_pass, pass_level=0): + """Callback function for circuit drawing.""" + + pass_instance = previous_pass if previous_pass else next_pass + collector = QMLCollector(module) + ops, meas = collector.collect() + tape = QuantumScript(ops, meas) + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance + cache[pass_level] = ( + tape.draw(show_matrices=False), + pass_name if pass_level else "No transforms", + ) + + @wraps(qnode) + def wrapper(*args, **kwargs): + if args or kwargs: + warnings.warn( + "The `draw` function does not yet support dynamic arguments.\n" + "To visualize the circuit with dynamic parameters or wires, please use the\n" + "`compiler.python_compiler.visualization.generate_mlir_graph` function instead.", + UserWarning, + ) + mlir_module = _get_mlir_module(qnode, args, kwargs) + Compiler.run(mlir_module, callback=_draw_callback) + + if not cache: + return None + + return cache.get(level, cache[max(cache.keys())])[0] + + return wrapper diff --git a/frontend/catalyst/python_interface/visualization/mlir_graph.py b/frontend/catalyst/python_interface/visualization/mlir_graph.py new file mode 100644 index 0000000000..7683feaecf --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/mlir_graph.py @@ -0,0 +1,119 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the implementation of the MLIR graph generation for the Unified Compiler framework. +""" +from __future__ import annotations + +import io +import subprocess +from functools import wraps +from pathlib import Path +from typing import TYPE_CHECKING + +from xdsl.printer import Printer + +from catalyst.compiler import CompileError, _get_catalyst_cli_cmd +from catalyst.python_interface.compiler import Compiler + +from .draw import _get_mlir_module + +if TYPE_CHECKING: + from pennylane import QNode + from pennylane.typing import Callable + +try: + from graphviz import Source as GraphSource + + has_graphviz = True +except (ModuleNotFoundError, ImportError) as import_error: # pragma: no cover + has_graphviz = False + + +# TODO: This interface can be removed once the _quantum_opt interface +# implemented in catalyst.compiler returns the `stderr` output +def _quantum_opt_stderr(*args, stdin=None, stderr_return=False): + """Raw interface to quantum-opt""" + return _catalyst(("--tool", "opt"), *args, stdin=stdin, stderr_return=stderr_return) + + +def _catalyst(*args, stdin=None, stderr_return=False): + """Raw interface to catalyst""" + cmd = _get_catalyst_cli_cmd(*args, stdin=stdin) + try: + result = subprocess.run(cmd, input=stdin, check=True, capture_output=True, text=True) + if stderr_return: + return result.stdout, result.stderr + return result.stdout + except subprocess.CalledProcessError as e: + raise CompileError(f"catalyst failed with error code {e.returncode}: {e.stderr}") from e + + +def _mlir_graph_callback(previous_pass, module, next_pass, pass_level=0): + """Callback function for MLIR graph generation.""" + + pass_instance = previous_pass if previous_pass else next_pass + buffer = io.StringIO() + Printer(stream=buffer, print_generic_format=True).print_op(module) + _, dot_graph = _quantum_opt_stderr( + "--view-op-graph", stdin=buffer.getvalue(), stderr_return=True + ) + graph = GraphSource(dot_graph) + + out_dir = Path("mlir_generated_graphs") + out_dir.mkdir(exist_ok=True) + + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance + pass_name = f"after_{pass_name}" if pass_level else "no_transforms" + out_file = out_dir / f"QNode_level_{pass_level}_{pass_name}.svg" + + with open(out_file, "wb") as f: + f.write(graph.pipe(format="svg")) + + +def generate_mlir_graph(qnode: QNode) -> Callable: + """ + Generate an MLIR graph for the given QNode and saves it to a file. + + This function uses the callback mechanism of the unified compiler framework to generate + the MLIR graph in between compilation passes. The provided QNode is assumed to be decorated with xDSL compilation passes. + The ``qjit`` decorator is used to recompile the QNode with the passes and the provided arguments. + + If no passes are applied, the original QNode is visualized. + + Args: + qnode (.QNode): the input QNode that is to be visualized. + + + Returns: + Callable: A wrapper function that generates the MLIR graph. + + """ + + if not has_graphviz: + raise ImportError( + "This feature requires graphviz, a library for graph visualization. " + "It can be installed with:\n\npip install graphviz" + ) # pragma: no cover + + @wraps(qnode) + def wrapper(*args, **kwargs): + # We re-compile the qnode to ensure the passes are applied + # with the args and kwargs provided by the user. + # TODO: we could integrate the callback mechanism within `qjit`, + # so that we wouldn't need to recompile the qnode twice. + mlir_module = _get_mlir_module(qnode, args, kwargs) + Compiler.run(mlir_module, callback=_mlir_graph_callback) + + return wrapper diff --git a/frontend/catalyst/python_interface/visualization/xdsl_conversion.py b/frontend/catalyst/python_interface/visualization/xdsl_conversion.py new file mode 100644 index 0000000000..1fa496892c --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/xdsl_conversion.py @@ -0,0 +1,330 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file contains the implementation of the QMLCollector class, +which collects and maps PennyLane operations and measurements from xDSL.""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING + +from pennylane import ops +from pennylane.measurements import expval, probs, sample, state, var +from pennylane.operation import Operator +from pennylane.ops import MidMeasure +from pennylane.ops import __all__ as ops_all +from pennylane.ops import measure +from xdsl.dialects.builtin import DenseIntOrFPElementsAttr, IntegerAttr, IntegerType +from xdsl.dialects.tensor import ExtractOp as TensorExtractOp +from xdsl.ir import SSAValue + +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + ExtractOp, + GlobalPhaseOp, + MeasureOp, + MultiRZOp, + NamedObsOp, + QubitUnitaryOp, + SetBasisStateOp, + SetStateOp, +) + +if TYPE_CHECKING: + from pennylane.measurements import MeasurementProcess + +has_jax = True +try: + import jax +except ImportError: + has_jax = False + + +from_str_to_PL_gate = { + name: getattr(ops, name) + for name in ops_all + if inspect.isclass(getattr(ops, name, None)) and issubclass(getattr(ops, name), Operator) +} + +from_str_to_PL_measurement = { + "quantum.state": state, + "quantum.probs": probs, + "quantum.sample": sample, + "quantum.expval": expval, + "quantum.var": var, + "quantum.measure": measure, +} + + +###################################################### +### Gate/Measurement resolution +###################################################### + + +def _resolve(name: str, mapping: dict, kind: str): + try: + return mapping[name] + except KeyError as exc: + raise NotImplementedError(f"Unsupported {kind}: {name}") from exc + + +def resolve_gate(name: str) -> Operator: + """Resolve the gate from the name.""" + return _resolve(name, from_str_to_PL_gate, "gate") + + +def resolve_measurement(name: str) -> MeasurementProcess: + """Resolve the measurement from the name.""" + return _resolve(name, from_str_to_PL_measurement, "measurement") + + +###################################################### +### Helpers +###################################################### + + +def _tensor_shape_from_ssa(ssa: SSAValue) -> list[int]: + """Extract the concrete shape from an SSA tensor value.""" + # pylint: disable= protected-access + tensor_abstr_shape = ssa.owner.operand._type.shape.data + return [dim.data for dim in tensor_abstr_shape] + + +def _extract(op, attr: str, resolver: Callable, single: bool = False): + """Helper to extract and resolve attributes.""" + values = getattr(op, attr, None) + if not values: + return [] if not single else None + return resolver(values) if single else [resolver(v) for v in values if v is not None] + + +def _extract_dense_constant_value(op) -> float | int: + """Extract the first value from a stablehlo.constant op.""" + attr = op.properties.get("value") + if isinstance(attr, DenseIntOrFPElementsAttr): + # TODO: handle multi-value cases if needed + return attr.get_values()[0] + raise NotImplementedError(f"Unexpected attr type in constant: {type(attr)}") + + +def _apply_adjoint_and_ctrls(qml_op: Operator, xdsl_op) -> Operator: + """Apply adjoint and control modifiers to a gate if needed.""" + if xdsl_op.properties.get("adjoint"): + qml_op = ops.op_math.adjoint(qml_op) + ctrls = ssa_to_qml_wires(xdsl_op, control=True) + if ctrls: + cvals = ssa_to_qml_params(xdsl_op, control=True) + qml_op = ops.op_math.ctrl(qml_op, control=ctrls, control_values=cvals) + return qml_op + + +# pylint: disable=too-many-return-statements +def resolve_constant_params(ssa: SSAValue) -> float | int: + """Resolve a constant parameter SSA value to a Python float or int.""" + op = ssa.owner + + if isinstance(op, TensorExtractOp): + return resolve_constant_params(op.tensor) + + match op.name: + + case "arith.addf": + return sum(resolve_constant_params(o) for o in op.operands) + + case "arith.constant": + return op.value.value.data # Catalyst + + case "stablehlo.constant": + return _extract_dense_constant_value(op) + + case "stablehlo.convert" | "stablehlo.broadcast_in_dim": + return resolve_constant_params(op.operands[0]) + + case "stablehlo.concatenate": + return [resolve_constant_params(operand) for operand in op.operands] + + case "stablehlo.reshape": + res_type = op.result_types[0] + shape = res_type.get_shape() + type = res_type.get_element_type() + return jax.numpy.array(shape, dtype=int if isinstance(type, IntegerType) else float) + + case _: + raise NotImplementedError(f"Cannot resolve parameters for operation: {op}") + + +def dispatch_wires_extract(op: ExtractOp): + """Dispatch the wire resolution for the given extract operation.""" + if op.idx_attr is not None: # used by Catalyst + return resolve_constant_wire(op.idx_attr) + return resolve_constant_wire(op.idx) # used by xDSL + + +def resolve_constant_wire(ssa: SSAValue) -> float | int: + """Resolve the wire for the given SSA qubit.""" + if isinstance(ssa, IntegerAttr): # Catalyst + return ssa.value.data + + op = ssa.owner + + match op: + + case TensorExtractOp(tensor=tensor): + return resolve_constant_wire(tensor) + + case _ if op.name == "stablehlo.convert": + return resolve_constant_wire(op.operands[0]) + + case _ if op.name == "stablehlo.constant": + return _extract_dense_constant_value(op) + + case ( + CustomOp() + | GlobalPhaseOp() + | QubitUnitaryOp() + | SetStateOp() + | MultiRZOp() + | SetBasisStateOp() + ): + all_qubits = list(getattr(op, "in_qubits", [])) + list( + getattr(op, "in_ctrl_qubits", []) + ) + return resolve_constant_wire(all_qubits[ssa.index]) + + case ExtractOp(): + return dispatch_wires_extract(op) + + case MeasureOp(in_qubit=in_qubit): + return resolve_constant_wire(in_qubit) + + case _: + raise NotImplementedError(f"Cannot resolve wire for op: {op}") + + +###################################################### +### Parameters/Wires Conversion +###################################################### + + +def ssa_to_qml_params( + op, control: bool = False, single: bool = False +) -> list[float | int] | float | int | None: + """Get the parameters from the operation.""" + return _extract(op, "in_ctrl_values" if control else "params", resolve_constant_params, single) + + +def ssa_to_qml_wires(op: CustomOp, control: bool = False) -> list[int]: + """Get the wires from the operation.""" + return _extract(op, "in_ctrl_qubits" if control else "in_qubits", resolve_constant_wire) + + +def ssa_to_qml_wires_named(op: NamedObsOp) -> int: + """Get the wire from the named observable operation.""" + if not op.qubit: + raise ValueError("No qubit found for named observable operation.") + return resolve_constant_wire(op.qubit) + + +############################################################ +### xDSL ---> PennyLane Operators/Measurements conversion +############################################################ + + +def xdsl_to_qml_op(op) -> Operator: + """Convert an xDSL operation into a PennyLane Operator. + + Args: + op: The xDSL operation to convert. + + Returns: + A PennyLane Operator. + """ + + match op.name: + + case "quantum.gphase": + gate = ops.GlobalPhase(ssa_to_qml_params(op, single=True), wires=ssa_to_qml_wires(op)) + + case "quantum.unitary": + gate = ops.qubit.matrix_ops.QubitUnitary( + U=jax.numpy.zeros(_tensor_shape_from_ssa(op.matrix)), wires=ssa_to_qml_wires(op) + ) + + case "quantum.set_state": + gate = ops.qubit.state_preparation.StatePrep( + state=jax.numpy.zeros(_tensor_shape_from_ssa(op.in_state)), + wires=ssa_to_qml_wires(op), + ) + + case "quantum.multirz": + gate = ops.qubit.parametric_ops_multi_qubit.MultiRZ( + theta=_extract(op, "theta", resolve_constant_params, single=True), + wires=ssa_to_qml_wires(op), + ) + + case "quantum.set_basis_state": + gate = ops.qubit.state_preparation.BasisState( + state=jax.numpy.zeros(_tensor_shape_from_ssa(op.basis_state)), + wires=ssa_to_qml_wires(op), + ) + + case "quantum.custom": + gate_cls = resolve_gate(op.properties.get("gate_name").data) + gate = gate_cls(*ssa_to_qml_params(op), wires=ssa_to_qml_wires(op)) + + case _: + raise NotImplementedError(f"Unsupported gate: {op.name}") + + return _apply_adjoint_and_ctrls(gate, op) + + +def xdsl_to_qml_measurement(op, *args, **kwargs) -> MeasurementProcess | Operator: + """Convert any xDSL measurement/observable operation to a PennyLane object. + + Args: + op: The xDSL measurement/observable operation to convert. + + Returns: + A PennyLane MeasurementProcess or Operator. + """ + + match op.name: + + case "quantum.measure": + postselect = op.postselect.value.data if op.postselect is not None else None + return MidMeasure([resolve_constant_wire(op.in_qubit)], postselect=postselect) + + case "quantum.namedobs": + return resolve_gate(op.type.data.value)(wires=ssa_to_qml_wires_named(op)) + + case "quantum.tensor": + return ops.op_math.prod( + *(xdsl_to_qml_measurement(operand.owner) for operand in op.operands) + ) + + case "quantum.hamiltonian": + coeffs = _extract(op, "coeffs", resolve_constant_params, single=True) + ops_list = [xdsl_to_qml_measurement(term.owner) for term in op.terms] + return ops.LinearCombination(coeffs, ops_list) + case "quantum.compbasis": + return _extract(op, "qubits", resolve_constant_wire) + + case ( + "quantum.state" | "quantum.probs" | "quantum.sample" | "quantum.expval" | "quantum.var" + ): + return resolve_measurement(op.name)(*args, **kwargs) + + case _: + raise NotImplementedError(f"Unsupported measurement/observable: {op.name}") diff --git a/frontend/catalyst/python_interface/xdsl_extras/__init__.py b/frontend/catalyst/python_interface/xdsl_extras/__init__.py new file mode 100644 index 0000000000..094b56f4fc --- /dev/null +++ b/frontend/catalyst/python_interface/xdsl_extras/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains additional utilities and functionality not available upstream in xDSL.""" + +from .constraints import MemRefConstraint, TensorConstraint, NestedTupleOfConstraint +from .traits import ( + AllMatchSameOperatorTrait, + Elementwise, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +__all__ = [ + # Constraints + "NestedTupleOfConstraint", + "MemRefConstraint", + "TensorConstraint", + # Traits + "AllMatchSameOperatorTrait", + "Elementwise", + "SameOperandsAndResultElementType", + "SameOperandsAndResultShape", + "SameOperandsElementType", +] diff --git a/frontend/catalyst/python_interface/xdsl_extras/constraints.py b/frontend/catalyst/python_interface/xdsl_extras/constraints.py new file mode 100644 index 0000000000..8c812e35f2 --- /dev/null +++ b/frontend/catalyst/python_interface/xdsl_extras/constraints.py @@ -0,0 +1,234 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains additional type and attribute constraints that are currently not available +upstream in xDSL.""" + +from abc import ABC, abstractmethod +from collections.abc import Collection, Mapping, Sequence +from dataclasses import dataclass +from typing import TypeVar + +from xdsl.dialects.builtin import ( + ArrayAttr, + IntAttr, + IntAttrConstraint, + MemRefType, + TensorType, + TupleType, +) +from xdsl.ir import Attribute, AttributeInvT +from xdsl.irdl import ( + AnyAttr, + AnyInt, + AttrConstraint, + ConstraintContext, + EqIntConstraint, + IntConstraint, + IntSetConstraint, + IRDLAttrConstraint, + RangeLengthConstraint, + RangeOf, + irdl_to_attr_constraint, +) +from xdsl.utils.exceptions import VerifyException + + +@dataclass(frozen=True, init=False) +class ContainerConstraint(AttrConstraint, ABC): + r"""Internal base class for constraining the element type and shape of container types. + + The shape of the container can be constrained by providing a constraint either directly + for the shape, or by providing a constraint for the rank. There are two ways to provide + an explicit shape constraint: + + * By providing an ``IRDLAttrConstraint`` for the ``shape`` argument. + * By providing a sequence of ``int``\ s specifying the concrete expected shape for + the ``shape`` argument. + + There are three ways to provide the rank constraint: + + * By providing an ``IRDLAttrConstraint`` for the ``rank`` argument. + * By providing an ``int`` representing the concrete expected rank for the ``rank`` + argument. + * By providing a collection of ``int``\ s specifying the various allowed ranks for the + ``rank`` argument. + + .. note:: + + Only one of the ``shape`` or ``rank`` constraint may be provided, not both. + + Args: + element_type (IRDLAttrConstraint | None): The constraint for the element type. + Default is ``None``, which indicates that any element type is allowed. + shape (IRDLAttrConstraint | Sequence[int] | None): The constraint for the shape. + Default is ``None``, which indicates that any shape is allowed. + rank (IRDLAttrConstraint | Collection[int] | int | None): The constraint for the + rank. Default is ``None``, which indicates that any rank is allowed. + """ + + element_type: IRDLAttrConstraint[AttributeInvT] + + shape: IRDLAttrConstraint[AttributeInvT] + + def __init__( + self, + *, + element_type: IRDLAttrConstraint[AttributeInvT] | None = None, + shape: IRDLAttrConstraint[AttributeInvT] | Sequence[int] | None = None, + rank: IRDLAttrConstraint[AttributeInvT] | Collection[int] | int | None = None, + ): + element_type = element_type or AnyAttr() + element_type_constr = element_type + shape_constr = None + + if shape is not None and rank is not None: + raise ValueError("Only one of 'shape' or 'rank' may be provided.") + + if shape is None and rank is None: + shape_constr = AnyAttr() + + elif shape is not None: + shape_constr = shape + if isinstance(shape_constr, Sequence): + shape_constr = ArrayAttr([IntAttr(s) for s in shape]) + + # rank is not None + else: + shape_constr = rank + if isinstance(shape_constr, (int, Collection)): + # Constrain shape to have length `rank` if `rank` is an int + if isinstance(shape_constr, int): + length_constr = EqIntConstraint(shape_constr) + else: + length_constr = IntSetConstraint(frozenset(shape_constr)) + shape_constr = ArrayAttr.constr( + RangeLengthConstraint( + constraint=RangeOf(IntAttrConstraint(AnyInt())), length=length_constr + ) + ) + + if not isinstance(element_type_constr, (Attribute, AttrConstraint)): + raise TypeError( + f"{element_type} is not a valid constraint for the 'element_type' argument. " + "'element_type' must be an AttrConstraint or Attribute." + ) + if not isinstance(shape_constr, (Attribute, AttrConstraint)): + if shape is not None: + raise TypeError( + f"{shape} is not a valid constraint for the 'shape' argument. 'shape' " + "must be an AttrConstraint, Attribute, or sequence of integers." + ) + raise TypeError( + f"{rank} is not a valid constraint for the 'rank' argument. 'rank' must be " + "an AttrConstraint, Attribute, integer, or collection of integers." + ) + + object.__setattr__(self, "element_type", element_type_constr) + object.__setattr__(self, "shape", shape_constr) + + @property + @abstractmethod + def expected_type(self) -> type[Attribute]: + """The expected IR type class (e.g., builtin.TensorType).""" + + @property + def type_name(self) -> str: + """The name of the type for use in error messages (e.g., 'tensor').""" + return self.expected_type.name + + def get_bases(self) -> set[type[Attribute]]: + """Get a set of base types that can satisfy this constraint (e.g., {builtin.TensorType}).""" + return {self.expected_type} + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + # pylint: disable=missing-function-docstring + if not isinstance(attr, self.expected_type): + raise VerifyException(f"{attr} should be of type {self.expected_type.__name__}.") + constr = self.expected_type.constr(element_type=self.element_type, shape=self.shape) + constr.verify(attr, constraint_context) + + def mapping_type_vars( + self, type_var_mapping: dict[TypeVar, AttrConstraint] + ) -> "ContainerConstraint": + # pylint: disable=unused-argument,missing-function-docstring + return self + + +@dataclass(frozen=True, init=False) +class TensorConstraint(ContainerConstraint): + """TensorType constraint for element type and shape.""" + + @property + def expected_type(self): + return TensorType + + +@dataclass(frozen=True, init=False) +class MemRefConstraint(ContainerConstraint): + """MemRefType constraint for element type and shape.""" + + @property + def expected_type(self): + return MemRefType + + +@dataclass(frozen=True, init=False) +class NestedTupleOfConstraint(AttrConstraint[TupleType]): + """Constrain a nested tuple whose flattened leaves all match any allowed constraints.""" + + elem_constraints: tuple[AttrConstraint, ...] + + def __init__(self, elem_constraints: Sequence[object]): + object.__setattr__( + self, + "elem_constraints", + tuple(irdl_to_attr_constraint(c) for c in elem_constraints), + ) + + def get_flattened(self, a: Attribute): + """Get the flattened leaves of a tuple.""" + if isinstance(a, TupleType): + for t in a.types.data: + yield from self.get_flattened(t) + else: + yield a + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + """Verify that the attribute is a tuple of allowed types.""" + if not isinstance(attr, TupleType): + raise VerifyException(f"expected TupleType, got {type(attr)}") + + leaves = list(self.get_flattened(attr)) + + for i, leaf in enumerate(leaves): + matched = False + for constr in self.elem_constraints: + try: + constr.verify(leaf, constraint_context) + matched = True + break + except VerifyException: + # Try next allowed constraint + pass + if not matched: + raise VerifyException(f"tuple leaf {i} failed all allowed constraints: {leaf}") + + def mapping_type_vars( + self, + type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint], + ) -> AttrConstraint: + """Map type variables to constraints.""" + # pylint: disable=unused-argument + return self diff --git a/frontend/catalyst/python_interface/xdsl_extras/traits.py b/frontend/catalyst/python_interface/xdsl_extras/traits.py new file mode 100644 index 0000000000..092241dd82 --- /dev/null +++ b/frontend/catalyst/python_interface/xdsl_extras/traits.py @@ -0,0 +1,240 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Traits for xDSL operations. + +This module provides operation traits that can be used to define operation invariants, +additional semantic information, or to group operations that have similar properties. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from xdsl.dialects.builtin import TensorType, VectorType +from xdsl.ir import Attribute, Operation +from xdsl.traits import OpTrait +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.type import get_element_type_or_self, have_compatible_shape + + +@dataclass(frozen=True) +class SameOperandsAndResultShape(OpTrait): + """Constrain the operation to have the same shape for all operands and results.""" + + # TODO: This trait should be added to ElementwiseBinaryOperation and + # ElementwiseUnaryOperation operations when upstreaming to xdsl. + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same shape for all operands and results.""" + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException(f"'{op.name}' requires at least one result or operand") + + # Get all types (operands and results) to check for compatible shapes + all_types = list(op.operand_types) + list(op.result_types) + + # Check that all types have compatible shapes + for type_to_check in all_types[1:]: + if not have_compatible_shape(all_types[0], type_to_check): + + raise VerifyException( + f"'{op.name}' requires the same shape for all operands and results" + ) + + +@dataclass(frozen=True) +class SameOperandsElementType(OpTrait): + """Constrain the operation to have the same element type for all operands.""" + + # TODO: This trait should be added to ElementwiseBinaryOperation and + # ElementwiseUnaryOperation operations when upstreaming to xdsl. + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same element type for all operands.""" + + if len(op.operands) <= 1: + return + + # Get the element type of the first operand + first_elem_type = get_element_type_or_self(op.operand_types[0]) + + # Check that all other operands have the same element type + for operand_type in op.operand_types[1:]: + elem_type = get_element_type_or_self(operand_type) + if elem_type != first_elem_type: + raise VerifyException( + f"'{op.name}' requires the same element type for all operands" + ) + + +@dataclass(frozen=True) +class SameOperandsAndResultElementType(OpTrait): + """Constrain the operation to have the same element type for all operands and results.""" + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same element type for all operands and results.""" + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException(f"'{op.name}' requires at least one result or operand") + + # Get the element type of the first operand + first_elem_type = get_element_type_or_self(op.operand_types[0]) + + all_types = list(op.operand_types) + list(op.result_types) + + # Check that all other operands have the same element type + for type_to_check in all_types[1:]: + elem_type = get_element_type_or_self(type_to_check) + if elem_type != first_elem_type: + raise VerifyException( + f"'{op.name}' requires the same element type for all operands and results" + ) + + +@dataclass(frozen=True) +class Elementwise(OpTrait): + """ + The following is the definition of the `Elementwise` trait from MLIR: + + https://github.com/llvm/llvm-project/blob/f8cb7987c64dcffb72414a40560055cb717dbf74/mlir/include/mlir/IR/OpDefinition.h#L1378-L1409 + + TODO: Add this trait to all the elementwise operations in xdsl when upstreaming. + + Tags elementwise operations on vectors or tensors. + + NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this trait. + In particular, broadcasting behavior is not allowed. + + An `Elementwise` op must satisfy the following properties: + + 1. If any result is a vector/tensor then at least one operand must also be a + vector/tensor. + 2. If any operand is a vector/tensor then there must be at least one result + and all results must be vectors/tensors. + 3. All operand and result vector/tensor types must be of the same shape. The + shape may be dynamic in which case the op's behaviour is undefined for + non-matching shapes. + 4. The operation must be elementwise on its vector/tensor operands and + results. When applied to single-element vectors/tensors, the result must + be the same per element. + + Rationale: + - 1. and 2. guarantee a well-defined iteration space and exclude the cases + of 0 non-scalar operands or 0 non-scalar results, which complicate a + generic definition of the iteration space. + - 3. guarantees that folding can be done across scalars/vectors/tensors with + the same pattern, as otherwise lots of special handling for type + mismatches would be needed. + - 4. guarantees that no error handling is needed. Higher-level dialects + should reify any needed guards or error handling code before lowering to + an Elementwise op. + """ + + def verify(self, op: Operation) -> None: + """Verify that the operation is elementwise.""" + + # Filter mappable types from results and operands (vectors/tensors only) + result_mappable_types = [t for t in op.result_types if Elementwise.is_mappable_type(t)] + operand_mappable_types = [t for t in op.operand_types if Elementwise.is_mappable_type(t)] + + # If the op only has scalar operand/result types, then we have nothing to check + if not result_mappable_types and not operand_mappable_types: + return + + # If a result is non-scalar, then at least one operand must be non-scalar + if result_mappable_types and not operand_mappable_types: + raise VerifyException( + f"'{op.name}': if a result is non-scalar, then at least one " + "operand must be non-scalar" + ) + + # At this point, operand_mappable_types should not be empty + assert operand_mappable_types, "At least one operand must be a vector or tensor" + + # If an operand is non-scalar, then there must be at least one non-scalar result + if not result_mappable_types: + raise VerifyException( + f"'{op.name}': if an operand is non-scalar, then there must be at " + "least one non-scalar result" + ) + + # If an operand is non-scalar, then all results must be non-scalar + if len(result_mappable_types) != len(op.results): + raise VerifyException( + f"'{op.name}': if an operand is non-scalar, then all results must be non-scalar" + ) + + # All non-scalar operands/results must have the same shape and base type + all_types = operand_mappable_types + result_mappable_types + + # Check that all types have compatible shapes + for type_to_check in all_types[1:]: + if not have_compatible_shape(all_types[0], type_to_check): + raise VerifyException( + f"'{op.name}': all non-scalar operands/results must have the " + "same shape and base type" + ) + + @staticmethod + def is_mappable_type(attr_type: Attribute) -> bool: + """Return True if the type is elementwise-mappable (vector or tensor). + + There is a TODO in MLIR to generalize this trait to avoid hardcoding vector/tensor. + We should update this when the TODO is resolved. + """ + return isinstance(attr_type, (VectorType, TensorType)) + + +@dataclass(frozen=True) +class AllMatchSameOperatorTrait(OpTrait): + """ + Verify that a list of operation attributes all match under the same operator + (e.g., size, rank, type, shape, element type). + + Parameters: + - attr_names: attribute names on the op to compare + - operator: callable taking the attribute value and returning a comparable value + - summary: human-readable name of the property used in error messages + """ + + attr_names: tuple[str, ...] + operator: Callable[[Any], Any] + summary: str + + def verify(self, op: Operation) -> None: + """Verify that the operation attributes all match under the same operator.""" + attributes = [] + for name in self.attr_names: + value = getattr(op, name, None) + if value is None: + return + attributes.append(value) + + if len(attributes) <= 1: + return + + names_str = ", ".join(self.attr_names) + try: + results = [self.operator(attr) for attr in attributes] + except (TypeError, ValueError, AttributeError) as e: + raise VerifyException(f"cannot compute {self.summary} for {{{names_str}}}: {e}") from e + + first = results[0] + if any(res != first for res in results[1:]): + results_str = ", ".join(str(r) for r in results) + raise VerifyException( + f"all of {{{names_str}}} must have the same {self.summary}: got {self.summary}s {results_str}" + ) diff --git a/frontend/test/pytest/python_interface/conftest.py b/frontend/test/pytest/python_interface/conftest.py new file mode 100644 index 0000000000..d0d7cf38ce --- /dev/null +++ b/frontend/test/pytest/python_interface/conftest.py @@ -0,0 +1,203 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytest configuration for tests for the pennylane.compiler.python_compiler submodule.""" + +from inspect import getsource +from io import StringIO + +import pytest + +deps_available = True + +try: + from filecheck.finput import FInput + from filecheck.matcher import Matcher + from filecheck.options import parse_argv_options + from filecheck.parser import Parser, pattern_for_opts + from xdsl.context import Context + from xdsl.dialects import test + from xdsl.passes import PassPipeline + from xdsl.printer import Printer + + from catalyst.python_interface import Compiler, QuantumParser + from catalyst.python_interface.conversion import parse_generic_to_xdsl_module +except (ImportError, ModuleNotFoundError): + deps_available = False + + +def _run_filecheck_impl(program_str, pipeline=(), verify=False, roundtrip=False): + """Run filecheck on an xDSL module, comparing it to a program string containing + filecheck directives.""" + if not deps_available: + return + + ctx = Context() + xdsl_module = QuantumParser(ctx, program_str, extra_dialects=(test.Test,)).parse_module() + + if roundtrip: + # Print generic format + stream = StringIO() + Printer(stream=stream, print_generic_format=True).print_op(xdsl_module) + xdsl_module = QuantumParser(ctx, stream.getvalue()).parse_module() + + if verify: + xdsl_module.verify() + + pipeline = PassPipeline(pipeline) + pipeline.apply(ctx, xdsl_module) + + if verify: + xdsl_module.verify() + + stream = StringIO() + Printer(stream).print_op(xdsl_module) + opts = parse_argv_options(["filecheck", __file__]) + matcher = Matcher( + opts, + FInput("no-name", stream.getvalue()), + Parser(opts, StringIO(program_str), *pattern_for_opts(opts)), + ) + + exit_code = matcher.run() + assert ( + exit_code == 0 + ), f""" + filecheck failed with exit code {exit_code}. + + Original program string: + {program_str} + + Parsed module: + {stream.getvalue()} + """ + + +@pytest.fixture(scope="function") +def run_filecheck(): + """Fixture to run filecheck on an xDSL module. + + This fixture uses FileCheck to verify the correctness of a parsed MLIR string. Testers + can provide a pass pipeline to transform the IR, and verify correctness by including + FileCheck directives as comments in the input program string. + + Args: + program_str (str): The MLIR string containing the input program and FileCheck directives + pipeline (tuple[ModulePass]): A sequence containing all passes that should be applied + before running FileCheck + verify (bool): Whether or not to verify the IR after parsing and transforming. + ``False`` by default. + roundtrip (bool): Whether or not to use round-trip testing. This is useful for dialect + tests to verify that xDSL both parses and prints the IR correctly. If ``True``, we parse + the program string into an xDSL module, print it in generic format, and then parse the + generic program string back to an xDSL module. ``False`` by default. + """ + if not deps_available: + pytest.skip("Cannot run lit tests without xDSL and filecheck.") + + yield _run_filecheck_impl + + +def _get_filecheck_directives(qjit_fn): + """Return a string containing all FileCheck directives in the source function.""" + try: + src = getsource(qjit_fn) + except Exception as e: + raise RuntimeError(f"Could not get source for {qjit_fn}") from e + + filecheck_directives = [] + for line in src.splitlines(): + line = line.strip() + if line[0] != "#": + continue + + line = line[1:].strip() + if line.startswith("CHECK"): + filecheck_directives.append("// " + line) + + return "\n".join(filecheck_directives) + + +def _run_filecheck_qjit_impl(qjit_fn, verify=False): + """Run filecheck on a qjit-ed function, using FileCheck directives in its inline + comments to assert correctness.""" + if not deps_available: + return + + checks = _get_filecheck_directives(qjit_fn) + compiler = Compiler() + mlir_module = compiler.run(qjit_fn.mlir_module) + + # The following is done because ``mlir_module`` will be in the generic syntax, and + # we want as many ops to be pretty printed as possible. + mod_str = mlir_module.operation.get_asm( + binary=False, print_generic_op_form=True, assume_verified=True + ) + xdsl_module = parse_generic_to_xdsl_module(mod_str) + + if verify: + xdsl_module.verify() + + opts = parse_argv_options(["filecheck", __file__]) + matcher = Matcher( + opts, + FInput("no-name", str(xdsl_module)), + Parser(opts, StringIO(checks), *pattern_for_opts(opts)), + ) + + exit_code = matcher.run() + assert exit_code == 0, f"filecheck failed with exit code {exit_code}" + + +@pytest.fixture(scope="function") +def run_filecheck_qjit(): + """Fixture to run filecheck on a qjit-ed function. + + This fixture yields a function that takes a QJIT object as input, parses its + MLIR, applies any passes that are present, and uses FileCheck to check the + output IR against FileCheck directives that may be present in the source + function as inline comments. + + Args: + qjit_fn (Callable): The QJIT object on which we want to run lit tests + verify (bool): Whether or not to verify the IR after parsing and transforming. + ``False`` by default. + + An example showing how to use the fixture is shown below. We apply the + ``merge_rotations_pass`` and check that there is only one rotation in + the final IR: + + .. code-block:: python + + def test_qjit(self, run_filecheck_qjit): + # Test that the merge_rotations_pass works as expected when used with `qjit` + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @merge_rotations_pass + @qml.qnode(dev) + def circuit(x: float, y: float): + # CHECK: [[phi:%.*]] = arith.addf + # CHECK: quantum.custom "RX"([[phi]]) + # CHECK-NOT: quantum.custom + qml.RX(x, 0) + qml.RX(y, 0) + return qml.state() + + run_filecheck_qjit(circuit) + + """ + if not deps_available: + pytest.skip("Cannot run lit tests without xDSL and filecheck.") + + yield _run_filecheck_qjit_impl diff --git a/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py new file mode 100644 index 0000000000..e5af80a305 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py @@ -0,0 +1,108 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/catalyst.py.""" + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + +from catalyst.python_interface.dialects import Catalyst + +all_ops = list(Catalyst.operations) +all_attrs = list(Catalyst.attributes) + +expected_ops_names = { + "AssertionOp": "catalyst.assert", + "CallbackCallOp": "catalyst.callback_call", + "CallbackOp": "catalyst.callback", + "CustomCallOp": "catalyst.custom_call", + "LaunchKernelOp": "catalyst.launch_kernel", + "ListDeallocOp": "catalyst.list_dealloc", + "ListInitOp": "catalyst.list_init", + "ListLoadDataOp": "catalyst.list_load_data", + "ListPopOp": "catalyst.list_pop", + "ListPushOp": "catalyst.list_push", + "PrintOp": "catalyst.print", +} + +expected_attrs_names = { + "ArrayListType": "catalyst.arraylist", +} + + +def test_catalyst_dialect_name(): + """Test that the Catalyst dialect name is correct.""" + assert Catalyst.name == "catalyst" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert ( + expected_name is not None + ), f"Unexpected operation {op_class_name} found in Catalyst dialect" + assert op.name == expected_name + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert ( + expected_name is not None + ), f"Unexpected attribute {attr_class_name} found in Catalyst dialect" + assert attr.name == expected_name + + +def test_assembly_format(run_filecheck): + """Test the assembly format of the catalyst ops.""" + program = """ + // CHECK: [[LIST:%.+]] = catalyst.list_init : !catalyst.arraylist + %list = catalyst.list_init : !catalyst.arraylist + + // CHECK: [[DATA:%.+]] = catalyst.list_load_data [[LIST]] : !catalyst.arraylist -> memref + %data = catalyst.list_load_data %list : !catalyst.arraylist -> memref + + // CHECK: [[VAL:%.+]] = "test.op"() : () -> f64 + %val = "test.op"() : () -> f64 + + // CHECK: [[POP_RESULT:%.+]] = catalyst.list_pop [[LIST]] : !catalyst.arraylist + %pop_result = catalyst.list_pop %list : !catalyst.arraylist + + // CHECK: catalyst.list_push [[VAL]], [[LIST]] : !catalyst.arraylist + catalyst.list_push %val, %list : !catalyst.arraylist + + // CHECK: catalyst.list_dealloc [[LIST]] : !catalyst.arraylist + catalyst.list_dealloc %list : !catalyst.arraylist + + // CHECK: [[CUSTOM_RESULT:%.+]] = catalyst.custom_call fn("custom_function") ([[VAL]]) : (f64) -> f64 + %custom_result = catalyst.custom_call fn("custom_function")(%val) : (f64) -> f64 + + // CHECK: [[KERNEL_RESULT:%.+]] = catalyst.launch_kernel @kernel_name([[VAL]]) : (f64) -> f64 + %kernel_result = catalyst.launch_kernel @kernel_name(%val) : (f64) -> f64 + + // CHECK: [[CALLBACK_RESULT:%.+]] = catalyst.callback_call @callback_func([[VAL]]) : (f64) -> f64 + %callback_result = catalyst.callback_call @callback_func(%val) : (f64) -> f64 + """ + + run_filecheck(program, roundtrip=True) diff --git a/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py new file mode 100644 index 0000000000..3ca61a8ba5 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py @@ -0,0 +1,178 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/mbqc.py.""" + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + +from xdsl.dialects import arith, builtin, test +from xdsl.utils.exceptions import VerifyException + +from catalyst.python_interface.dialects import Quantum, mbqc + +all_ops = list(mbqc.MBQC.operations) +all_attrs = list(mbqc.MBQC.attributes) + +expected_ops_names = { + "MeasureInBasisOp": "mbqc.measure_in_basis", + "GraphStatePrepOp": "mbqc.graph_state_prep", +} + +expected_attrs_names = { + "MeasurementPlaneAttr": "mbqc.measurement_plane", +} + + +def test_mbqc_dialect_name(): + """Test that the MBQCDialect name is correct.""" + assert mbqc.MBQC.name == "mbqc" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert expected_name is not None, f"Unexpected operation {op_class_name} found in MBQCDialect" + assert op.name == expected_name + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert expected_name is not None, f"Unexpected attribute {attr_class_name} found in MBQCDialect" + assert attr.name == expected_name + + +def test_assembly_format(run_filecheck): + """Test the assembly format of the mbqc ops.""" + program = r""" + // CHECK: [[angle:%.+]] = "test.op"() : () -> f64 + %angle = "test.op"() : () -> f64 + + // CHECK: [[qubit:%.+]] = "test.op"() : () -> !quantum.bit + %qubit = "test.op"() : () -> !quantum.bit + + // CHECK: [[mres0:%.+]], [[out_qubit0:%.+]] = mbqc.measure_in_basis{{\s*}}[XY, [[angle]]] [[qubit]] : i1, !quantum.bit + %mres0, %out_qubit0 = mbqc.measure_in_basis [XY, %angle] %qubit : i1, !quantum.bit + + // CHECK: [[mres1:%.+]], [[out_qubit1:%.+]] = mbqc.measure_in_basis{{\s*}}[YZ, [[angle]]] [[qubit]] : i1, !quantum.bit + %mres1, %out_qubit1 = mbqc.measure_in_basis [YZ, %angle] %qubit : i1, !quantum.bit + + // CHECK: [[mres2:%.+]], [[out_qubit2:%.+]] = mbqc.measure_in_basis{{\s*}}[ZX, [[angle]]] [[qubit]] : i1, !quantum.bit + %mres2, %out_qubit2 = mbqc.measure_in_basis [ZX, %angle] %qubit : i1, !quantum.bit + + // CHECK: [[mres3:%.+]], [[out_qubit3:%.+]] = mbqc.measure_in_basis{{\s*}}[XY, [[angle]]] [[qubit]] postselect 0 : i1, !quantum.bit + %mres3, %out_qubit3 = mbqc.measure_in_basis [XY, %angle] %qubit postselect 0 : i1, !quantum.bit + + // CHECK: [[mres4:%.+]], [[out_qubit4:%.+]] = mbqc.measure_in_basis{{\s*}}[XY, [[angle]]] [[qubit]] postselect 1 : i1, !quantum.bit + %mres4, %out_qubit4 = mbqc.measure_in_basis [XY, %angle] %qubit postselect 1 : i1, !quantum.bit + + // COM: Check generic format + // CHECK: {{%.+}}, {{%.+}} = mbqc.measure_in_basis[XY, [[angle]]] [[qubit]] postselect 0 : i1, !quantum.bit + %res:2 = "mbqc.measure_in_basis"(%qubit, %angle) <{plane = #mbqc, postselect = 0 : i32}> : (!quantum.bit, f64) -> (i1, !quantum.bit) + + // CHECK: [[adj_matrix:%.+]] = arith.constant {{.*}} : tensor<6xi1> + // CHECK: [[graph_reg:%.+]] = mbqc.graph_state_prep{{\s*}}([[adj_matrix]] : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + %adj_matrix = arith.constant dense<[1, 0, 1, 0, 0, 1]> : tensor<6xi1> + %graph_reg = mbqc.graph_state_prep (%adj_matrix : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + """ + + run_filecheck(program, roundtrip=True) + + +class TestMeasureInBasisOp: + """Unit tests for the mbqc.measure_in_basis op.""" + + @pytest.mark.parametrize("plane,", ["XY", "YZ", "ZX"]) + @pytest.mark.parametrize("postselect", ["", "postselect 0", "postselect 1"]) + def test_measure_in_basis_properties(self, plane, postselect): + """Test the parsing of the mbqc.measure_in_basis op's properties.""" + program = rf""" + %angle = "test.op"() : () -> f64 + %qubit = "test.op"() : () -> !quantum.bit + + %mres, %out_qubit = mbqc.measure_in_basis [{plane}, %angle] %qubit {postselect} : i1, !quantum.bit + """ + + ctx = xdsl.context.Context() + + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(test.Test) + ctx.load_dialect(Quantum) + ctx.load_dialect(mbqc.MBQC) + + module = xdsl.parser.Parser(ctx, program).parse_module() + + measure_in_basis_op: mbqc.MeasureInBasisOp = module.ops.last + assert isinstance(measure_in_basis_op, mbqc.MeasureInBasisOp) + + assert measure_in_basis_op.properties["plane"].data == plane + + if postselect: + assert measure_in_basis_op.properties["postselect"].value.data == int(postselect[-1]) + else: + assert measure_in_basis_op.properties.get("postselect") is None + + @pytest.mark.parametrize("postselect", [-1, 2]) + def test_invalid_postselect_raises_on_verify(self, postselect): + """Test that using an invalid postselect value (a value other than 0 or 1) raises a + VerifyException during verification.""" + + program = rf""" + %angle = "test.op"() : () -> f64 + %qubit = "test.op"() : () -> !quantum.bit + + %mres, %out_qubit = mbqc.measure_in_basis [XY, %angle] %qubit postselect {postselect} : i1, !quantum.bit + """ + + ctx = xdsl.context.Context() + + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(test.Test) + ctx.load_dialect(Quantum) + ctx.load_dialect(mbqc.MBQC) + + module = xdsl.parser.Parser(ctx, program).parse_module() + + measure_in_basis_op: mbqc.MeasureInBasisOp = module.ops.last + assert isinstance(measure_in_basis_op, mbqc.MeasureInBasisOp) + + with pytest.raises(VerifyException, match="'postselect' must be 0 or 1"): + measure_in_basis_op.verify_() + + @pytest.mark.parametrize("init_op", ["Hadamard", builtin.StringAttr(data="Hadamard")]) + @pytest.mark.parametrize("entangle_op", ["CZ", builtin.StringAttr(data="CZ")]) + def test_graph_state_prep_instantiation(self, init_op, entangle_op): + """Test the instantiation of a mbqc.graph_state_prep op.""" + adj_matrix = [1, 0, 1, 0, 0, 1] + adj_matrix_op = arith.ConstantOp( + builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(builtin.IntegerType(1), shape=(6,)), data=adj_matrix + ) + ) + graph_state_prep_op = mbqc.GraphStatePrepOp(adj_matrix_op.result, init_op, entangle_op) + + assert graph_state_prep_op.adj_matrix == adj_matrix_op.result + assert graph_state_prep_op.init_op == builtin.StringAttr(data="Hadamard") + assert graph_state_prep_op.entangle_op == builtin.StringAttr(data="CZ") diff --git a/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py new file mode 100644 index 0000000000..274e8f4465 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py @@ -0,0 +1,98 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/qec.py.""" + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + +from catalyst.python_interface.dialects import QEC + +all_ops = list(QEC.operations) +all_attrs = list(QEC.attributes) + +expected_ops_names = { + "FabricateOp": "qec.fabricate", + "LayerOp": "qec.layer", + "PPMeasurementOp": "qec.ppm", + "PPRotationOp": "qec.ppr", + "PrepareStateOp": "qec.prepare", + "SelectPPMeasurementOp": "qec.select.ppm", + "YieldOp": "qec.yield", +} + +expected_attrs_names = { + "LogicalInit": "qec.enum", +} + + +def test_qec_dialect_name(): + """Test that the QEC dialect name is correct.""" + assert QEC.name == "qec" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert expected_name is not None, f"Unexpected operation {op_class_name} found in QEC dialect" + assert op.name == expected_name + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert expected_name is not None, f"Unexpected attribute {attr_class_name} found in QEC dialect" + assert attr.name == expected_name + + +def test_assembly_format(run_filecheck): + """Test the assembly format of the qec ops.""" + program = """ + // CHECK: [[QUBIT:%.+]] = "test.op"() : () -> !quantum.bit + %qubit = "test.op"() : () -> !quantum.bit + + // CHECK: [[COND:%.+]] = "test.op"() : () -> i1 + %cond = "test.op"() : () -> i1 + + // CHECK: [[FABRICATED:%.+]] = qec.fabricate magic : !quantum.bit + %fabricated = qec.fabricate magic : !quantum.bit + + // CHECK: [[PREPARED:%.+]] = qec.prepare zero [[QUBIT]] : !quantum.bit + %prepared = qec.prepare zero %qubit : !quantum.bit + + // CHECK: [[ROTATED:%.+]] = qec.ppr ["X", "I", "Z"](4) [[QUBIT]] : !quantum.bit + %rotated = qec.ppr ["X", "I", "Z"](4) %qubit : !quantum.bit + + // CHECK: [[MEASURED:%.+]], [[OUT_QUBITS:%.+]] = qec.ppm ["X", "I", "Z"] [[QUBIT]] : i1, !quantum.bit + %measured, %out_qubits = qec.ppm ["X", "I", "Z"] %qubit : i1, !quantum.bit + + // CHECK: [[MEASURED_COND:%.+]], [[OUT_QUBITS_COND:%.+]] = qec.ppm ["X", "I", "Z"] [[QUBIT]] cond([[COND]]) : i1, !quantum.bit + %measured_cond, %out_qubits_cond = qec.ppm ["X", "I", "Z"] %qubit cond(%cond) : i1, !quantum.bit + + // CHECK: [[SELECT_MEASURED:%.+]], [[SELECT_OUT:%.+]] = qec.select.ppm([[COND]], ["X"], ["Z"]) [[QUBIT]] : i1, !quantum.bit + %select_measured, %select_out = qec.select.ppm (%cond, ["X"], ["Z"]) %qubit : i1, !quantum.bit + + """ + + run_filecheck(program) diff --git a/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py new file mode 100644 index 0000000000..843412be09 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py @@ -0,0 +1,697 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/quantum.py.""" + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + +from xdsl.dialects.builtin import ( + I32, + ComplexType, + Float64Type, + StringAttr, + TensorType, + UnitAttr, + i1, +) +from xdsl.dialects.test import TestOp +from xdsl.ir import AttributeCovT, OpResult + +from catalyst.python_interface.dialects import Quantum +from catalyst.python_interface.dialects.quantum import ( + CustomOp, + NamedObservableAttr, + ObservableType, + QubitType, + QuregType, +) + +all_ops = list(Quantum.operations) +all_attrs = list(Quantum.attributes) + +expected_ops_names = { + "AdjointOp": "quantum.adjoint", + "AllocOp": "quantum.alloc", + "AllocQubitOp": "quantum.alloc_qb", + "ComputationalBasisOp": "quantum.compbasis", + "CountsOp": "quantum.counts", + "CustomOp": "quantum.custom", + "DeallocOp": "quantum.dealloc", + "DeallocQubitOp": "quantum.dealloc_qb", + "DeviceInitOp": "quantum.device", + "DeviceReleaseOp": "quantum.device_release", + "ExpvalOp": "quantum.expval", + "ExtractOp": "quantum.extract", + "FinalizeOp": "quantum.finalize", + "GlobalPhaseOp": "quantum.gphase", + "HamiltonianOp": "quantum.hamiltonian", + "HermitianOp": "quantum.hermitian", + "InitializeOp": "quantum.init", + "InsertOp": "quantum.insert", + "MeasureOp": "quantum.measure", + "MultiRZOp": "quantum.multirz", + "NamedObsOp": "quantum.namedobs", + "NumQubitsOp": "quantum.num_qubits", + "PCPhaseOp": "quantum.pcphase", + "ProbsOp": "quantum.probs", + "QubitUnitaryOp": "quantum.unitary", + "SampleOp": "quantum.sample", + "SetBasisStateOp": "quantum.set_basis_state", + "SetStateOp": "quantum.set_state", + "StateOp": "quantum.state", + "TensorOp": "quantum.tensor", + "VarianceOp": "quantum.var", + "YieldOp": "quantum.yield", +} + +expected_attrs_names = { + "ObservableType": "quantum.obs", + "QubitType": "quantum.bit", + "QuregType": "quantum.reg", + "ResultType": "quantum.res", + "NamedObservableAttr": "quantum.named_observable", +} + +TestOp.__test__ = False +"""Setting this attribute silences the PytestCollectionWarning that TestOp can not be collected +for testing, because it is a class with __init__ method.""" + + +# Test function taken from xdsl/utils/test_value.py +def create_ssa_value(t: AttributeCovT) -> OpResult[AttributeCovT]: + """Create a single SSA value with the given type for testing purposes.""" + op = TestOp(result_types=(t,)) + return op.results[0] + + +q0 = create_ssa_value(QubitType()) +q1 = create_ssa_value(QubitType()) +q2 = create_ssa_value(QubitType()) +qreg = create_ssa_value(QuregType()) +theta = create_ssa_value(Float64Type()) +dim = create_ssa_value(Float64Type()) +pauli_x = NamedObservableAttr("PauliX") +obs = create_ssa_value(ObservableType()) +i = create_ssa_value(I32) +matrix = create_ssa_value(TensorType(element_type=Float64Type, shape=(2, 2))) +coeffs = create_ssa_value(TensorType(Float64Type(), shape=(10,))) +samples = create_ssa_value(TensorType(Float64Type(), shape=(8, 7))) +basis_state = create_ssa_value(TensorType(i1, shape=(8,))) +state = create_ssa_value(TensorType(ComplexType(Float64Type()), shape=(16,))) + +expected_ops_init_kwargs = { + "AdjointOp": {"qreg": qreg, "region": (CustomOp(gate_name="CNOT", in_qubits=(q0, q1)),)}, + "AllocOp": {"nqubits": 3}, + "AllocQubitOp": {}, + "ComputationalBasisOp": {"operands": (q0, None), "result_types": (obs,)}, + "CountsOp": { + "operands": (obs, i, None, None), + "result_types": (TensorType(Float64Type(), shape=(1,)), TensorType(I32, shape=(1,))), + }, + "CustomOp": { + "gate_name": "RX", + "in_qubits": (q0, q1), + "in_ctrl_qubits": (q2,), + "params": (theta,), + "adjoint": True, + }, + "DeallocOp": {"qreg": qreg}, + "DeallocQubitOp": {"qubit": q0}, + "DeviceInitOp": { + "operands": (i,), + "properties": {"lib": StringAttr("lib"), "device_name": StringAttr("my_device")}, + }, + "DeviceReleaseOp": {}, + "ExpvalOp": {"obs": obs}, + "ExtractOp": {"qreg": qreg, "idx": i}, + "FinalizeOp": {}, + "GlobalPhaseOp": {"params": theta, "in_ctrl_qubits": q0}, + "HamiltonianOp": {"operands": (coeffs, (obs,)), "result_types": (obs,)}, + "HermitianOp": {"operands": (matrix, (q0, q1)), "result_types": (obs,)}, + "InitializeOp": {}, + "InsertOp": {"in_qreg": qreg, "idx": i, "qubit": q1}, + "MeasureOp": {"in_qubit": q0, "postselect": i}, + "MultiRZOp": { + "theta": theta, + "in_qubits": (q1, q0), + "in_ctrl_qubits": (q2,), + "in_ctrl_values": (i,), + "adjoint": UnitAttr(), + }, + "NamedObsOp": {"qubit": q0, "obs_type": pauli_x}, + "NumQubitsOp": {"result_types": (i,)}, + "PCPhaseOp": { + "theta": theta, + "dim": dim, + "in_qubits": (q1, q0), + "in_ctrl_qubits": (q2,), + "in_ctrl_values": (i,), + "adjoint": False, + }, + "ProbsOp": { + "operands": (obs, i, None), + "result_types": (TensorType(Float64Type(), shape=(8,)),), + }, + "QubitUnitaryOp": {"matrix": matrix, "in_qubits": (q2,), "adjoint": True}, + "SampleOp": {"operands": (obs, i, samples), "result_types": (samples,)}, + "SetBasisStateOp": {"operands": (basis_state, (q0, q2)), "result_types": ((q1, q2),)}, + "SetStateOp": {"operands": (state, (q0, q1)), "result_types": ((q0, q1),)}, + "StateOp": {"operands": (obs, i, state), "result_types": (state,)}, + "TensorOp": {"operands": ((obs, obs),), "result_types": (obs,)}, + "VarianceOp": {"obs": (obs,)}, + "YieldOp": {"operands": (qreg,)}, +} + + +def test_quantum_dialect_name(): + """Test that the QuantumDialect name is correct.""" + assert Quantum.name == "quantum" + + +@pytest.mark.parametrize("op", all_ops) +def test_all_operations_names(op): + """Test that all operations have the expected name.""" + op_class_name = op.__name__ + expected_name = expected_ops_names.get(op_class_name) + assert ( + expected_name is not None + ), f"Unexpected operation {op_class_name} found in QuantumDialect" + assert op.name == expected_name + + +def test_only_existing_operations_are_expected(): + """Test that the expected operations above only contain existing operations.""" + existing_ops_names = {op.__name__ for op in all_ops} + assert existing_ops_names == set(expected_ops_names) + + +@pytest.mark.parametrize("op", all_ops) +def test_operation_construction(op): + kwargs = expected_ops_init_kwargs[op.__name__] + _ = op(**kwargs) + + +@pytest.mark.parametrize("attr", all_attrs) +def test_all_attributes_names(attr): + """Test that all attributes have the expected name.""" + attr_class_name = attr.__name__ + expected_name = expected_attrs_names.get(attr_class_name) + assert ( + expected_name is not None + ), f"Unexpected attribute {attr_class_name} found in QuantumDialect" + assert attr.name == expected_name + + +def test_only_existing_attributes_are_expected(): + """Test that the expected attributes above only contain existing attributes.""" + existing_attrs_names = {attr.__name__ for attr in all_attrs} + assert existing_attrs_names == set(expected_attrs_names) + + +class TestAssemblyFormat: + """Lit tests for assembly format of operations/attributes in the Quantum + dialect.""" + + def test_qubit_qreg_operations(self, run_filecheck): + """Test that the assembly format for operations for allocation/deallocation of + qubits/quantum registers works correctly.""" + + # Tests for allocation/deallocation ops: AllocOp, DeallocOp, AllocQubitOp, DeallocQubitOp + # Tests for extraction/insertion ops: ExtractOp, InsertOp + program = """ + ////////////////// **Allocation of register with dynamic number of wires** ////////////////// + // CHECK: [[NQUBITS:%.+]] = "test.op"() : () -> i64 + // CHECK: [[QREG_DYN:%.+]] = quantum.alloc([[NQUBITS]]) : !quantum.reg + %nqubits = "test.op"() : () -> i64 + %qreg_dynamic = quantum.alloc(%nqubits) : !quantum.reg + + ////////////////// **Deallocation of dynamic register** ////////////////// + // CHECK: quantum.dealloc [[QREG_DYN]] : !quantum.reg + quantum.dealloc %qreg_dynamic : !quantum.reg + + ////////////////// **Allocation of register with static number of wires** ////////////////// + // CHECK: [[QREG_STATIC:%.+]] = quantum.alloc(10) : !quantum.reg + %qreg_static = quantum.alloc(10) : !quantum.reg + + ////////////////// **Deallocation of static register** ////////////////// + // CHECK: quantum.dealloc [[QREG_STATIC]] : !quantum.reg + quantum.dealloc %qreg_static : !quantum.reg + + ////////////////// **Dynamic qubit allocation** ////////////////// + // CHECK: [[DYN_QUBIT:%.+]] = quantum.alloc_qb : !quantum.bit + %dyn_qubit = quantum.alloc_qb : !quantum.bit + + ////////////////// **Dynamic qubit deallocation** ////////////////// + // CHECK: quantum.dealloc_qb [[DYN_QUBIT]] : !quantum.bit + quantum.dealloc_qb %dyn_qubit : !quantum.bit + + ////////////////////////////////////////////////////// + ////////////////// Quantum register ////////////////// + ////////////////////////////////////////////////////// + // CHECK: [[QREG:%.+]] = "test.op"() : () -> !quantum.reg + %qreg = "test.op"() : () -> !quantum.reg + + ////////////////// **Static qubit extraction** ////////////////// + // CHECK: [[STATIC_QUBIT:%.+]] = quantum.extract [[QREG]][[[STATIC_INDEX:0]]] : !quantum.reg -> !quantum.bit + %static_qubit = quantum.extract %qreg[0] : !quantum.reg -> !quantum.bit + + ////////////////// **Dynamic qubit extraction** ////////////////// + // CHECK: [[DYN_INDEX:%.+]] = "test.op"() : () -> i64 + // CHECK: [[DYN_QUBIT1:%.+]] = quantum.extract [[QREG]][[[DYN_INDEX]]] : !quantum.reg -> !quantum.bit + %dyn_index = "test.op"() : () -> i64 + %dyn_qubit1 = quantum.extract %qreg[%dyn_index] : !quantum.reg -> !quantum.bit + + ////////////////// **Static qubit insertion** ////////////////// + // CHECK: [[QREG1:%.+]] = quantum.insert [[QREG]][[[STATIC_INDEX]]], [[STATIC_QUBIT]] : !quantum.reg, !quantum.bit + %qreg1 = quantum.insert %qreg[0], %static_qubit : !quantum.reg, !quantum.bit + + ////////////////// **Dynamic qubit insertion** ////////////////// + // CHECK: quantum.insert [[QREG1]][[[DYN_INDEX]]], [[DYN_QUBIT1]] : !quantum.reg, !quantum.bit + %qreg2 = quantum.insert %qreg1[%dyn_index], %dyn_qubit1 : !quantum.reg, !quantum.bit + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_quantum_ops(self, run_filecheck): + """Test that the assembly format for quantum non-terminal operations works correctly.""" + + # Tests for CustomOp, GlobalPhaseOp, MeasureOp, MultiRZOp, QubitUnitaryOp + program = """ + //////////////////////////////////////////////////////////////////////// + ////////////////// Qubits, params, and control values ////////////////// + //////////////////////////////////////////////////////////////////////// + ////////////////// **Qubits** ////////////////// + // CHECK: [[Q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q2:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q3:%.+]] = "test.op"() : () -> !quantum.bit + %q0 = "test.op"() : () -> !quantum.bit + %q1 = "test.op"() : () -> !quantum.bit + %q2 = "test.op"() : () -> !quantum.bit + %q3 = "test.op"() : () -> !quantum.bit + + ////////////////// **Params** ////////////////// + // CHECK: [[PARAM1:%.+]] = "test.op"() : () -> f64 + // CHECK: [[PARAM2:%.+]] = "test.op"() : () -> f64 + // CHECK: [[MAT_TENSOR:%.+]] = "test.op"() : () -> tensor<4x4xcomplex> + // CHECK: [[MAT_MEMREF:%.+]] = "test.op"() : () -> memref<4x4xcomplex> + %param1 = "test.op"() : () -> f64 + %param2 = "test.op"() : () -> f64 + %mat_tensor = "test.op"() : () -> tensor<4x4xcomplex> + %mat_memref = "test.op"() : () -> memref<4x4xcomplex> + + ////////////////// **Control values** ////////////////// + // CHECK: [[TRUE_CST:%.+]] = "test.op"() : () -> i1 + // CHECK: [[FALSE_CST:%.+]] = "test.op"() : () -> i1 + %true_cst = "test.op"() : () -> i1 + %false_cst = "test.op"() : () -> i1 + + /////////////////////////////////////////////////////////////////////// + ///////////////////////// **Operation tests** ///////////////////////// + /////////////////////////////////////////////////////////////////////// + + ////////////////// **CustomOp tests** ////////////////// + // No params, no control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.custom "Gate"() [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qc1, %qc2 = quantum.custom "Gate"() %q0, %q1 : !quantum.bit, !quantum.bit + + // Params, no control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.custom "ParamGate"([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qc3, %qc4 = quantum.custom "ParamGate"(%param1, %param2) %q0, %q1 : !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}} = quantum.custom "ControlledGate"() [[Q0]] ctrls([[Q1]]) ctrlvals([[TRUE_CST]]) : !quantum.bit ctrls !quantum.bit + %qc5, %qc6 = quantum.custom "ControlledGate"() %q0 ctrls(%q1) ctrlvals(%true_cst) : !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}} = quantum.custom "AdjGate"() [[Q0]] adj : !quantum.bit + %qc8 = quantum.custom "AdjGate"() %q0 adj : !quantum.bit + + ////////////////// **GlobalPhaseOp tests** ////////////////// + // No control wires + // CHECK: quantum.gphase([[PARAM1]]) : + quantum.gphase(%param1) : + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}} = quantum.gphase([[PARAM1]]) ctrls([[Q0]], [[Q1]]) ctrlvals([[FALSE_CST]], [[TRUE_CST]]) : !quantum.bit, !quantum.bit + %qg1, %qg2 = quantum.gphase(%param1) ctrls(%q0, %q1) ctrlvals(%false_cst, %true_cst) : !quantum.bit, !quantum.bit + + // Adjoint + // CHECK: {{%.+}} = quantum.gphase([[PARAM1]]) {adjoint} ctrls([[Q0]]) ctrlvals([[TRUE_CST]]) : !quantum.bit + %qg3 = quantum.gphase(%param1) {adjoint} ctrls(%q0) ctrlvals(%true_cst) : !quantum.bit + + ////////////////// **MultiRZOp tests** ////////////////// + // No control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.multirz([[PARAM1]]) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qm1, %qm2 = quantum.multirz(%param1) %q0, %q1 : !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}}, {{%.+}} = quantum.multirz([[PARAM1]]) [[Q0]], [[Q1]] ctrls([[Q2]]) ctrlvals([[TRUE_CST]]) : !quantum.bit, !quantum.bit + %qm3, %qm4, %qm5 = quantum.multirz(%param1) %q0, %q1 ctrls(%q2) ctrlvals(%true_cst) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}}, {{%.+}} = quantum.multirz([[PARAM1]]) [[Q0]], [[Q1]] adj : !quantum.bit, !quantum.bit + %qm6, %qm7 = quantum.multirz(%param1) %q0, %q1 adj : !quantum.bit, !quantum.bit + + ////////////////// **PCPhaseOp tests** ////////////////// + // No control wires + // CHECK: {{%.+}}, {{%.+}}, {{%.+}} = quantum.pcphase([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]], [[Q2]] : !quantum.bit, !quantum.bit, !quantum.bit + %qp1, %qp2, %qp3 = quantum.pcphase(%param1, %param2) %q0, %q1, %q2 : !quantum.bit, !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}}, {{%.+}} = quantum.pcphase([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]] ctrls([[Q2]]) ctrlvals([[TRUE_CST]]) : !quantum.bit, !quantum.bit ctrls !quantum.bit + %qp4, %qp5, %qp6 = quantum.pcphase(%param1, %param2) %q0, %q1 ctrls(%q2) ctrlvals(%true_cst) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}}, {{%.+}} = quantum.pcphase([[PARAM1]], [[PARAM2]]) [[Q0]], [[Q1]] adj : !quantum.bit, !quantum.bit + %qp7, %qp8 = quantum.pcphase(%param1, %param2) %q0, %q1 adj : !quantum.bit, !quantum.bit + + ////////////////// **QubitUnitaryOp tests** ////////////////// + // No control wires + // CHECK: {{%.+}}, {{%.+}} = quantum.unitary([[MAT_TENSOR]] : tensor<4x4xcomplex>) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qb1, %qb2 = quantum.unitary(%mat_tensor : tensor<4x4xcomplex>) %q0, %q1 : !quantum.bit, !quantum.bit + + // Control wires and values + // CHECK: {{%.+}}, {{%.+}} {{%.+}} = quantum.unitary([[MAT_TENSOR]] : tensor<4x4xcomplex>) [[Q0]], [[Q1]] ctrls([[Q2]]) ctrlvals([[FALSE_CST]]) : !quantum.bit, !quantum.bit ctrls !quantum.bit + %qb3, %qb4, %qb5 = quantum.unitary(%mat_tensor : tensor<4x4xcomplex>) %q0, %q1 ctrls(%q2) ctrlvals(%false_cst) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + // Adjoint + // CHECK: {{%.+}}, {{%.+}} = quantum.unitary([[MAT_TENSOR]] : tensor<4x4xcomplex>) [[Q0]], [[Q1]] adj : !quantum.bit, !quantum.bit + %qb6, %qb7 = quantum.unitary(%mat_tensor : tensor<4x4xcomplex>) %q0, %q1 adj : !quantum.bit, !quantum.bit + + // MemRef + // CHECK: {{%.+}}, {{%.+}} = quantum.unitary([[MAT_MEMREF]] : memref<4x4xcomplex>) [[Q0]], [[Q1]] : !quantum.bit, !quantum.bit + %qb8, %qb9 = quantum.unitary(%mat_memref : memref<4x4xcomplex>) %q0, %q1 : !quantum.bit, !quantum.bit + + ////////////////// **MeasureOp tests** ////////////////// + // No postselection + // CHECK: {{%.+}}, {{%.+}} = quantum.measure [[Q0]] : i1, !quantum.bit + %mres1, %mqubit1 = quantum.measure %q0 : i1, !quantum.bit + + // Postselection + // CHECK: {{%.+}}, {{%.+}} = quantum.measure [[Q1]] postselect 0 : i1, !quantum.bit + // CHECK: {{%.+}}, {{%.+}} = quantum.measure [[Q2]] postselect 1 : i1, !quantum.bit + %mres2, %mqubit2 = quantum.measure %q1 postselect 0 : i1, !quantum.bit + %mres3, %mqubit3 = quantum.measure %q2 postselect 1 : i1, !quantum.bit + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_state_prep(self, run_filecheck): + """Test that the assembly format for state prep operations works correctly.""" + + # Tests for SetBasisStateOp, SetStateOp + program = """ + //////////////////////////////////////////// + ////////////////// Qubits ////////////////// + //////////////////////////////////////////// + // CHECK: [[Q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q1:%.+]] = "test.op"() : () -> !quantum.bit + %q0 = "test.op"() : () -> !quantum.bit + %q1 = "test.op"() : () -> !quantum.bit + + ////////////////// **SetBasisStateOp tests** ////////////////// + // Basis state containers + // CHECK: [[BASIS_TENSOR:%.+]] = "test.op"() : () -> tensor<2xi1> + // CHECK: [[BASIS_MEMREF:%.+]] = "test.op"() : () -> memref<2xi1> + %basis_tensor = "test.op"() : () -> tensor<2xi1> + %basis_memref = "test.op"() : () -> memref<2xi1> + + // Basis state operations + // CHECK: [[Q2:%.+]], [[Q3:%.+]] = quantum.set_basis_state([[BASIS_TENSOR]]) [[Q0]], [[Q1]] : (tensor<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + // CHECK: [[Q4:%.+]], [[Q5:%.+]] = quantum.set_basis_state([[BASIS_MEMREF]]) [[Q2]], [[Q3]] : (memref<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q2, %q3 = quantum.set_basis_state(%basis_tensor) %q0, %q1 : (tensor<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q4, %q5 = quantum.set_basis_state(%basis_memref) %q2, %q3 : (memref<2xi1>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + + ////////////////// **SetStateOp tests** ////////////////// + // State vector containers + // CHECK: [[STATE_TENSOR:%.+]] = "test.op"() : () -> tensor<4xcomplex> + // CHECK: [[STATE_MEMREF:%.+]] = "test.op"() : () -> memref<4xcomplex> + %state_tensor = "test.op"() : () -> tensor<4xcomplex> + %state_memref = "test.op"() : () -> memref<4xcomplex> + + // State prep operations + // CHECK: [[Q6:%.+]], [[Q7:%.+]] = quantum.set_state([[STATE_TENSOR]]) [[Q4]], [[Q5]] : (tensor<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + // CHECK: quantum.set_state([[STATE_MEMREF]]) [[Q6]], [[Q7]] : (memref<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q6, %q7 = quantum.set_state(%state_tensor) %q4, %q5 : (tensor<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %q8, %q9 = quantum.set_state(%state_memref) %q6, %q7 : (memref<4xcomplex>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_observables(self, run_filecheck): + """Test that the assembly format for observable operations works correctly.""" + + # Tests for observables: ComputationalBasisOp, HamiltonianOp, HermitianOp, + # NamedObsOp, TensorOp + program = """ + ////////////////////////////////////////////////////// + //////////// Quantum register and qubits //////////// + ////////////////////////////////////////////////////// + // CHECK: [[QREG:%.+]] = "test.op"() : () -> !quantum.reg + %qreg = "test.op"() : () -> !quantum.reg + + // CHECK: [[Q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q2:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q3:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[Q4:%.+]] = "test.op"() : () -> !quantum.bit + %q0 = "test.op"() : () -> !quantum.bit + %q1 = "test.op"() : () -> !quantum.bit + %q2 = "test.op"() : () -> !quantum.bit + %q3 = "test.op"() : () -> !quantum.bit + %q4 = "test.op"() : () -> !quantum.bit + + ////////////////////////////////////////////// + //////////// **Observable tests** //////////// + ////////////////////////////////////////////// + + //////////// **NamedObsOp** //////////// + // CHECK: [[X_OBS:%.+]] = quantum.namedobs [[Q0]][PauliX] : !quantum.obs + // CHECK: [[Y_OBS:%.+]] = quantum.namedobs [[Q1]][PauliY] : !quantum.obs + // CHECK: [[Z_OBS:%.+]] = quantum.namedobs [[Q2]][PauliZ] : !quantum.obs + // CHECK: [[H_OBS:%.+]] = quantum.namedobs [[Q3]][Hadamard] : !quantum.obs + // CHECK: [[I_OBS:%.+]] = quantum.namedobs [[Q4]][Identity] : !quantum.obs + %x_obs = quantum.namedobs %q0[PauliX] : !quantum.obs + %y_obs = quantum.namedobs %q1[PauliY] : !quantum.obs + %z_obs = quantum.namedobs %q2[PauliZ] : !quantum.obs + %h_obs = quantum.namedobs %q3[Hadamard] : !quantum.obs + %i_obs = quantum.namedobs %q4[Identity] : !quantum.obs + + //////////// **HermitianOp** //////////// + // Create tensor/memref + // CHECK: [[HERM_TENSOR:%.+]] = "test.op"() : () -> tensor<2x2xcomplex> + // CHECK: [[HERM_MEMREF:%.+]] = "test.op"() : () -> memref<2x2xcomplex> + %herm_tensor = "test.op"() : () -> tensor<2x2xcomplex> + %herm_memref = "test.op"() : () -> memref<2x2xcomplex> + + // Create Hermitians + // CHECK: [[HERM1:%.+]] = quantum.hermitian([[HERM_TENSOR]] : tensor<2x2xcomplex>) [[Q0]] : !quantum.obs + // CHECK: [[HERM2:%.+]] = quantum.hermitian([[HERM_MEMREF]] : memref<2x2xcomplex>) [[Q1]] : !quantum.obs + %herm1 = quantum.hermitian(%herm_tensor : tensor<2x2xcomplex>) %q0 : !quantum.obs + %herm2 = quantum.hermitian(%herm_memref : memref<2x2xcomplex>) %q1 : !quantum.obs + + //////////// **TensorOp** //////////// + // CHECK: [[TENSOR_OBS:%.+]] = quantum.tensor [[X_OBS]], [[HERM2]], [[I_OBS]] : !quantum.obs + %tensor_obs = quantum.tensor %x_obs, %herm2, %i_obs : !quantum.obs + + //////////// **HamiltonianOp** //////////// + // Create tensor/memref + // CHECK: [[HAM_TENSOR:%.+]] = "test.op"() : () -> tensor<3xf64> + // CHECK: [[HAM_MEMREF:%.+]] = "test.op"() : () -> memref<3xf64> + %ham_tensor = "test.op"() : () -> tensor<3xf64> + %ham_memref = "test.op"() : () -> memref<3xf64> + + // Create Hamiltonians + // CHECK: {{%.+}} = quantum.hamiltonian([[HAM_TENSOR]] : tensor<3xf64>) [[TENSOR_OBS]], [[X_OBS]], [[HERM1]] : !quantum.obs + // CHECK: {{%.+}} = quantum.hamiltonian([[HAM_MEMREF]] : memref<3xf64>) [[TENSOR_OBS]], [[X_OBS]], [[HERM1]] : !quantum.obs + %ham1 = quantum.hamiltonian(%ham_tensor : tensor<3xf64>) %tensor_obs, %x_obs, %herm1 : !quantum.obs + %ham2 = quantum.hamiltonian(%ham_memref : memref<3xf64>) %tensor_obs, %x_obs, %herm1 : !quantum.obs + + //////////// **ComputationalBasisOp** //////////// + // CHECK: {{%.+}} = quantum.compbasis qubits [[Q0]], [[Q1]] : !quantum.obs + // CHECK: {{%.+}} = quantum.compbasis qreg [[QREG]] : !quantum.obs + %cb_01 = quantum.compbasis qubits %q0, %q1 : !quantum.obs + %cb_all = quantum.compbasis qreg %qreg : !quantum.obs + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_measurements(self, run_filecheck): + """Test that the assembly format for measurement operations works correctly.""" + + # Tests for measurements: CountsOp, ExpvalOp, MeasureOp, ProbsOp, SampleOp, + # StateOp, VarianceOp + program = """ + /////////////////////////////////////////////////// + //////////// Observables and constants //////////// + /////////////////////////////////////////////////// + // CHECK: [[OBS:%.+]] = "test.op"() : () -> !quantum.obs + %obs = "test.op"() : () -> !quantum.obs + + // CHECK: [[DYN_WIRES:%.+]] = "test.op"() : () -> i64 + %dyn_wires = "test.op"() : () -> i64 + // CHECK: [[DYN_SHOTS:%.+]] = "test.op"() : () -> i64 + %dyn_shots = "test.op"() : () -> i64 + + /////////////////////////////////////////////// + //////////// **Measurement tests** //////////// + /////////////////////////////////////////////// + + ///////////////////// **ExpvalOp** ///////////////////// + // CHECK: {{%.+}} = quantum.expval [[OBS]] : f64 + %expval = quantum.expval %obs : f64 + + ///////////////////// **VarianceOp** ///////////////////// + // CHECK: {{%.+}} = quantum.var [[OBS]] : f64 + %var = quantum.var %obs : f64 + + ///////////////////// **CountsOp** ///////////////////// + // Counts with static shape + // CHECK: {{%.+}}, {{%.+}} = quantum.counts [[OBS]] : tensor<6xf64>, tensor<6xi64> + %eigvals1, %counts1 = quantum.counts %obs : tensor<6xf64>, tensor<6xi64> + + // Counts with dynamic shape + // CHECK: {{%.+}}, {{%.+}} = quantum.counts [[OBS]] shape [[DYN_WIRES]] : tensor, tensor + %eigvals2, %counts2 = quantum.counts %obs shape %dyn_wires : tensor, tensor + + // Counts with no results (mutate memref in-place) + // CHECK: [[EIGVALS_IN:%.+]] = "test.op"() : () -> memref<16xf64> + // CHECK: [[COUNTS_IN:%.+]] = "test.op"() : () -> memref<16xi64> + // CHECK: quantum.counts [[OBS]] in([[EIGVALS_IN]] : memref<16xf64>, [[COUNTS_IN]] : memref<16xi64>) + %eigvals_in = "test.op"() : () -> memref<16xf64> + %counts_in = "test.op"() : () -> memref<16xi64> + quantum.counts %obs in(%eigvals_in : memref<16xf64>, %counts_in : memref<16xi64>) + + ///////////////////// **ProbsOp** ///////////////////// + // Probs with static shape + // CHECK: {{%.+}} = quantum.probs [[OBS]] : tensor<8xf64> + %probs1 = quantum.probs %obs : tensor<8xf64> + + // Probs with dynamic shape + // CHECK: {{%.+}} = quantum.probs [[OBS]] shape [[DYN_WIRES]] : tensor + %probs2 = quantum.probs %obs shape %dyn_wires : tensor + + // Probs with no results (mutate memref in-place) + // CHECK: [[PROBS_IN:%.+]] = "test.op"() : () -> memref<16xf64> + // CHECK: quantum.probs [[OBS]] in([[PROBS_IN]] : memref<16xf64>) + %probs_in = "test.op"() : () -> memref<16xf64> + quantum.probs %obs in(%probs_in : memref<16xf64>) + + ///////////////////// **StateOp** ///////////////////// + // State with static shape + // CHECK: {{%.+}} = quantum.state [[OBS]] : tensor<8xcomplex> + %state1 = quantum.state %obs : tensor<8xcomplex> + + // State with dynamic shape + // CHECK: {{%.+}} = quantum.state [[OBS]] shape [[DYN_WIRES]] : tensor> + %state2 = quantum.state %obs shape %dyn_wires : tensor> + + // State with no results (mutate memref in-place) + // CHECK: [[STATE_IN:%.+]] = "test.op"() : () -> memref<16xcomplex> + // CHECK: quantum.state [[OBS]] in([[STATE_IN]] : memref<16xcomplex>) + %state_in = "test.op"() : () -> memref<16xcomplex> + quantum.state %obs in(%state_in : memref<16xcomplex>) + + ///////////////////// **SampleOp** ///////////////////// + // Samples with static shape + // CHECK: {{%.+}} = quantum.sample [[OBS]] : tensor<10x3xf64> + %samples1 = quantum.sample %obs : tensor<10x3xf64> + + // Samples with dynamic wires + // CHECK: {{%.+}} = quantum.sample [[OBS]] shape [[DYN_WIRES]] : tensor<10x?xf64> + %samples2 = quantum.sample %obs shape %dyn_wires : tensor<10x?xf64> + + // Samples with dynamic shots + // CHECK: {{%.+}} = quantum.sample [[OBS]] shape [[DYN_SHOTS]] : tensor + %samples3 = quantum.sample %obs shape %dyn_shots : tensor + + // Samples with dynamic wires and shots + // CHECK: {{%.+}} = quantum.sample [[OBS]] shape [[DYN_SHOTS]], [[DYN_WIRES]] : tensor + %samples4 = quantum.sample %obs shape %dyn_shots, %dyn_wires : tensor + + // Samples with no results (mutate memref in-place) + // CHECK: [[SAMPLES_IN:%.+]] = "test.op"() : () -> memref<7x4xf64> + // CHECK: quantum.sample [[OBS]] in([[SAMPLES_IN]] : memref<7x4xf64>) + %samples_in = "test.op"() : () -> memref<7x4xf64> + quantum.sample %obs in(%samples_in : memref<7x4xf64>) + """ + + run_filecheck(program, roundtrip=True, verify=True) + + def test_miscellaneous_operations(self, run_filecheck): + """Test that the assembly format for miscelleneous operations + works correctly.""" + + # Tests for AdjointOp, DeviceInitOp, DeviceReleaseOp, FinalizeOp, InitializeOp, + # NumQubitsOp, YieldOp + program = """ + ////////////////////////////////////////// + //////////// Quantum register //////////// + ////////////////////////////////////////// + // CHECK: [[QREG:%.+]] = "test.op"() : () -> !quantum.reg + %qreg = "test.op"() : () -> !quantum.reg + + //////////// **AdjointOp and YieldOp tests** //////////// + // CHECK: quantum.adjoint([[QREG]]) : !quantum.reg { + // CHECK-NEXT: ^bb0([[ARG_QREG:%.+]] : !quantum.reg): + // CHECK-NEXT: quantum.yield [[ARG_QREG]] : !quantum.reg + // CHECK-NEXT: } + %qreg1 = quantum.adjoint(%qreg) : !quantum.reg { + ^bb0(%arg_qreg: !quantum.reg): + quantum.yield %arg_qreg : !quantum.reg + } + + //////////// **DeviceInitOp tests** //////////// + // Integer SSA value for shots + // CHECK: [[SHOTS:%.+]] = "test.op"() : () -> i64 + %shots = "test.op"() : () -> i64 + + // No auto qubit management + // CHECK: quantum.device shots([[SHOTS]]) ["foo", "bar", "baz"] + quantum.device shots(%shots) ["foo", "bar", "baz"] + + // Auto qubit management + // CHECK: quantum.device shots([[SHOTS]]) ["foo", "bar", "baz"] {auto_qubit_management} + quantum.device shots(%shots) ["foo", "bar", "baz"] {auto_qubit_management} + + //////////// **DeviceReleaseOp tests** //////////// + // CHECK: quantum.device_release + quantum.device_release + + //////////// **FinalizeOp tests** //////////// + // CHECK: quantum.finalize + quantum.finalize + + //////////// **InitializeOp tests** //////////// + // CHECK: quantum.init + quantum.init + + //////////// **NumQubitsOp tests** //////////// + // CHECK: quantum.num_qubits : i64 + %nqubits = quantum.num_qubits : i64 + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py new file mode 100644 index 0000000000..ea4492d690 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py @@ -0,0 +1,964 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/dialects/stablehlo.py.""" + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + + +def test_all_unary_operations(run_filecheck): + """Test all unary elementwise operations.""" + program = r""" + // CHECK: %[[tf32:.*]] = "test.op"() : () -> tensor + %tf32 = "test.op"() : () -> tensor + + // CHECK: %[[tf64:.*]] = "test.op"() : () -> tensor + %tf64 = "test.op"() : () -> tensor + + // CHECK: %[[tcomplex:.*]] = "test.op"() : () -> tensor> + %tcomplex = "test.op"() : () -> tensor> + + // CHECK: %convert = "stablehlo.convert"(%[[tf32]]) : (tensor) -> tensor + %convert = "stablehlo.convert"(%tf32) : (tensor) -> tensor + + // CHECK: %cos = "stablehlo.cosine"(%[[tf32]]) : (tensor) -> tensor + %cos = "stablehlo.cosine"(%tf32) : (tensor) -> tensor + + // CHECK: %exp = "stablehlo.exponential"(%[[tf32]]) : (tensor) -> tensor + %exp = "stablehlo.exponential"(%tf32) : (tensor) -> tensor + + // CHECK: %exponential_minus_one = "stablehlo.exponential_minus_one"(%[[tf32]]) : (tensor) -> tensor + %exponential_minus_one = "stablehlo.exponential_minus_one"(%tf32) : (tensor) -> tensor + + // CHECK: %floor = "stablehlo.floor"(%[[tf64]]) : (tensor) -> tensor + %floor = "stablehlo.floor"(%tf64) : (tensor) -> tensor + + // CHECK: %imag = "stablehlo.imag"(%[[tcomplex]]) : (tensor>) -> tensor + %imag = "stablehlo.imag"(%tcomplex) : (tensor>) -> tensor + + // CHECK: %is_finite = "stablehlo.is_finite"(%[[tf32]]) : (tensor) -> tensor + %is_finite = "stablehlo.is_finite"(%tf32) : (tensor) -> tensor + + // CHECK: %log = "stablehlo.log"(%[[tf32]]) : (tensor) -> tensor + %log = "stablehlo.log"(%tf32) : (tensor) -> tensor + + // CHECK: %log_plus_one = "stablehlo.log_plus_one"(%[[tf64]]) : (tensor) -> tensor + %log_plus_one = "stablehlo.log_plus_one"(%tf64) : (tensor) -> tensor + + // CHECK: %logistic = "stablehlo.logistic"(%[[tf32]]) : (tensor) -> tensor + %logistic = "stablehlo.logistic"(%tf32) : (tensor) -> tensor + + // CHECK: %negate = "stablehlo.negate"(%[[tf32]]) : (tensor) -> tensor + %negate = "stablehlo.negate"(%tf32) : (tensor) -> tensor + + // CHECK: %real = "stablehlo.real"(%[[tcomplex]]) : (tensor>) -> tensor + %real = "stablehlo.real"(%tcomplex) : (tensor>) -> tensor + + // CHECK: %round_afz = "stablehlo.round_nearest_afz"(%[[tf64]]) : (tensor) -> tensor + %round_afz = "stablehlo.round_nearest_afz"(%tf64) : (tensor) -> tensor + + // CHECK: %round_even = "stablehlo.round_nearest_even"(%[[tf64]]) : (tensor) -> tensor + %round_even = "stablehlo.round_nearest_even"(%tf64) : (tensor) -> tensor + + // CHECK: %rsqrt = "stablehlo.rsqrt"(%[[tf32]]) : (tensor) -> tensor + %rsqrt = "stablehlo.rsqrt"(%tf32) : (tensor) -> tensor + + // CHECK: %sign = "stablehlo.sign"(%[[tf32]]) : (tensor) -> tensor + %sign = "stablehlo.sign"(%tf32) : (tensor) -> tensor + + // CHECK: %sin = "stablehlo.sine"(%[[tf32]]) : (tensor) -> tensor + %sin = "stablehlo.sine"(%tf32) : (tensor) -> tensor + + // CHECK: %sqrt = "stablehlo.sqrt"(%[[tf32]]) : (tensor) -> tensor + %sqrt = "stablehlo.sqrt"(%tf32) : (tensor) -> tensor + + // CHECK: %tan = "stablehlo.tan"(%[[tf64]]) : (tensor) -> tensor + %tan = "stablehlo.tan"(%tf64) : (tensor) -> tensor + + // CHECK: %tanh = "stablehlo.tanh"(%[[tf32]]) : (tensor) -> tensor + %tanh = "stablehlo.tanh"(%tf32) : (tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_all_binary_operations(run_filecheck): + """Test all binary elementwise operations.""" + program = r""" + // CHECK: %[[tf32_1:.*]] = "test.op"() : () -> tensor + %tf32_1 = "test.op"() : () -> tensor + + // CHECK: %[[tf32_2:.*]] = "test.op"() : () -> tensor + %tf32_2 = "test.op"() : () -> tensor + + // CHECK: %[[tf64_1:.*]] = "test.op"() : () -> tensor + %tf64_1 = "test.op"() : () -> tensor + + // CHECK: %[[tf64_2:.*]] = "test.op"() : () -> tensor + %tf64_2 = "test.op"() : () -> tensor + + // CHECK: %complex = "stablehlo.complex"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor> + %complex = "stablehlo.complex"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor> + + // CHECK: %divide = "stablehlo.divide"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %divide = "stablehlo.divide"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %maximum = "stablehlo.maximum"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %maximum = "stablehlo.maximum"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %minimum = "stablehlo.minimum"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %minimum = "stablehlo.minimum"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %power = "stablehlo.power"(%[[tf64_1]], %[[tf64_2]]) : (tensor, tensor) -> tensor + %power = "stablehlo.power"(%tf64_1, %tf64_2) : (tensor, tensor) -> tensor + + // CHECK: %remainder = "stablehlo.remainder"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %remainder = "stablehlo.remainder"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_all_other_operations(run_filecheck): + """Test all other elementwise operations.""" + program = r""" + // CHECK: %[[tf32:.*]] = "test.op"() : () -> tensor + %tf32 = "test.op"() : () -> tensor + + // CHECK: %[[tf64:.*]] = "test.op"() : () -> tensor + %tf64 = "test.op"() : () -> tensor + + // CHECK: %[[ti1:.*]] = "test.op"() : () -> tensor + %ti1 = "test.op"() : () -> tensor + + // CHECK: %clamp = "stablehlo.clamp"(%[[tf32]], %[[tf32]], %[[tf32]]) : (tensor, tensor, tensor) -> tensor + %clamp = "stablehlo.clamp"(%tf32, %tf32, %tf32) : (tensor, tensor, tensor) -> tensor + + // CHECK: %compare = stablehlo.compare EQ, %[[tf32]], %[[tf32]] : (tensor, tensor) -> tensor + %compare = "stablehlo.compare"(%tf32, %tf32) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + + // CHECK: %map = "stablehlo.map"(%[[tf32]], %[[tf32]]) ({ + // CHECK: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK: %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%0) : (tensor) -> () + // CHECK: }) {dimensions = array} : (tensor, tensor) -> tensor + %map = "stablehlo.map"(%tf32, %tf32) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + dimensions = array + } : (tensor, tensor) -> tensor + + // CHECK: %reduce_precision = "stablehlo.reduce_precision"(%[[tf64]]) {exponent_bits = 5 : i32, mantissa_bits = 10 : i32} : (tensor) -> tensor + %reduce_precision = "stablehlo.reduce_precision"(%tf64) {exponent_bits = 5 : i32, mantissa_bits = 10 : i32} : (tensor) -> tensor + + // CHECK: %select = "stablehlo.select"(%[[ti1]], %[[tf32]], %[[tf32]]) : (tensor, tensor, tensor) -> tensor + %select = "stablehlo.select"(%ti1, %tf32, %tf32) : (tensor, tensor, tensor) -> tensor + + // CHECK: %constant1 = "stablehlo.constant"() <{value = dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + %constant1 = "stablehlo.constant"() <{value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_shape_mismatch(run_filecheck): + """Test that operations with shape mismatches are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + %tf64_3x2 = "test.op"() : () -> tensor<3x2xf64> + + // This should fail verification due to shape mismatch + %convert = "stablehlo.convert"(%tf32_2x3) : (tensor<2x3xf32>) -> tensor<3x2xf64> + """ + + with pytest.raises( + Exception, match="all non-scalar operands/results must have the same shape and base type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_type_mismatch(run_filecheck): + """Test that operations with type mismatches are properly rejected.""" + program = r""" + %ti32 = "test.op"() : () -> tensor<2x3xi32> + + // This should fail verification due to type mismatch (cosine expects float/complex) + %cos = "stablehlo.cosine"(%ti32) : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match="'operand' at position 0 does not verify"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_missing_operands(run_filecheck): + """Test that operations with missing operands are properly rejected.""" + program = r""" + %result = "stablehlo.convert"() : () -> tensor<2x3xf64> + """ + + with pytest.raises(Exception, match="Expected 1 operand"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_trait_verification_failure(run_filecheck): + """Test that operations that violate trait constraints are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + %tf64_3x2 = "test.op"() : () -> tensor<3x2xf64> + + // This should fail verification due to shape mismatch between operands + %complex = "stablehlo.complex"(%tf32_2x3, %tf64_3x2) : (tensor<2x3xf32>, tensor<3x2xf64>) -> tensor<2x3xcomplex> + """ + + with pytest.raises(Exception, match="requires the same shape"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_operand_result_shape_mismatch(run_filecheck): + """Test that operations with operand vs result shape mismatches are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + + // This should fail verification due to shape mismatch between operand and result + %convert = "stablehlo.convert"(%tf32_2x3) : (tensor<2x3xf32>) -> tensor<3x2xf64> + """ + + with pytest.raises( + Exception, match="all non-scalar operands/results must have the same shape and base type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_control_flow_operations(run_filecheck): + """Test the IfOp operation.""" + program = r""" + // Test IfOp: + + // CHECK: %[[pred:.*]] = "test.op"() : () -> tensor + %pred = "test.op"() : () -> tensor + + // CHECK: %[[result:.*]] = "stablehlo.if"(%[[pred]]) ({ + // CHECK: "stablehlo.return"(%[[pred]]) : (tensor) -> () + // CHECK: }, { + // CHECK: "stablehlo.return"(%[[pred]]) : (tensor) -> () + // CHECK: }) : (tensor) -> tensor + %result = "stablehlo.if"(%pred) ({ + "stablehlo.return"(%pred) : (tensor) -> () + }, { + "stablehlo.return"(%pred) : (tensor) -> () + }) : (tensor) -> tensor + + // Test WhileOp: + + // CHECK: %[[init_i:.*]] = "test.op"() : () -> tensor + %init_i = "test.op"() : () -> tensor + + // CHECK: %[[init_sum:.*]] = "test.op"() : () -> tensor + %init_sum = "test.op"() : () -> tensor + + // CHECK: %[[ten:.*]] = "test.op"() : () -> tensor + %ten = "test.op"() : () -> tensor + + // CHECK: %[[one:.*]] = "test.op"() : () -> tensor + %one = "test.op"() : () -> tensor + + // CHECK: %[[results:.*]], %[[results_1:.*]] = "stablehlo.while"(%[[init_i]], %[[init_sum]]) ({ + // CHECK: ^{{.*}}(%[[arg0:.*]] : tensor, %[[arg1:.*]] : tensor): + // CHECK: %[[cond:.*]] = stablehlo.compare LT, %[[arg0]], %[[ten]] : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%[[cond]]) : (tensor) -> () + // CHECK: }, { + // CHECK: ^{{.*}}(%[[arg0_1:.*]] : tensor, %[[arg1_1:.*]] : tensor): + // CHECK: %[[new_sum:.*]] = "stablehlo.add"(%[[arg1_1]], %[[one]]) : (tensor, tensor) -> tensor + // CHECK: %[[new_i:.*]] = "stablehlo.add"(%[[arg0_1]], %[[one]]) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%[[new_i]], %[[new_sum]]) : (tensor, tensor) -> () + // CHECK: }) : (tensor, tensor) -> (tensor, tensor) + %results:2 = "stablehlo.while"(%init_i, %init_sum) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %cond = "stablehlo.compare"(%arg0, %ten) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%cond) : (tensor) -> () + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + %new_sum = "stablehlo.add"(%arg1, %one) : (tensor, tensor) -> tensor + %new_i = "stablehlo.add"(%arg0, %one) : (tensor, tensor) -> tensor + "stablehlo.return"(%new_i, %new_sum) : (tensor, tensor) -> () + }) : (tensor, tensor) -> (tensor, tensor) + + // Test OptimizationBarrierOp: + + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor + %operand = "test.op"() : () -> tensor + + // CHECK: %[[result2:.*]] = "stablehlo.optimization_barrier"(%[[operand]]) : (tensor) -> tensor + %result2 = "stablehlo.optimization_barrier"(%operand) : (tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_data_movement_operations(run_filecheck): + """Test all data movement operations.""" + program = r""" + ////////////////// Setup test operations ////////////////// + // CHECK: %[[input1:.*]] = "test.op"() : () -> tensor<3x2xi64> + %input1 = "test.op"() : () -> tensor<3x2xi64> + + // CHECK: %[[input2:.*]] = "test.op"() : () -> tensor<1x2xi64> + %input2 = "test.op"() : () -> tensor<1x2xi64> + + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor<2x3x4x2xi32> + %operand = "test.op"() : () -> tensor<2x3x4x2xi32> + + // CHECK: %[[start_indices:.*]] = "test.op"() : () -> tensor<2x2x3x2xi64> + %start_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // CHECK: %[[reshape_input:.*]] = "test.op"() : () -> tensor<2xf32> + %reshape_input = "test.op"() : () -> tensor<2xf32> + + // CHECK: %[[scatter_input:.*]] = "test.op"() : () -> tensor<2x3x4x2xi64> + %scatter_input = "test.op"() : () -> tensor<2x3x4x2xi64> + + // CHECK: %[[scatter_indices:.*]] = "test.op"() : () -> tensor<2x2x3x2xi64> + %scatter_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // CHECK: %[[scatter_updates:.*]] = "test.op"() : () -> tensor<2x2x3x2x2xi64> + %scatter_updates = "test.op"() : () -> tensor<2x2x3x2x2xi64> + + // CHECK: %[[slice_input:.*]] = "test.op"() : () -> tensor<3x4xi64> + %slice_input = "test.op"() : () -> tensor<3x4xi64> + + // CHECK: %[[broadcast_input:.*]] = "test.op"() : () -> tensor<1x3xi32> + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + ////////////////// Test ConcatenateOp ////////////////// + // CHECK: %concatenate = "stablehlo.concatenate"(%[[input1]], %[[input2]]) <{dimension = 0 : i64}> : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + %concatenate = "stablehlo.concatenate"(%input1, %input2) {dimension = 0 : i64} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + + ////////////////// Test GatherOp ////////////////// + // CHECK: %gather = "stablehlo.gather"(%[[operand]], %[[start_indices]]) + // CHECK-SAME: dimension_numbers = #stablehlo.gather< + // CHECK-NEXT: offset_dims = [3, 4], + // CHECK-NEXT: collapsed_slice_dims = [1], + // CHECK-NEXT: operand_batching_dims = [0], + // CHECK-NEXT: start_indices_batching_dims = [1], + // CHECK-NEXT: start_index_map = [2, 1], + // CHECK-NEXT: index_vector_dim = 3 + // CHECK-NEXT: slice_sizes = array, indices_are_sorted = false + %gather = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32> + + ////////////////// Test ReshapeOp ////////////////// + // CHECK: %reshape = stablehlo.reshape %[[reshape_input]] : (tensor<2xf32>) -> tensor<1x2xf32> + %reshape = "stablehlo.reshape"(%reshape_input) : (tensor<2xf32>) -> tensor<1x2xf32> + + ////////////////// Test ScatterOp ////////////////// + // CHECK: %scatter = "stablehlo.scatter"(%[[scatter_input]], %[[scatter_indices]], %[[scatter_updates]]) + // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter< + // CHECK-NEXT: update_window_dims = [3, 4], + // CHECK-NEXT: inserted_window_dims = [1], + // CHECK-NEXT: input_batching_dims = [0], + // CHECK-NEXT: scatter_indices_batching_dims = [1], + // CHECK-NEXT: scatter_dims_to_operand_dims = [2, 1], + // CHECK-NEXT: index_vector_dim = 3 + // CHECK-NEXT: indices_are_sorted = false, unique_indices = false + // CHECK-NEXT: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK-NEXT: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK-NEXT: "stablehlo.return"(%0) : (tensor) -> () + %scatter = "stablehlo.scatter"(%scatter_input, %scatter_indices, %scatter_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + + ////////////////// Test SliceOp ////////////////// + // CHECK: %slice = "stablehlo.slice"(%[[slice_input]]) + // CHECK-SAME: start_indices = array, + // CHECK-SAME: limit_indices = array, + // CHECK-SAME: strides = array + // CHECK-SAME: : (tensor<3x4xi64>) -> tensor<2x2xi64> + %slice = "stablehlo.slice"(%slice_input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xi64>) -> tensor<2x2xi64> + + ////////////////// Test BroadcastInDimOp ////////////////// + // CHECK: %broadcast = stablehlo.broadcast_in_dim %[[broadcast_input]], dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + %broadcast = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + + ////////////////// Test DynamicSliceOp ////////////////// + // CHECK: %[[dyn_operand:.*]] = "test.op"() : () -> tensor<4x4xi32> + %dyn_operand = "test.op"() : () -> tensor<4x4xi32> + + // CHECK: %[[start0:.*]] = "test.op"() : () -> tensor + %start0 = "test.op"() : () -> tensor + + // CHECK: %[[start1:.*]] = "test.op"() : () -> tensor + %start1 = "test.op"() : () -> tensor + + // CHECK: %dynamic_slice = "stablehlo.dynamic_slice"(%[[dyn_operand]], %[[start0]], %[[start1]]) + // CHECK-SAME: slice_sizes = array + // CHECK-SAME: : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x3xi32> + %dynamic_slice = "stablehlo.dynamic_slice"(%dyn_operand, %start0, %start1) { + slice_sizes = array + } : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_slice_operations(run_filecheck): + """Test invalid slice operations that should fail verification.""" + program_slice_mismatch = r""" + // CHECK: %input = "test.op"() : () -> tensor<3x8xi64> + %input = "test.op"() : () -> tensor<3x8xi64> + + // This should fail verification due to mismatched array sizes + // CHECK: %slice = "stablehlo.slice"(%input) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x8xi64>) -> tensor<2x2xi64> + %slice = "stablehlo.slice"(%input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x8xi64>) -> tensor<2x2xi64> + """ + + with pytest.raises( + Exception, + match="all of \\{start_indices, limit_indices, strides\\} must have the same size: got sizes 2, 3, 2", + ): + run_filecheck(program_slice_mismatch, roundtrip=True, verify=True) + + +def test_invalid_slice_element_type_mismatch(run_filecheck): + """Test that SliceOp rejects mismatched operand/result element types.""" + program = r""" + %slice_input = "test.op"() : () -> tensor<3x4xi64> + // CHECK: %slice_input = "test.op"() : () -> tensor<3x4xi64> + // Mismatched element type: operand is i64, result is f32 + %slice = "stablehlo.slice"(%slice_input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xi64>) -> tensor<2x2xf32> + """ + + # Expect verification failure due to element type mismatch + with pytest.raises( + Exception, match="requires the same element type for all operands and results" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_gather_element_type_mismatch(run_filecheck): + """Test that GatherOp rejects mismatched operand/result element types.""" + program = r""" + %operand = "test.op"() : () -> tensor<2x3x4x2xi32> + %start_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // Mismatched element type: operand is i32, result is f32 + %gather_bad = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xf32> + """ + + # Expect verification failure due to element type mismatch between operand and result + with pytest.raises( + Exception, match=r"all of \{operand, result\} must have the same element type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_reshape_operations(run_filecheck): + """Test invalid reshape operations that should fail verification.""" + program_reshape_mismatch = r""" + %reshape_input = "test.op"() : () -> tensor<2xf32> + + // This should fail verification due to element count mismatch (2 != 4) + %reshape_bad = "stablehlo.reshape"(%reshape_input) : (tensor<2xf32>) -> tensor<2x2xf32> + """ + + with pytest.raises(Exception, match="number of output elements"): + run_filecheck(program_reshape_mismatch, roundtrip=True, verify=True) + + +def test_invalid_broadcast_in_dim_operations(run_filecheck): + """Test invalid broadcast_in_dim operations that should fail verification.""" + # Test dims size mismatch. + program_broadcast_dims_size_mismatch = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // dims has size 1, but operand rank is 2 + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions size .* does not match operand rank"): + run_filecheck(program_broadcast_dims_size_mismatch, roundtrip=True, verify=True) + + # Test duplicate dims. + program_broadcast_duplicate_dims = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // duplicate entries in broadcast_dimensions are not allowed + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions should not have duplicates"): + run_filecheck(program_broadcast_duplicate_dims, roundtrip=True, verify=True) + + # Test dim index out of bounds. + program_broadcast_dim_oob = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // result rank is 2, but dims contains 2 (out of bounds) + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions contains invalid value"): + run_filecheck(program_broadcast_dim_oob, roundtrip=True, verify=True) + + # Test operand dim not 1 and not equal to result dim. + program_broadcast_dim_mismatch = r""" + %broadcast_input = "test.op"() : () -> tensor<2x3xi32> + + // operand[0] = 2, result[0] = 4; dims = [0, 2] -> mismatch on dim 0 + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<2x3xi32>) -> tensor<4x3x2xi32> + """ + + with pytest.raises( + Exception, + match="size of operand dimension .* is not equal to 1 or size of result dimension", + ): + run_filecheck(program_broadcast_dim_mismatch, roundtrip=True, verify=True) + + +def test_dynamism_operations(run_filecheck): + """Test all dynamism operations.""" + program = r""" + ////////////////// Setup ////////////////// + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor<1x3xi64> + %operand = "test.op"() : () -> tensor<1x3xi64> + + // CHECK: %[[out_dims:.*]] = "test.op"() : () -> tensor<3xi64> + %out_dims = "test.op"() : () -> tensor<3xi64> + + ////////////////// Test DynamicBroadcastInDimOp ////////////////// + // CHECK: %dynamic_bcast = stablehlo.dynamic_broadcast_in_dim %[[operand]], %[[out_dims]], dims = [2, 1] : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + %dynamic_bcast = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out_dims) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_reduction_operations(run_filecheck): + """Test all reduction operations.""" + program = r""" + ////////////////// Setup ////////////////// + // CHECK: %[[input:.*]] = "test.op"() : () -> tensor<1x6xi64> + %input = "test.op"() : () -> tensor<1x6xi64> + + // CHECK: %[[init:.*]] = "test.op"() : () -> tensor + %init = "test.op"() : () -> tensor + + ////////////////// Test ReduceOp ////////////////// + // CHECK: %reduce = "stablehlo.reduce"(%[[input]], %[[init]]) <{dimensions = array}> ({ + // CHECK: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%0) : (tensor) -> () + // CHECK: }) : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_reduction_operations(run_filecheck): + """Test invalid cases for ReduceOp verifier.""" + + # Duplicate dimensions + program_dup_dims = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"dimensions should not have duplicates"): + run_filecheck(program_dup_dims, roundtrip=True, verify=True) + + # Dimension out of range + program_dim_oob = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"dimensions contains an invalid value"): + run_filecheck(program_dim_oob, roundtrip=True, verify=True) + + # Input/init element type mismatch + program_elem_mismatch = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"input and init_value must have the same element type"): + run_filecheck(program_elem_mismatch, roundtrip=True, verify=True) + + # Reducer wrong arity (expects 2 args per input; give 1) + program_wrong_arity = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%acc: tensor): + "stablehlo.return"(%acc) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer must take 2 arguments, got 1"): + run_filecheck(program_wrong_arity, roundtrip=True, verify=True) + + # Reducer arg wrong rank (should be 0D) + program_arg_rank = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor<2xi64>, %arg1: tensor<2xi64>): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> + "stablehlo.return"(%0) : (tensor<2xi64>) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer arguments must be rank-0 tensors"): + run_filecheck(program_arg_rank, roundtrip=True, verify=True) + + # Reducer return wrong count + program_return_count = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "stablehlo.return"() : () -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer must return exactly one value per input"): + run_filecheck(program_return_count, roundtrip=True, verify=True) + + +def test_custom_call_basic(run_filecheck): + """CustomCallOp minimal form without layouts should verify.""" + program = r""" + // CHECK: %[[ARG:.*]] = "test.op"() : () -> tensor<2x3xi32> + %arg = "test.op"() : () -> tensor<2x3xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG]]) + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: api_version = #stablehlo + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_with_layouts(run_filecheck): + """CustomCallOp with matching operand/result layouts should verify.""" + program = r""" + // CHECK: %[[ARG:.*]] = "test.op"() : () -> tensor<2x3xi32> + %arg = "test.op"() : () -> tensor<2x3xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG]]) + // CHECK-SAME: operand_layouts = [dense<[1, 0]> : tensor<2xindex>] + // CHECK-SAME: result_layouts = [dense<[1, 0]> : tensor<2xindex>] + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_missing_result_layouts(run_filecheck): + """Providing only operand_layouts should fail (must provide both or none).""" + program = r""" + %arg = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises( + Exception, + match=r"either both operands and results or none", + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_layouts_mismatch(run_filecheck): + """Number of layouts must match number of operands/results.""" + program = r""" + %arg0 = "test.op"() : () -> tensor<2x3xi32> + %arg1 = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg0, %arg1) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises( + Exception, match=r"Number of operands must match the number of operand layouts" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_incorrect_layout_perm(run_filecheck): + """Layout must be a permutation of [0, rank).""" + program = r""" + %arg = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[0]> : tensor<1xindex>], + result_layouts = [dense<[0]> : tensor<1xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match=r"layout must be a permutation of \[0, 2\)"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_single_tuple_result_with_element_layouts(run_filecheck): + """Single tuple result with element-wise layouts should verify (common case).""" + program = r""" + // CHECK: %[[ARG0:.*]] = "test.op"() : () -> tensor<2x3xi32> + // CHECK: %[[ARG1:.*]] = "test.op"() : () -> tensor<1xi32> + %arg0 = "test.op"() : () -> tensor<2x3xi32> + %arg1 = "test.op"() : () -> tensor<1xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG0]]) + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: api_version = #stablehlo + // CHECK-SAME: operand_layouts = [dense<[1, 0]> : tensor<2xindex>] + // CHECK-SAME: result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>] + %res = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0]> : tensor<1xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tuple, tensor<1xi32>> + """ + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_dynamic_broadcast_in_dim_operations(run_filecheck): + """Test invalid dynamic_broadcast_in_dim cases that should fail verification.""" + + # dims size mismatch (c2) + program_dims_size_mismatch = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<3xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + """ + + with pytest.raises( + Exception, match=r"broadcast_dimensions size \(1\) does not match operand rank \(2\)" + ): + run_filecheck(program_dims_size_mismatch, roundtrip=True, verify=True) + + # result rank < operand rank (c3) + program_result_rank_too_small = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<1xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<1xi64>) -> tensor<3xi64> + """ + + with pytest.raises(Exception, match=r"result rank \(1\) is less than operand rank \(2\)"): + run_filecheck(program_result_rank_too_small, roundtrip=True, verify=True) + + # duplicate dims (c4) + program_duplicate_dims = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises(Exception, match=r"broadcast_dimensions should not have duplicates"): + run_filecheck(program_duplicate_dims, roundtrip=True, verify=True) + + # dim index out of bounds (c5 bounds) + program_dim_oob = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, match=r"broadcast_dimensions contains invalid value 2 for result with rank 2" + ): + run_filecheck(program_dim_oob, roundtrip=True, verify=True) + + # per-dimension size compatibility (c5 compatibility) + program_dim_incompatible = r""" + %operand = "test.op"() : () -> tensor<2x3xi32> + %out = "test.op"() : () -> tensor<3xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<2x3xi32>, tensor<3xi64>) -> tensor<4x3x2xi32> + """ + + with pytest.raises( + Exception, + match=r"size of operand dimension 0 \(2\) is not compatible with size of result dimension 0 \(4\)", + ): + run_filecheck(program_dim_incompatible, roundtrip=True, verify=True) + + # output_dimensions length incompatible with result rank when static (c7) + program_outlen_mismatch = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3x2xi64> + """ + + with pytest.raises( + Exception, + match=r"length of output_dimensions \(2\) is not compatible with result rank \(3\)", + ): + run_filecheck(program_outlen_mismatch, roundtrip=True, verify=True) + + # duplicate expansion hints across both lists (c8) + program_dup_hints = r""" + %operand = "test.op"() : () -> tensor<1x1xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor<1x1xi64>, tensor<2xi64>) -> tensor<2x1xi64> + """ + + with pytest.raises( + Exception, match=r"duplicate expansion hint for at least one operand dimension" + ): + run_filecheck(program_dup_hints, roundtrip=True, verify=True) + + # hint refers to invalid operand dimension (c9/c10) + program_hint_oob = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array, + known_expanding_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, + match=r"hint for expanding dimension 5 does not refer to a valid operand dimension", + ): + run_filecheck(program_hint_oob, roundtrip=True, verify=True) diff --git a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py new file mode 100644 index 0000000000..04ac810c22 --- /dev/null +++ b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py @@ -0,0 +1,149 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/transform.py.""" + +from dataclasses import dataclass + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + +pytestmark = pytest.mark.external + +from xdsl import passes +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.dialects.builtin import DictionaryAttr, IntegerAttr, i64 +from xdsl.dialects.transform import AnyOpType +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.test_value import create_ssa_value + +from catalyst.python_interface.conversion import xdsl_from_docstring +from catalyst.python_interface.dialects import transform +from catalyst.python_interface.dialects.transform import ApplyRegisteredPassOp +from catalyst.python_interface.pass_api import ( + ApplyTransformSequence, + compiler_transform, +) + + +def test_dict_options(): + """Test ApplyRegisteredPassOp constructor with dict options.""" + target = create_ssa_value(AnyOpType()) + options = {"option1": 1, "option2": True} + + op = ApplyRegisteredPassOp("canonicalize", target, options) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({"option1": 1, "option2": True}) + assert op.verify_() is None + + +def test_attr_options(): + """Test ApplyRegisteredPassOp constructor with DictionaryAttr options.""" + target = create_ssa_value(AnyOpType()) + options = DictionaryAttr({"test-option": IntegerAttr(42, i64)}) + + # This should trigger the __init__ method + op = ApplyRegisteredPassOp("canonicalize", target, options) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({"test-option": IntegerAttr(42, i64)}) + assert op.verify_() is None + + +def test_none_options(): + """Test ApplyRegisteredPassOp constructor with None options.""" + target = create_ssa_value(AnyOpType()) + + # This should trigger the __init__ method + op = ApplyRegisteredPassOp("canonicalize", target, None) + + assert op.pass_name.data == "canonicalize" + assert isinstance(op.options, DictionaryAttr) + assert op.options == DictionaryAttr({}) + assert op.verify_() is None + + +def test_invalid_options(): + """Test ApplyRegisteredPassOp constructor with invalid options type.""" + target = create_ssa_value(AnyOpType()) + + with pytest.raises( + VerifyException, match="invalid_options should be of base attribute dictionary" + ): + ApplyRegisteredPassOp("canonicalize", target, "invalid_options").verify_() + + +def test_transform_dialect_filecheck(run_filecheck): + """Test that the transform dialect operations are parsed correctly.""" + program = """ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.any_op): + %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op + // CHECK: options = {"invalid-option" = 1 : i64} + %1 = "transform.apply_registered_pass"(%0) <{options = {"invalid-option" = 1 : i64}, pass_name = "canonicalize"}> : (!transform.any_op) -> !transform.any_op + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + """ + + run_filecheck(program) + + +def test_integration_for_transform_interpreter(capsys): + """Test that a pass with options is run via the transform interpreter""" + + @compiler_transform + @dataclass(frozen=True) + class _HelloWorld(passes.ModulePass): + name = "test-hello-world" + + custom_print: str | None = None + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + if self.custom_print: + print(self.custom_print) + else: + print("hello world") + + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = "transform.apply_registered_pass"(%arg0) <{options = {"custom_print" = "Hello from custom option!"}, pass_name = "test-hello-world"}> : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = xdsl.context.Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + + mod = program() + pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) + pipeline.apply(ctx, mod) + + assert "Hello from custom option!" in capsys.readouterr().out diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py new file mode 100644 index 0000000000..2110477fde --- /dev/null +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -0,0 +1,531 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for pennylane/compiler/python_compiler/impl.py""" + +from dataclasses import dataclass + +# pylint: disable=wrong-import-position +import pytest + +pytestmark = pytest.mark.catalyst + +catalyst = pytest.importorskip("catalyst") +jax = pytest.importorskip("jax") +jaxlib = pytest.importorskip("jaxlib") +xdsl = pytest.importorskip("xdsl") + +import pennylane as qml +from pennylane.capture import enabled as capture_enabled +from xdsl import passes +from xdsl.context import Context +from xdsl.dialects import builtin +from xdsl.interpreters import Interpreter + +from catalyst import CompileError +from catalyst.passes import apply_pass +from catalyst.passes import cancel_inverses as catalyst_cancel_inverses +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface import Compiler +from catalyst.python_interface.conversion import ( + mlir_from_docstring, + mlir_module, + xdsl_from_docstring, +) +from catalyst.python_interface.dialects import transform +from catalyst.python_interface.pass_api import ( + ApplyTransformSequence, + TransformFunctionsExt, + TransformInterpreterPass, + available_passes, + compiler_transform, +) +from catalyst.python_interface.transforms import ( + iterative_cancel_inverses_pass, + merge_rotations_pass, +) + + +@dataclass(frozen=True) +class HelloWorldPass(passes.ModulePass): + """A simple pass that prints 'hello world' when run.""" + + name = "hello-world" + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + print("hello world") + + +hello_world_pass = compiler_transform(HelloWorldPass) + + +def test_compiler(): + """Test that we can pass a jax module into the compiler. + + In this particular case, the compiler is not doing anything + because this module does not contain nested modules which is what + is expected of Catalyst. + + So, it just tests that Compiler.run does not trigger an assertion + and returns a valid + """ + + @mlir_module + @jax.jit + def identity(x): + return x + + input_module = identity(1) + retval = Compiler.run(input_module) + assert isinstance(retval, jaxlib.mlir.ir.Module) + assert str(retval) == str(input_module) + + +def test_generic_catalyst_program(): + """ + test that actually will trigger the transform interpreter + """ + + @mlir_from_docstring + def program(): + """ + "builtin.module"() <{sym_name = "circuit"}> ({ + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "jit_circuit", sym_visibility = "public"}> ({ + %8 = "catalyst.launch_kernel"() <{callee = @module_circuit::@circuit}> : () -> tensor<2xcomplex> + "func.return"(%8) : (tensor<2xcomplex>) -> () + }) {llvm.emit_c_interface} : () -> () + "builtin.module"() <{sym_name = "module_circuit"}> ({ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.op<"builtin.module">) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.op<"builtin.module">): + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "circuit", sym_visibility = "public"}> ({ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + "quantum.device"(%0) <{kwargs = "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}", lib = "/usr/local/lib/python3.11/dist-packages/pennylane_lightning/liblightning_qubit_catalyst.so", name = "LightningSimulator"}> : (i64) -> () + %1 = "quantum.alloc"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + %2 = "quantum.extract"(%1) <{idx_attr = 0 : i64}> : (!quantum.reg) -> !quantum.bit + %3 = "quantum.custom"(%2) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %4 = "quantum.custom"(%3) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %5 = "quantum.insert"(%1, %4) <{idx_attr = 0 : i64}> : (!quantum.reg, !quantum.bit) -> !quantum.reg + %6 = "quantum.compbasis"(%5) <{operandSegmentSizes = array}> : (!quantum.reg) -> !quantum.obs + %7 = "quantum.state"(%6) <{operandSegmentSizes = array}> : (!quantum.obs) -> tensor<2xcomplex> + "quantum.dealloc"(%5) : (!quantum.reg) -> () + "quantum.device_release"() : () -> () + "func.return"(%7) : (tensor<2xcomplex>) -> () + }) {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage, qnode} : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "setup"}> ({ + "quantum.init"() : () -> () + "func.return"() : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "teardown"}> ({ + "quantum.finalize"() : () -> () + "func.return"() : () -> () + }) : () -> () + }) : () -> () + """ + + retval = Compiler.run(program()) + assert isinstance(retval, jaxlib.mlir.ir.Module) + + +def test_generic_catalyst_program_as_string(): + """ + test that actually will trigger the transform interpreter + """ + + program: str = """ + "builtin.module"() <{sym_name = "circuit"}> ({ + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "jit_circuit", sym_visibility = "public"}> ({ + %8 = "catalyst.launch_kernel"() <{callee = @module_circuit::@circuit}> : () -> tensor<2xcomplex> + "func.return"(%8) : (tensor<2xcomplex>) -> () + }) {llvm.emit_c_interface} : () -> () + "builtin.module"() <{sym_name = "module_circuit"}> ({ + "builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.op<"builtin.module">) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.op<"builtin.module">): + "transform.yield"() : () -> () + }) : () -> () + }) {transform.with_named_sequence} : () -> () + "func.func"() <{function_type = () -> tensor<2xcomplex>, sym_name = "circuit", sym_visibility = "public"}> ({ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + "quantum.device"(%0) <{kwargs = "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}", lib = "/usr/local/lib/python3.11/dist-packages/pennylane_lightning/liblightning_qubit_catalyst.so", name = "LightningSimulator"}> : (i64) -> () + %1 = "quantum.alloc"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + %2 = "quantum.extract"(%1) <{idx_attr = 0 : i64}> : (!quantum.reg) -> !quantum.bit + %3 = "quantum.custom"(%2) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %4 = "quantum.custom"(%3) <{gate_name = "Hadamard", operandSegmentSizes = array, resultSegmentSizes = array}> : (!quantum.bit) -> !quantum.bit + %5 = "quantum.insert"(%1, %4) <{idx_attr = 0 : i64}> : (!quantum.reg, !quantum.bit) -> !quantum.reg + %6 = "quantum.compbasis"(%5) <{operandSegmentSizes = array}> : (!quantum.reg) -> !quantum.obs + %7 = "quantum.state"(%6) <{operandSegmentSizes = array}> : (!quantum.obs) -> tensor<2xcomplex> + "quantum.dealloc"(%5) : (!quantum.reg) -> () + "quantum.device_release"() : () -> () + "func.return"(%7) : (tensor<2xcomplex>) -> () + }) {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage, qnode} : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "setup"}> ({ + "quantum.init"() : () -> () + "func.return"() : () -> () + }) : () -> () + "func.func"() <{function_type = () -> (), sym_name = "teardown"}> ({ + "quantum.finalize"() : () -> () + "func.return"() : () -> () + }) : () -> () + }) : () -> () + """ + + retval = Compiler.run(program) + assert isinstance(retval, str) + + +def test_raises_error_when_pass_does_not_exists(): + """Attempts to run pass "this-pass-does-not-exists" on an empty module. + + This should raise an error + """ + + @xdsl_from_docstring + def empty_module(): + """ + builtin.module {} + """ + + @xdsl_from_docstring + def schedule_module(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "this-pass-does-not-exists" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + schedule = TransformInterpreterPass.find_transform_entry_point( + schedule_module(), "__transform_main" + ) + interpreter = Interpreter(empty_module()) + interpreter.register_implementations(TransformFunctionsExt(ctx, {})) + with pytest.raises(CompileError): + interpreter.call_op(schedule, (empty_module(),)) + + +def test_decorator(): + """Test that the decorator has modified the available_passes dictionary""" + assert "hello-world" in available_passes + assert available_passes["hello-world"]() == HelloWorldPass + + +def test_integration_for_transform_interpreter(capsys): + """Test that a pass is run via the transform interpreter""" + + # The hello-world pass is in the IR + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "hello-world" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = xdsl.context.Context() + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(transform.Transform) + + pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) + pipeline.apply(ctx, program()) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + +class TestCatalystIntegration: + """Tests for integration of the Python compiler with Catalyst""" + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_integration_catalyst_no_passes_with_capture(self): + """Test that the xDSL plugin can be used even when no passes are applied + when capture is enabled.""" + + assert capture_enabled() + + @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + + def test_integration_catalyst_no_passes_no_capture(self): + """Test that the xDSL plugin can be used even when no passes are applied + when capture is disabled.""" + + assert not capture_enabled() + + @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_integration_catalyst_xdsl_pass_with_capture(self, capsys): + """Test that a pass is run via the transform interpreter when using with a + qjit workflow and capture is enabled.""" + + assert capture_enabled() + + @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @hello_world_pass + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + def test_integration_catalyst_xdsl_pass_no_capture(self, capsys): + """Test that a pass is run via the transform interpreter when using with a + qjit workflow and capture is disabled.""" + + assert not capture_enabled() + + @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @apply_pass("hello-world") + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_integration_catalyst_mixed_passes_with_capture(self, capsys): + """Test that both Catalyst and Python compiler passes can be used with qjit + when capture is enabled.""" + + assert capture_enabled() + + @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @hello_world_pass + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + qml.X(0) + qml.X(0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + def test_integration_catalyst_mixed_passes_no_capture(self, capsys): + """Test that both Catalyst and Python compiler passes can be used with qjit + when capture is disabled.""" + + assert not capture_enabled() + + @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @apply_pass("hello-world") + @catalyst_cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, 0) + qml.X(0) + qml.X(0) + return qml.expval(qml.Z(0)) + + out = f(1.5) + assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + +class TestCallbackIntegration: + """Test the integration of the callback functionality""" + + def test_callback_integration(self, capsys): + """Test that the callback mechanism works with the transform interpreter""" + + @compiler_transform + @dataclass(frozen=True) + class _(passes.ModulePass): + name = "none-pass" + + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: ... + + def print_between_passes(*_, pass_level=0): + if pass_level == 0: + return + print("hello world") + + @xdsl_from_docstring + def program(): + """ + builtin.module { + builtin.module { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "none-pass" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + pipeline = xdsl.passes.PassPipeline( + (ApplyTransformSequence(callback=print_between_passes),) + ) + pipeline.apply(ctx, program()) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + + def test_callback_prints_module_after_each_pass(self, capsys): + """Test that the callback prints the module after each pass""" + + # pylint: disable=redefined-outer-name + def print_between_passes(_, module, __, pass_level=0): + if pass_level == 0: + return + print("=== Between Pass ===") + print(module) + + @xdsl_from_docstring + def program_2_passes(): + """ + builtin.module { + builtin.module @module_foo { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "xdsl-cancel-inverses" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + %1 = transform.apply_registered_pass "xdsl-merge-rotations" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield %1 : !transform.op<"builtin.module"> + func.func public @foo() { + %2 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %3 = tensor.extract %2[] : tensor + %4 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %5 = quantum.alloc(1) : !quantum.reg + %6 = tensor.extract %2[] : tensor + %7 = quantum.extract %5[%6] : !quantum.reg -> !quantum.bit + %8 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor + %9 = tensor.extract %8[] : tensor + %10 = quantum.custom "RX"(%9) %7 : !quantum.bit + %11 = tensor.extract %8[] : tensor + %12 = quantum.custom "RX"(%11) %10 : !quantum.bit + %13 = quantum.custom "Hadamard"() %12 : !quantum.bit + %14 = quantum.custom "Hadamard"() %13 : !quantum.bit + %15 = tensor.extract %1[] : tensor + %16 = quantum.insert %5[%15], %14 : !quantum.reg, !quantum.bit + %17 = quantum.compbasis qreg %16 : !quantum.obs + %18 = quantum.probs %17 : tensor<2xf64> + } + } + } + } + """ + + ctx = Context() + ctx.load_dialect(builtin.Builtin) + pipeline = xdsl.passes.PassPipeline( + (ApplyTransformSequence(callback=print_between_passes),) + ) + pipeline.apply(ctx, program_2_passes()) + + out = capsys.readouterr().out + printed_modules = out.split("=== Between Pass ===")[1:] + + assert ( + len(printed_modules) == 2 + ), "Callback should have been called twice (after each pass)." + + # callback after cancel-inverses + assert 'quantum.custom "RX"' in printed_modules[0] + assert 'quantum.custom "Hadamard"' not in printed_modules[0] + + # callback after merge-rotations + # We expect an `arith.addf` if rotations were merged + assert "arith.addf" in printed_modules[1], "Expected merged RX gates into a single rotation" + assert 'quantum.custom "RX"' in printed_modules[1] + + assert printed_modules[0] != printed_modules[1], "IR should differ between passes" + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_callback_run_integration(self, capsys): + """Test that the callback is integrated into the pass pipeline with the Compiler.run() method""" + + # pylint: disable=redefined-outer-name + def print_between_passes(_, module, __, pass_level=0): + if pass_level == 0: + return + print("=== Between Pass ===") + print(module) + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @iterative_cancel_inverses_pass + @merge_rotations_pass + @qml.qnode(qml.device("null.qubit", wires=2)) + def circuit(): + qml.RX(0.1, 0) + qml.RX(2.0, 0) + qml.Hadamard(1) + qml.Hadamard(1) + return qml.state() + + Compiler.run(circuit.mlir_module, callback=print_between_passes) + out = capsys.readouterr().out + printed_modules = out.split("=== Between Pass ===")[1:] + + assert ( + len(printed_modules) == 2 + ), "Callback should have been called twice (after each pass)." + + # callback after merge-rotations + # We expect an `arith.addf` if rotations were merged + assert "arith.addf" in printed_modules[0], "Expected merged RX gates into a single rotation" + assert 'quantum.custom "RX"' in printed_modules[0] + assert 'quantum.custom "Hadamard"' in printed_modules[0] + + # callback after cancel-inverses + assert "arith.addf" in printed_modules[1], "Expected merged RX gates into a single rotation" + assert 'quantum.custom "RX"' in printed_modules[1] + assert 'quantum.custom "Hadamard"' not in printed_modules[1] + + assert printed_modules[0] != printed_modules[1], "IR should differ between passes" + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/test_xdsl_utils.py b/frontend/test/pytest/python_interface/test_xdsl_utils.py new file mode 100644 index 0000000000..23254dd7ec --- /dev/null +++ b/frontend/test/pytest/python_interface/test_xdsl_utils.py @@ -0,0 +1,120 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for xDSL utilities.""" + +import pytest + +pytestmark = pytest.mark.external +xdsl = pytest.importorskip("xdsl") + +# pylint: disable=wrong-import-position +from xdsl.dialects import arith, builtin, tensor, test + +from catalyst.python_interface.dialects.stablehlo import ConstantOp as hloConstantOp +from catalyst.python_interface.utils import get_constant_from_ssa + + +class TestGetConstantFromSSA: + """Unit tests for ``get_constant_from_ssa``.""" + + def test_non_constant(self): + """Test that ``None`` is returned if the input is not a constant.""" + val = test.TestOp(result_types=(builtin.Float64Type(),)).results[0] + assert get_constant_from_ssa(val) is None + + @pytest.mark.parametrize( + "const, attr_type, dtype", + [ + (11, builtin.IntegerAttr, builtin.IntegerType(64)), + (5, builtin.IntegerAttr, builtin.IndexType()), + (2.5, builtin.FloatAttr, builtin.Float64Type()), + ], + ) + def test_scalar_constant_arith(self, const, attr_type, dtype): + """Test that constants created by ``arith.constant`` are returned correctly.""" + const_attr = attr_type(const, dtype) + val = arith.ConstantOp(value=const_attr).results[0] + + assert get_constant_from_ssa(val) == const + + @pytest.mark.parametrize( + "const, elt_type", + [ + (11, builtin.IntegerType(64)), + (9, builtin.IndexType()), + (2.5, builtin.Float64Type()), + (-1.1 + 2.3j, builtin.ComplexType(builtin.Float64Type())), + ], + ) + @pytest.mark.parametrize("constant_op", [arith.ConstantOp, hloConstantOp]) + def test_scalar_constant_extracted_from_rank0_tensor(self, const, elt_type, constant_op): + """Test that constants created by ``stablehlo.constant`` are returned correctly.""" + data = const + if isinstance(const, complex): + # For complex numbers, the number must be split into a 2-tuple containing + # the real and imaginary part when initializing a dense elements attr. + data = (const.real, const.imag) + + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=elt_type, shape=()), + data=(data,), + ) + tensor_ = constant_op(value=dense_attr).results[0] + val = tensor.ExtractOp(tensor=tensor_, indices=[], result_type=elt_type).results[0] + + assert get_constant_from_ssa(val) == const + + def test_tensor_constant_arith(self): + """Test that ``None`` is returned if the input is a tensor created by ``arith.constant``.""" + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), + data=(1, 2, 3), + ) + val = arith.ConstantOp(value=dense_attr).results[0] + + assert get_constant_from_ssa(val) is None + + def test_tensor_constant_stablehlo(self): + """Test that ``None`` is returned if the input is a tensor created by ``stablehlo.constant``.""" + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), + data=(1.0, 2.0, 3.0), + ) + val = hloConstantOp(value=dense_attr).results[0] + + assert get_constant_from_ssa(val) is None + + def test_extract_scalar_from_constant_tensor_stablehlo(self): + """Test that ``None`` is returned if the input is a scalar constant, but it was extracted + from a non-scalar constant.""" + # Index SSA value to be used for extracting a value from a tensor + dummy_index = test.TestOp(result_types=(builtin.IndexType(),)).results[0] + + dense_attr = builtin.DenseIntOrFPElementsAttr.from_list( + type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), + data=(1.0, 2.0, 3.0), + ) + tensor_ = hloConstantOp(value=dense_attr).results[0] + val = tensor.ExtractOp( + tensor=tensor_, indices=[dummy_index], result_type=builtin.Float64Type() + ).results[0] + # val is a value that we got by indexing into a constant tensor with rank >= 1 + assert isinstance(val.type, builtin.Float64Type) + + assert get_constant_from_ssa(val) is None + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py new file mode 100644 index 0000000000..37fbde22ee --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py @@ -0,0 +1,201 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the graph state utils""" + +# pylint: disable=wrong-import-position + +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + +from pennylane.exceptions import CompileError + +from catalyst.python_interface.transforms.mbqc.graph_state_utils import ( + _adj_matrix_generation_helper, + edge_iter, + get_graph_state_edges, + get_num_aux_wires, + n_vertices_from_packed_adj_matrix, +) + + +@pytest.fixture(scope="module", name="mbqc_single_qubit_graph") +def fixture_mbqc_single_qubit_graph(): + """Fixture that returns the densely packed adjacency matrix for the graph state used for + representing single-qubit gates in the MBQC formalism. + + The graph state is as follows: + + 0 -- 1 -- 2 -- 3 + """ + # fmt: off + packed_adj_matrix = [ + # 0 1 2 + 1, # 1 + 0, 1, # 2 + 0, 0, 1 # 3 + ] + return packed_adj_matrix + + +@pytest.fixture(scope="module", name="mbqc_cnot_graph") +def fixture_mbqc_cnot_graph(): + """Fixture that returns the densely packed adjacency matrix for the graph state used for + representing a CNOT gate in the MBQC formalism. + + The graph state is as follows: + + 0 -- 1 -- 2 -- 3 -- 4 -- 5 + | + 6 + | + 7 -- 8 -- 9 -- 10 - 11 - 12 + """ + # fmt: off + packed_adj_matrix = [ + # 0 1 2 3 4 5 6 7 8 9 10 11 + 1, + 0, 1, + 0, 0, 1, + 0, 0, 0, 1, + 0, 0, 0, 0, 1, + 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 1, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 + ] + return packed_adj_matrix + + +class TestGraphStateUtils: + """Unit tests for graph state utils.""" + + def test_unsupported_gate(self): + """Test error raised for unsupported gates""" + with pytest.raises(ValueError, match="rotxzx is not supported in the MBQC formalism."): + get_graph_state_edges("rotxzx") + + with pytest.raises(ValueError, match="rotXzx is not supported in the MBQC formalism."): + get_num_aux_wires("rotXzx") + + def test_adj_matrix_generation_helper(self): + """Test that error raised for unsupported gates.""" + num_vertices = 4 + edges = [(0, 1), (1, 2), (2, 3)] + adj_matrix = _adj_matrix_generation_helper(num_vertices, edges) + assert adj_matrix == [1, 0, 1, 0, 0, 1] + + @pytest.mark.parametrize("n_vertices", range(1, 16)) + def test_n_vertices(self, n_vertices: int): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function returns correct results + when given a valid densely packed adjacency matrix as input. + + This test performs the inverse operation of _n_vertices_from_packed_adj_matrix by computing + the number of elements in the densely packed adjacency matrix from the given number of + vertices, `n_vertices`, generates a null list of this length, and checks the function output + given this list is the same as `n_vertices`. + """ + n_elements = int(n_vertices * (n_vertices - 1) / 2) + adj_matrix = [0] * n_elements + n_observed = n_vertices_from_packed_adj_matrix(adj_matrix) + + assert n_observed == n_vertices + + @pytest.mark.parametrize("n", [2, 4, 5, 7, 8, 9]) + def test_n_vertices_raises_on_invalid(self, n): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function raises a CompileError when + given an invalid densely packed adjacency matrix as input. + """ + with pytest.raises(CompileError, match="densely packed adjacency matrix"): + adj_matrix = [0] * n + _ = n_vertices_from_packed_adj_matrix(adj_matrix) + + @pytest.mark.parametrize( + "adj_matrix, expected_edges", + [ + ([], []), + ([0], []), + ([1], [(0, 1)]), + ([1, 0, 0], [(0, 1)]), + ([1, 1, 0], [(0, 1), (0, 2)]), + ([1, 1, 1], [(0, 1), (0, 2), (1, 2)]), + ([False], []), + ([True], [(0, 1)]), + ], + ) + def test_edge_iter(self, adj_matrix, expected_edges): + """Test that the ``_edge_iter`` generator function yields correct results when given a valid + densely packed adjacency matrix as input.""" + edges = list(edge_iter(adj_matrix)) + assert edges == expected_edges + + @pytest.mark.parametrize("n", [2, 4, 5, 7, 8, 9]) + def test_edge_iter_raises_on_invalid(self, n): + """Test that the ``_edge_iter`` generator function raises a CompileError when given an + invalid densely packed adjacency matrix as input. + """ + with pytest.raises(CompileError, match="densely packed adjacency matrix"): + adj_matrix = [0] * n + _ = list(edge_iter(adj_matrix)) + + def test_n_vertices_mbqc_single_qubit(self, mbqc_single_qubit_graph): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function correctly determines that + the number of vertices in the densely packed adjacency matrix for the graph state used for + representing single-qubit gates in the MBQC formalism is equal to 4. + """ + n_observed = n_vertices_from_packed_adj_matrix(mbqc_single_qubit_graph) + assert n_observed == 4 + + def test_n_vertices_mbqc_cnot(self, mbqc_cnot_graph): + """Test that the ``_n_vertices_from_packed_adj_matrix`` function correctly determines that + the number of vertices in the densely packed adjacency matrix for the graph state used for + representing a CNOT gate in the MBQC formalism is equal to 13. + """ + n_observed = n_vertices_from_packed_adj_matrix(mbqc_cnot_graph) + assert n_observed == 13 + + def test_edge_iter_mbqc_single_qubit(self, mbqc_single_qubit_graph): + """Test that the ``_edge_iter`` generator function applied to the densely packed adjacency + matrix for the graph state used for representing single-qubit gates in the MBQC formalism + yields the correct edges. + + For reference, the graph is: + + 0 -- 1 -- 2 -- 3 + """ + edges_observed = list(edge_iter(mbqc_single_qubit_graph)) + + assert edges_observed == get_graph_state_edges("RZ") + + def test_edge_iter_mbqc_cnot(self, mbqc_cnot_graph): + """Test that the ``_edge_iter`` generator function applied to the densely packed adjacency + matrix for the graph state used for representing a CNOT gate in the MBQC formalism yields + the correct edges. + + For reference, the graph is: + + 0 -- 1 -- 2 -- 3 -- 4 -- 5 + | + 6 + | + 7 -- 8 -- 9 -- 10 - 11 - 12 + """ + edges_observed = list(edge_iter(mbqc_cnot_graph)) + assert edges_observed == get_graph_state_edges("CNOT") diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py new file mode 100644 index 0000000000..e6fc9a3229 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py @@ -0,0 +1,576 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the convert to MBQC formalism transform""" +import pytest + +pytestmark = pytest.mark.external + +xdsl = pytest.importorskip("xdsl") +catalyst = pytest.importorskip("catalyst") + +import pennylane as qml +from pennylane.ftqc import RotXZX + +# pylint: disable=wrong-import-position +from catalyst.ftqc import mbqc_pipeline +from catalyst.python_interface.transforms import ( + ConvertToMBQCFormalismPass, + convert_to_mbqc_formalism_pass, + decompose_graph_state_pass, + measurements_from_samples_pass, +) + + +class TestConvertToMBQCFormalismPass: + """Unit tests for ConvertToMBQCFormalismPass.""" + + def test_unsupported_gate(self, run_filecheck): + """Test for error threw for unsupported gate""" + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + %1 = quantum.custom "IsingXX"() %0 : !quantum.bit + return + } + """ + with pytest.raises(NotImplementedError): + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_unconverted_gate_set(self, run_filecheck): + """Test for supported gates that are not converted in the pass""" + program = """ + func.func @test_func(%arg0 :f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "PauliX"() [[q0:%.+]] : !quantum.bit + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "PauliY"() [[q0:%.+]] : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "PauliZ"() [[q0:%.+]] : !quantum.bit + %3 = quantum.custom "PauliZ"() %2 : !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = quantum.custom "Identity"() [[q0:%.+]] : !quantum.bit + %4 = quantum.custom "Identity"() %3 : !quantum.bit + // CHECK-NEXT: quantum.gphase + quantum.gphase %arg0 + return + } + """ + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_hadamard_gate(self, run_filecheck): + """Test for lowering a Hadamard gate to a MBQC formalism.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @hadamard_in_mbqc([[q0:%.+]]) : (!quantum.bit) -> !quantum.bit + %1 = quantum.custom "Hadamard"() %0 : !quantum.bit + return + } + + // CHECK: func.func private @hadamard_in_mbqc(%0 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %1 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %2 = mbqc.graph_state_prep(%1 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %3 = quantum.extract %2[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %4 = quantum.extract %2[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %5 = quantum.extract %2[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %6 = quantum.extract %2[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7, %8 = quantum.custom "CZ"() %0, %3 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %9 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %10 = arith.constant 1.5707963267948966 : f64 + // CHECK-NEXT: %11, %12 = mbqc.measure_in_basis[XY, %9] %7 : i1, !quantum.bit + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %10] %8 : i1, !quantum.bit + // CHECK-NEXT: %15, %16 = mbqc.measure_in_basis[XY, %10] %4 : i1, !quantum.bit + // CHECK-NEXT: %17, %18 = mbqc.measure_in_basis[XY, %10] %5 : i1, !quantum.bit + // CHECK-NEXT: %19 = arith.xori %11, %15 : i1 + // CHECK-NEXT: %20 = arith.xori %19, %17 : i1 + // CHECK-NEXT: %21 = arith.constant true + // CHECK-NEXT: %22 = arith.cmpi eq, %20, %21 : i1 + // CHECK-NEXT: %23 = scf.if %22 -> (!quantum.bit) { + // CHECK-NEXT: %24 = quantum.custom "PauliX"() %6 : !quantum.bit + // CHECK-NEXT: scf.yield %24 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %6 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %25 = arith.xori %13, %15 : i1 + // CHECK-NEXT: %26 = arith.constant true + // CHECK-NEXT: %27 = arith.cmpi eq, %25, %26 : i1 + // CHECK-NEXT: %28 = scf.if %27 -> (!quantum.bit) { + // CHECK-NEXT: %29 = quantum.custom "PauliZ"() %23 : !quantum.bit + // CHECK-NEXT: scf.yield %29 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %23 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %16 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %12 : !quantum.bit + // CHECK-NEXT: func.return %28 : !quantum.bit + // CHECK-NEXT: } + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_s_gate(self, run_filecheck): + """Test for lowering a S gate to a MBQC formalism.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @s_in_mbqc([[q0:%.+]]) : (!quantum.bit) -> !quantum.bit + %1 = quantum.custom "S"() %0 : !quantum.bit + return + } + + // CHECK: func.func private @s_in_mbqc(%0 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %1 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %2 = mbqc.graph_state_prep(%1 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %3 = quantum.extract %2[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %4 = quantum.extract %2[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %5 = quantum.extract %2[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %6 = quantum.extract %2[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7, %8 = quantum.custom "CZ"() %0, %3 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %9 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %10 = arith.constant 1.5707963267948966 : f64 + // CHECK-NEXT: %11, %12 = mbqc.measure_in_basis[XY, %9] %7 : i1, !quantum.bit + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %9] %8 : i1, !quantum.bit + // CHECK-NEXT: %15, %16 = mbqc.measure_in_basis[XY, %10] %4 : i1, !quantum.bit + // CHECK-NEXT: %17, %18 = mbqc.measure_in_basis[XY, %9] %5 : i1, !quantum.bit + // CHECK-NEXT: %19 = arith.xori %13, %17 : i1 + // CHECK-NEXT: %20 = arith.constant true + // CHECK-NEXT: %21 = arith.cmpi eq, %19, %20 : i1 + // CHECK-NEXT: %22 = scf.if %21 -> (!quantum.bit) { + // CHECK-NEXT: %23 = quantum.custom "PauliX"() %6 : !quantum.bit + // CHECK-NEXT: scf.yield %23 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %6 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %24 = arith.xori %11, %13 : i1 + // CHECK-NEXT: %25 = arith.xori %24, %15 : i1 + // CHECK-NEXT: %26 = arith.constant true + // CHECK-NEXT: %27 = arith.xori %25, %26 : i1 + // CHECK-NEXT: %28 = arith.constant true + // CHECK-NEXT: %29 = arith.cmpi eq, %27, %28 : i1 + // CHECK-NEXT: %30 = scf.if %29 -> (!quantum.bit) { + // CHECK-NEXT: %31 = quantum.custom "PauliZ"() %22 : !quantum.bit + // CHECK-NEXT: scf.yield %31 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %22 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %16 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %12 : !quantum.bit + // CHECK-NEXT: func.return %30 : !quantum.bit + // CHECK-NEXT: } + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_rz_gate(self, run_filecheck): + """Test for lowering a RZ gate to a MBQC formalism.""" + program = """ + func.func @test_func(%param0: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @rz_in_mbqc(%param0, %0) : (f64, !quantum.bit) -> !quantum.bit + %1 = quantum.custom "RZ"(%param0) %0 : !quantum.bit + return + } + // CHECK: func.func private @rz_in_mbqc(%0 : f64, %1 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %2 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %3 = mbqc.graph_state_prep(%2 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %4 = quantum.extract %3[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %5 = quantum.extract %3[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %6 = quantum.extract %3[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7 = quantum.extract %3[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %8, %9 = quantum.custom "CZ"() %1, %4 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %10 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %11, %12 = mbqc.measure_in_basis[XY, %10] %8 : i1, !quantum.bit + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %10] %9 : i1, !quantum.bit + // CHECK-NEXT: %15 = arith.constant true + // CHECK-NEXT: %16 = arith.cmpi eq, %13, %15 : i1 + // CHECK-NEXT: %17, %18 = scf.if %16 -> (i1, !quantum.bit) { + // CHECK-NEXT: %19, %20 = mbqc.measure_in_basis[XY, %0] %5 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %19, %20 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %21 = arith.negf %0 : f64 + // CHECK-NEXT: %22, %23 = mbqc.measure_in_basis[XY, %21] %5 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %22, %23 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %24, %25 = mbqc.measure_in_basis[XY, %10] %6 : i1, !quantum.bit + // CHECK-NEXT: %26 = arith.xori %13, %24 : i1 + // CHECK-NEXT: %27 = arith.constant true + // CHECK-NEXT: %28 = arith.cmpi eq, %26, %27 : i1 + // CHECK-NEXT: %29 = scf.if %28 -> (!quantum.bit) { + // CHECK-NEXT: %30 = quantum.custom "PauliX"() %7 : !quantum.bit + // CHECK-NEXT: scf.yield %30 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %7 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %31 = arith.xori %11, %17 : i1 + // CHECK-NEXT: %32 = arith.constant true + // CHECK-NEXT: %33 = arith.cmpi eq, %31, %32 : i1 + // CHECK-NEXT: %34 = scf.if %33 -> (!quantum.bit) { + // CHECK-NEXT: %35 = quantum.custom "PauliZ"() %29 : !quantum.bit + // CHECK-NEXT: scf.yield %35 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %29 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %25 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %12 : !quantum.bit + // CHECK-NEXT: func.return %34 : !quantum.bit + // CHECK-NEXT: } + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_rotxzx_gate(self, run_filecheck): + """Test for lowering a RotXZX gate to a MBQC formalism.""" + program = """ + func.func @test_func(%param0: f64, %param1: f64, %param2: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]] = func.call @rotxzx_in_mbqc(%param0, %param1, %param2, %0) : (f64, f64, f64, !quantum.bit) -> !quantum.bit + %1 = quantum.custom "RotXZX"(%param0, %param1, %param2) %0 : !quantum.bit + return + } + // CHECK: func.func private @rotxzx_in_mbqc(%0 : f64, %1 : f64, %2 : f64, %3 : !quantum.bit) -> !quantum.bit attributes {mbqc_transform = none} { + // CHECK-NEXT: %4 = arith.constant dense<[true, false, true, false, false, true]> : tensor<6xi1> + // CHECK-NEXT: %5 = mbqc.graph_state_prep(%4 : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + // CHECK-NEXT: %6 = quantum.extract %5[0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %7 = quantum.extract %5[1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %8 = quantum.extract %5[2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %9 = quantum.extract %5[3] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: %10, %11 = quantum.custom "CZ"() %3, %6 : !quantum.bit, !quantum.bit + // CHECK-NEXT: %12 = arith.constant 0.000000e+00 : f64 + // CHECK-NEXT: %13, %14 = mbqc.measure_in_basis[XY, %12] %10 : i1, !quantum.bit + // CHECK-NEXT: %15 = arith.constant true + // CHECK-NEXT: %16 = arith.cmpi eq, %13, %15 : i1 + // CHECK-NEXT: %17, %18 = scf.if %16 -> (i1, !quantum.bit) { + // CHECK-NEXT: %19, %20 = mbqc.measure_in_basis[XY, %0] %11 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %19, %20 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %21 = arith.negf %0 : f64 + // CHECK-NEXT: %22, %23 = mbqc.measure_in_basis[XY, %21] %11 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %22, %23 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %24 = arith.constant true + // CHECK-NEXT: %25 = arith.cmpi eq, %17, %24 : i1 + // CHECK-NEXT: %26, %27 = scf.if %25 -> (i1, !quantum.bit) { + // CHECK-NEXT: %28, %29 = mbqc.measure_in_basis[XY, %1] %7 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %28, %29 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %30 = arith.negf %1 : f64 + // CHECK-NEXT: %31, %32 = mbqc.measure_in_basis[XY, %30] %7 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %31, %32 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %33 = arith.xori %13, %26 : i1 + // CHECK-NEXT: %34 = arith.constant true + // CHECK-NEXT: %35 = arith.cmpi eq, %33, %34 : i1 + // CHECK-NEXT: %36, %37 = scf.if %35 -> (i1, !quantum.bit) { + // CHECK-NEXT: %38, %39 = mbqc.measure_in_basis[XY, %2] %8 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %38, %39 : i1, !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: %40 = arith.negf %2 : f64 + // CHECK-NEXT: %41, %42 = mbqc.measure_in_basis[XY, %40] %8 : i1, !quantum.bit + // CHECK-NEXT: scf.yield %41, %42 : i1, !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %43 = arith.xori %17, %36 : i1 + // CHECK-NEXT: %44 = arith.constant true + // CHECK-NEXT: %45 = arith.cmpi eq, %43, %44 : i1 + // CHECK-NEXT: %46 = scf.if %45 -> (!quantum.bit) { + // CHECK-NEXT: %47 = quantum.custom "PauliX"() %9 : !quantum.bit + // CHECK-NEXT: scf.yield %47 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %9 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: %48 = arith.xori %13, %26 : i1 + // CHECK-NEXT: %49 = arith.constant true + // CHECK-NEXT: %50 = arith.cmpi eq, %48, %49 : i1 + // CHECK-NEXT: %51 = scf.if %50 -> (!quantum.bit) { + // CHECK-NEXT: %52 = quantum.custom "PauliZ"() %46 : !quantum.bit + // CHECK-NEXT: scf.yield %52 : !quantum.bit + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %46 : !quantum.bit + // CHECK-NEXT: } + // CHECK-NEXT: quantum.dealloc_qb %18 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %27 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %37 : !quantum.bit + // CHECK-NEXT: quantum.dealloc_qb %14 : !quantum.bit + // CHECK-NEXT: func.return %51 : !quantum.bit + // CHECK-NEXT:} + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_cnot_gate(self, run_filecheck): + """Test for lowering a CNOT gate to a MBQC formalism.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK-NEXT: [[q0:%.+]], [[q1:%.+]]= func.call @cnot_in_mbqc([[q0:%.+]], [[q1:%.+]]) : (!quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit) + %2, %3 = quantum.custom "CNOT"() %0, %1 : !quantum.bit, !quantum.bit + return + } + // CHECK: func.func private @cnot_in_mbqc(%0 : !quantum.bit, %1 : !quantum.bit) -> (!quantum.bit, !quantum.bit) attributes {mbqc_transform = none} { + """ + + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_switch_statement(self, run_filecheck): + """Test that the convert_to_mbqc_formalism_pass works correctly with a switch statement.""" + program = """ + func.func @test_func(%qubits: !quantum.bit, %l : index) { + %0 = scf.index_switch %l -> !quantum.bit + case 0 { + // CHECK-NOT: quantum.custom "Hadamard"() + %q1 = quantum.custom "Hadamard"() %qubits : !quantum.bit + scf.yield %q1 : !quantum.bit + } + default { + // CHECK-NOT: quantum.custom "S"() + %q2 = quantum.custom "S"() %qubits : !quantum.bit + scf.yield %q2 : !quantum.bit + } + + return + } + """ + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + def test_function_no_body(self, run_filecheck): + """Test that the convert_to_mbqc_formalism_pass works correctly with a function that has no body.""" + program = """ + func.func @test_func() { + // CHECK: func.func private @func_1(f64, f64, i1) -> f64 + func.func private @func_1(f64, f64, i1) -> f64 + // CHECK: func.func private @func_2(memref, f64, f64, i1) + func.func private @func_2(memref, f64, f64, i1) + return + } + """ + pipeline = (ConvertToMBQCFormalismPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_gates_in_mbqc_gate_set_lowering(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and unrolled loops.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: circuit + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK-NOT: scf.if + # CHECK: scf.for + # CHECK: func.call @hadamard_in_mbqc + # CHECK: func.call @s_in_mbqc + # CHECK: func.call @rotxzx_in_mbqc + # CHECK: func.call @rz_in_mbqc + # CHECK: func.call @cnot_in_mbqc + # CHECK: mbqc.graph_state_prep + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + for i in range(50): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_gates_in_mbqc_gate_set_lowering_for(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and for-loop structure.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.for_loop(1, 1000, 1) + def loop_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: circuit + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK: scf.for + # CHECK: func.call @hadamard_in_mbqc + # CHECK: func.call @s_in_mbqc + # CHECK: func.call @rotxzx_in_mbqc + # CHECK: func.call @rz_in_mbqc + # CHECK: func.call @cnot_in_mbqc + # CHECK: mbqc.graph_state_prep + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + loop_for() + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_gates_in_mbqc_gate_set_lowering_graph_state_decomp(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and for-loop structure.""" + dev = qml.device("null.qubit", wires=1000) + + def loop_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: mbqc.graph_state_prep + # CHECK: scf.for + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + for i in range(1000): + loop_for(i) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_gates_in_mbqc_gate_set_lowering_while(self, run_filecheck_qjit): + """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and while-loop structure.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.while_loop(lambda i: i > 1000) + def while_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + autograph=True, + ) + @convert_to_mbqc_formalism_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RZ"() + # CHECK-NOT: quantum.custom "RotXZX"() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK: scf.while + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + while_for(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_gates_in_mbqc_gate_set_e2e(self): + """Test that the convert_to_mbqc_formalism_pass end to end on null.qubit.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pipelines=mbqc_pipeline(), + autograph=True, + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @measurements_from_samples_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + for i in range(1000): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + res = circuit() + assert res == 1.0 diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py new file mode 100644 index 0000000000..6c74d782e0 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py @@ -0,0 +1,507 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit and integration tests for the Python compiler `decompose-graph-state` transform. + +FileCheck notation hint: + + Qubit variable names are written as q0a, q0b, q1a, etc. in the FileCheck program. The leading + number indicates the wire index of that qubit, and the second letter increments by one after + each use. +""" + +# pylint: disable=wrong-import-position + +import pytest + +pytestmark = pytest.mark.external + +xdsl = pytest.importorskip("xdsl") + +from catalyst.python_interface.transforms import ( + DecomposeGraphStatePass, + NullDecomposeGraphStatePass, +) + + +class TestDecomposeGraphStatePass: + """Unit tests for the decompose-graph-state pass.""" + + def test_1_qubit(self, run_filecheck): + """Test the decompose-graph-state pass for a 1-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[]> : tensor<0xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(1) : !quantum.reg + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK: [[out_reg0:%.+]] = quantum.insert [[graph_reg]][0], [[q0b]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[]> : tensor<0xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<0xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_2_qubit_chain(self, run_filecheck): + """Test the decompose-graph-state pass for a 2-qubit graph state. The qubit connectivity is: + + 0 -- 1 + + which has the adjacency matrix representation + + 0 1 + 1 0 + + and densely packed adjacency matrix representation + + [1] + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "Hadamard"() [[q1a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CZ"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0c]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1c]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_3_qubit_chain(self, run_filecheck): + """Test the decompose-graph-state pass for a 3-qubit graph state. The qubit connectivity is: + + 0 -- 1 -- 2 + + which has the adjacency matrix representation + + 0 1 0 + 1 0 1 + 0 1 0 + + and densely packed adjacency matrix representation + + [1, 0, 1] + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(3) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q2a:%.+]] = quantum.extract [[graph_reg]][2] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "Hadamard"() [[q1a]] : !quantum.bit + // CHECK-NEXT: [[q2b:%.+]] = quantum.custom "Hadamard"() [[q2a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CZ"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q1d:%.+]], [[q2c:%.+]] = quantum.custom "CZ"() [[q1c]], [[q2b]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0c]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg02:%.+]] = quantum.insert [[out_reg01]][2], [[q2c]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_4_qubit_square_lattice(self, run_filecheck): + """Test the decompose-graph-state pass for a 4-qubit graph state. The qubit connectivity is: + + 0 -- 1 + | | + 2 -- 3 + + which has the adjacency matrix representation + + 0 1 1 0 + 1 0 0 1 + 1 0 0 1 + 0 1 1 0 + + and densely packed adjacency matrix representation + + [1, 1, 0, 0, 1, 1] + + [(0, 1), (0, 2), (1, 3), (2, 3)] + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 1, 0, 0, 1, 1]> : tensor<6xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(4) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q2a:%.+]] = quantum.extract [[graph_reg]][2] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q3a:%.+]] = quantum.extract [[graph_reg]][3] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "Hadamard"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "Hadamard"() [[q1a]] : !quantum.bit + // CHECK-NEXT: [[q2b:%.+]] = quantum.custom "Hadamard"() [[q2a]] : !quantum.bit + // CHECK-NEXT: [[q3b:%.+]] = quantum.custom "Hadamard"() [[q3a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CZ"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q0d:%.+]], [[q2c:%.+]] = quantum.custom "CZ"() [[q0c]], [[q2b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q1d:%.+]], [[q3c:%.+]] = quantum.custom "CZ"() [[q1c]], [[q3b]] : !quantum.bit, !quantum.bit + // CHECK-NEXT: [[q2d:%.+]], [[q3d:%.+]] = quantum.custom "CZ"() [[q2c]], [[q3c]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg02:%.+]] = quantum.insert [[out_reg01]][2], [[q2d]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg03:%.+]] = quantum.insert [[out_reg02]][3], [[q3d]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1, 1, 0, 0, 1, 1]> : tensor<6xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_2_qubit_chain_non_standard_init(self, run_filecheck): + """Test the decompose-graph-state pass for a 2-qubit graph state, using non-standard `init` + and `entangle` attributes. + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + // CHECK: [[q0a:%.+]] = quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: [[q1a:%.+]] = quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0b:%.+]] = quantum.custom "S"() [[q0a]] : !quantum.bit + // CHECK-NEXT: [[q1b:%.+]] = quantum.custom "S"() [[q1a]] : !quantum.bit + + // CHECK: [[q0c:%.+]], [[q1c:%.+]] = quantum.custom "CNOT"() [[q0b]], [[q1b]] : !quantum.bit, !quantum.bit + + // CHECK: [[out_reg00:%.+]] = quantum.insert [[graph_reg]][0], [[q0c]] : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_reg01:%.+]] = quantum.insert [[out_reg00]][1], [[q1c]] : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "S", entangle "CNOT"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_register_use(self, run_filecheck): + """Test that the uses of the register resulting from a graph_state_prep op are still correct + after decomposing it into its quantum ops with the decompose-graph-state pass. + + We do not rigorously test the decomposition here, only that the last resulting register from + the decomposition (specifically, from the last quantum.insert op) is correctly picked up by + the ops that used the register resulting from the original graph_state_prep op. + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + // CHECK: quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + + // CHECK: quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + // CHECK-NEXT: quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + + // CHECK: quantum.custom "CZ"() {{%.+}}, {{%.+}} : !quantum.bit, !quantum.bit + + // CHECK: quantum.insert {{%.+}}[0], {{%.+}} : !quantum.reg, !quantum.bit + // CHECK-NEXT: [[out_qreg0:%.+]] = quantum.insert {{%.+}}[1], {{%.+}} : !quantum.reg, !quantum.bit + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + + // CHECK: quantum.extract [[out_qreg0]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: quantum.extract [[out_qreg0]][1] : !quantum.reg -> !quantum.bit + %q0a = quantum.extract %qreg[0] : !quantum.reg -> !quantum.bit + %q1a = quantum.extract %qreg[1] : !quantum.reg -> !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_0 = arith.constant 0.0 : f64 + %m0, %q0b = mbqc.measure_in_basis [XY, %angle_0] %q0a : i1, !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_pi_2 = arith.constant 1.5707963267948966 : f64 + %m1, %q1b = mbqc.measure_in_basis [XY, %angle_pi_2] %q1a : i1, !quantum.bit + + // CHECK: [[out_qreg1:%.+]] = quantum.insert [[out_qreg0]][0], {{%.+}} : !quantum.reg, !quantum.bit + // CHECK: [[out_qreg2:%.+]] = quantum.insert [[out_qreg1]][1], {{%.+}} : !quantum.reg, !quantum.bit + %reg0 = quantum.insert %qreg[0], %q0b : !quantum.reg, !quantum.bit + %reg1 = quantum.insert %reg0[1], %q1b : !quantum.reg, !quantum.bit + + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_adj_matrix_reuse(self, run_filecheck): + """Test that the decompose-graph-state pass supports the case where we reuse the adjacency + matrix resulting from a constant op multiple times. + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + // CHECK: quantum.alloc(3) + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + %qreg2 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_with_stablehlo_constant(self, run_filecheck): + """Test that the decompose-graph-state pass supports the case where the adjacency matrix + results from a `stablehlo.constant` op (rather than an `arith.constant` op). + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: stablehlo.constant + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + + %adj_matrix = "stablehlo.constant"() <{value = dense<[1, 0, 1]> : tensor<3xi1>}> : () -> tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (DecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + +class TestNullDecomposeGraphStatePass: + """Unit tests for the null-decompose-graph-state pass.""" + + def test_1_qubit(self, run_filecheck): + """Test the null-decompose-graph-state pass for a 1-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[]> : tensor<0xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(1) : !quantum.reg + // CHECK-NOT: quantum.extract + // CHECK-NOT: quantum.custom "Hadamard" + // CHECK-NOT: quantum.custom "CZ" + // CHECK-NOT: quantum.insert + + %adj_matrix = arith.constant dense<[]> : tensor<0xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<0xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_2_qubit_chain(self, run_filecheck): + """Test the null-decompose-graph-state pass for a 2-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + // CHECK-NOT: quantum.extract + // CHECK-NOT: quantum.custom "Hadamard" + // CHECK-NOT: quantum.custom "CZ" + // CHECK-NOT: quantum.insert + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_3_qubit_chain(self, run_filecheck): + """Test the null-decompose-graph-state pass for a 3-qubit graph state.""" + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(3) : !quantum.reg + // CHECK-NOT: quantum.extract + // CHECK-NOT: quantum.custom "Hadamard" + // CHECK-NOT: quantum.custom "CZ" + // CHECK-NOT: quantum.insert + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_register_use(self, run_filecheck): + """Test that the uses of the register resulting from a graph_state_prep op are still correct + after decomposing it into its quantum ops with the null-decompose-graph-state pass. + + We do not rigorously test the decomposition here, only that the resulting register from the + decomposition (specifically, from the quantum.alloc op) is correctly picked up by the ops + that used the register resulting from the original graph_state_prep op. + """ + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1]> : tensor<1xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: [[graph_reg:%.+]] = quantum.alloc(2) : !quantum.reg + + %adj_matrix = arith.constant dense<[1]> : tensor<1xi1> + %qreg = mbqc.graph_state_prep (%adj_matrix : tensor<1xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + + // CHECK: quantum.extract [[graph_reg]][0] : !quantum.reg -> !quantum.bit + // CHECK-NEXT: quantum.extract [[graph_reg]][1] : !quantum.reg -> !quantum.bit + %q0a = quantum.extract %qreg[0] : !quantum.reg -> !quantum.bit + %q1a = quantum.extract %qreg[1] : !quantum.reg -> !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_0 = arith.constant 0.0 : f64 + %m0, %q0b = mbqc.measure_in_basis [XY, %angle_0] %q0a : i1, !quantum.bit + + // CHECK: arith.constant + // CHECK: mbqc.measure_in_basis + %angle_pi_2 = arith.constant 1.5707963267948966 : f64 + %m1, %q1b = mbqc.measure_in_basis [XY, %angle_pi_2] %q1a : i1, !quantum.bit + + // CHECK: [[out_qreg1:%.+]] = quantum.insert [[graph_reg]][0], {{%.+}} : !quantum.reg, !quantum.bit + // CHECK: [[out_qreg2:%.+]] = quantum.insert [[out_qreg1]][1], {{%.+}} : !quantum.reg, !quantum.bit + %reg0 = quantum.insert %qreg[0], %q0b : !quantum.reg, !quantum.bit + %reg1 = quantum.insert %reg0[1], %q1b : !quantum.reg, !quantum.bit + + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_adj_matrix_reuse(self, run_filecheck): + """Test that the null-decompose-graph-state pass supports the case where we reuse the + adjacency matrix resulting from a constant op multiple times. + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: arith.constant dense<[1, 0, 1]> : tensor<3xi1> + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + // CHECK: quantum.alloc(3) + + %adj_matrix = arith.constant dense<[1, 0, 1]> : tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + %qreg2 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) + + def test_with_stablehlo_constant(self, run_filecheck): + """Test that the null-decompose-graph-state pass supports the case where the adjacency matrix + results from a `stablehlo.constant` op (rather than an `arith.constant` op). + """ + + program = """ + // CHECK-LABEL: circuit + func.func @circuit() { + // CHECK-NOT: stablehlo.constant + // CHECK-NOT: mbqc.graph_state_prep + + // CHECK: quantum.alloc(3) + + %adj_matrix = "stablehlo.constant"() <{value = dense<[1, 0, 1]> : tensor<3xi1>}> : () -> tensor<3xi1> + %qreg1 = mbqc.graph_state_prep (%adj_matrix : tensor<3xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg + func.return + } + """ + + pipeline = (NullDecomposeGraphStatePass(),) + run_filecheck(program, pipeline) diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py new file mode 100644 index 0000000000..af040ecca0 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py @@ -0,0 +1,390 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the outline state evolution transform""" +import pytest + +pytestmark = pytest.mark.external + +xdsl = pytest.importorskip("xdsl") +catalyst = pytest.importorskip("catalyst") + +import pennylane as qml +from pennylane.ftqc import RotXZX + +# pylint: disable=wrong-import-position +from catalyst.ftqc import mbqc_pipeline +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + OutlineStateEvolutionPass, + convert_to_mbqc_formalism_pass, + decompose_graph_state_pass, + diagonalize_final_measurements_pass, + measurements_from_samples_pass, + outline_state_evolution_pass, +) + + +@qml.while_loop(lambda i: i < 1000) +def _while_for(i): + qml.H(i) + qml.S(i) + RotXZX(0.1, 0.2, 0.3, wires=[i]) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + +class TestOutlineStateEvolutionPass: + """Unit tests for OutlineStateEvolutionPass.""" + + def test_func_wo_qnode_attr(self, run_filecheck): + """Test outline state evolution pass is not applied to a func without a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit() -> tensor { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: quantum.custom "PauliX"() + // CHECK-NOT: call @circuit.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK-NOT: func.func public @circuit.state_evolution + } + """ + + pipeline = (OutlineStateEvolutionPass(),) + run_filecheck(program, pipeline) + + def test_func_w_qnode_attr(self, run_filecheck): + """Test outline state evolution pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit() -> tensor attributes {qnode} { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK-NOT: quantum.custom "PauliX"() + // CHECK: call @circuit.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func public @circuit.state_evolution + } + """ + + pipeline = (OutlineStateEvolutionPass(),) + run_filecheck(program, pipeline) + + def test_multiple_func_w_qnode_attr(self, run_filecheck): + """Test outline state evolution pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit1() -> tensor attributes {qnode} { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: call @circuit1.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + func.func public @circuit2() -> tensor attributes {qnode} { + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: call @circuit2.state_evolution + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func public @circuit1.state_evolution + // CHECK: func.func public @circuit2.state_evolution + } + """ + + pipeline = (OutlineStateEvolutionPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_outline_state_evolution_no_error(self): + """Test outline_state_evolution_pass does not raise error for circuit with classical operations only.""" + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + def circuit(x, y): + return x * y + 5 + + circuit(1, 4) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_outline_state_evolution_no_terminal_op_error(self): + """Test outline_state_evolution_pass raises error when no terminal_boundary_op is found in circuit with quantum operation.""" + # TODOs: we can resolve this issue if the boundary op is inserted when the program is captured. + dev = qml.device("null.qubit", wires=10) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + @qml.qnode(dev) + def circuit(): + return qml.state() + + with pytest.raises( + RuntimeError, match="A terminal_boundary_op op is not found in the circuit." + ): + circuit() + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_outline_state_evolution_pass_only(self, run_filecheck_qjit): + """Test the outline_state_evolution_pass only.""" + dev = qml.device("lightning.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # CHECK-NOT: scf.while + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT" + # CHECK-NOT: func.func public @circuit.state_evolution + # CHECK: quantum.alloc + # CHECK-NEXT: func.call @circuit.state_evolution + # CHECK-LABEL: func.func public @circuit.state_evolution + # CHECK-NOT: quantum.alloc + # CHECK-NOT: quantum.namedobs + # CHECK: scf.while + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "S"() + # CHECK: quantum.custom "RotXZX" + # CHECK: quantum.custom "RZ" + # CHECK: quantum.custom "CNOT" + _while_for(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_filecheck_qjit): + """Test if the outline_state_evolution_pass works with the convert-to-mbqc-formalism pass on lightning.qubit.""" + dev = qml.device("lightning.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # CHECK-NOT: quantum.custom "Hadamard"() + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: mbqc.measure_in_basis + # CHECK-NOT: scf.if + # CHECK-NOT: quantum.dealloc_qb + # CHECK-LABEL: func.func public @circuit.state_evolution + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.namedobs + # CHECK: scf.while + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + _while_for(0) + qml.H(0) + qml.S(1) + RotXZX(0.1, 0.2, 0.3, wires=[2]) + qml.RZ(phi=0.1, wires=[3]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.X(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qjit): + """Test if the outline_state_evolution_pass works with all mbqc transform pipeline on null.qubit.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @measurements_from_samples_pass + @diagonalize_final_measurements_pass + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # NOTE: There is scf.if, mbqc.measure_in_basis in the circuit() + # scope as X obs is decomposed into H@Z and H is converted to MBQC formalism. + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-LABEL: func.func public @circuit.state_evolution + # CHECK-NOT: quantum.custom "S"() + # CHECK-NOT: quantum.custom "RotXZX" + # CHECK-NOT: quantum.custom "RZ" + # CHECK-NOT: quantum.custom "CNOT"() + # CHECK-NOT: quantum.namedobs + # CHECK: scf.while + # CHECK: quantum.custom "Hadamard"() + # CHECK: quantum.custom "CZ"() + # CHECK: mbqc.measure_in_basis + # CHECK: scf.if + # CHECK: quantum.custom "PauliX"() + # CHECK: quantum.custom "PauliZ"() + # CHECK: quantum.dealloc_qb + _while_for(0) + qml.H(0) + qml.S(1) + RotXZX(0.1, 0.2, 0.3, wires=[2]) + qml.RZ(phi=0.1, wires=[3]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.X(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self): + """Test if a circuit can be transfored with the outline_state_evolution_pass and all mbqc transform pipeline can be executed on null.qubit.""" + dev = qml.device("null.qubit", wires=1000) + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @measurements_from_samples_pass + @diagonalize_final_measurements_pass + @outline_state_evolution_pass + @qml.set_shots(1000) + @qml.qnode(dev) + def circuit(): + _while_for(0) + qml.H(0) + qml.S(1) + RotXZX(0.1, 0.2, 0.3, wires=[2]) + qml.RZ(phi=0.1, wires=[3]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.X(wires=0)) + + res = circuit() + assert res == 1.0 + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_lightning_execution_with_structure(self): + """Test that the outline_state_evolution_pass on lightning.qubit for a circuit with program structure is executable and returns results as expected.""" + dev = qml.device("lightning.qubit", wires=10) + + @qml.for_loop(0, 10, 1) + def for_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + + @qml.while_loop(lambda i: i < 10) + def while_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @outline_state_evolution_pass + @qml.qnode(dev) + def circuit(): + for_fn() + while_fn(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.prod(qml.X(0), qml.Z(1))) + + res = circuit() + + @qml.qjit( + target="mlir", + ) + @qml.qnode(dev) + def circuit_ref(): + for_fn() + while_fn(0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.prod(qml.X(0), qml.Z(1))) + + res_ref = circuit_ref() + assert res == res_ref diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py new file mode 100644 index 0000000000..db3ddfa312 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py @@ -0,0 +1,234 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the iterative cancel inverses transform""" + +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + +import pennylane as qml + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + IterativeCancelInversesPass, + iterative_cancel_inverses_pass, +) + + +class TestIterativeCancelInversesPass: + """Unit tests for IterativeCancelInversesPass.""" + + def test_no_inverses_same_qubit(self, run_filecheck): + """Test that nothing changes when there are no inverses.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "PauliY"() [[q1]] : !quantum.bit + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_inverses_different_qubits(self, run_filecheck): + """Test that nothing changes when there are no inverses.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q1]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + %3 = quantum.custom "PauliX"() %1 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_simple_self_inverses(self, run_filecheck): + """Test that inverses are cancelled.""" + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: quantum.custom + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliX"() %1 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_nested_self_inverses(self, run_filecheck): + """Test that nested self-inverses are cancelled.""" + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: quantum.custom + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + %3 = quantum.custom "PauliZ"() %2 : !quantum.bit + %4 = quantum.custom "PauliZ"() %3 : !quantum.bit + %5 = quantum.custom "PauliY"() %4 : !quantum.bit + %6 = quantum.custom "PauliX"() %5 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_cancel_ops_with_control_qubits(self, run_filecheck): + """Test that ops with control qubits can be cancelled.""" + program = """ + func.func @test_func() { + %0 = "arith.constant"() <{value = true}> : () -> i1 + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: quantum.custom + %4, %5, %6 = quantum.custom "PauliY"() %1 ctrls(%2, %3) ctrlvals(%0, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %7, %8, %9 = quantum.custom "PauliY"() %4 ctrls(%5, %6) ctrlvals(%0, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_cancel_ops_with_same_control_qubits_and_values(self, run_filecheck): + """Test that ops with control qubits and control values can be + cancelled.""" + program = """ + func.func @test_func() { + %0 = arith.constant false + %1 = arith.constant true + %2 = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + %4 = "test.op"() : () -> !quantum.bit + // CHECK-NOT: "quantum.custom" + %5, %6, %7 = quantum.custom "PauliY"() %2 ctrls(%3, %4) ctrlvals(%1, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %8, %9, %10 = quantum.custom "PauliY"() %5 ctrls(%6, %7) ctrlvals(%1, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_ops_with_control_qubits_different_control_values(self, run_filecheck): + """Test that ops with the same control qubits but different + control values don't cancel.""" + program = """ + func.func @test_func() { + // CHECK-DAG: [[cval0:%.+]] = arith.constant false + // CHECK-DAG: [[cval1:%.+]] = arith.constant true + %0 = arith.constant false + %1 = arith.constant true + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + %4 = "test.op"() : () -> !quantum.bit + // CHECK: [[q3:%.+]], [[q4:%.+]], [[q5:%.+]] = quantum.custom "PauliY"() [[q0]] ctrls([[q1]], [[q2]]) ctrlvals([[cval1]], [[cval0]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK: quantum.custom "PauliY"() [[q3]] ctrls([[q4]], [[q5]]) ctrlvals([[cval0]], [[cval1]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %5, %6, %7 = quantum.custom "PauliY"() %2 ctrls(%3, %4) ctrlvals(%1, %0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %8, %9, %10 = quantum.custom "PauliY"() %5 ctrls(%6, %7) ctrlvals(%0, %1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + def test_non_consecutive_self_inverse_ops(self, run_filecheck): + """Test that self-inverse gates on the same qubit that are not + consecutive are not cancelled.""" + program = """ + func.func @test_func() { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "PauliX"() [[q0]] : !quantum.bit + // CHECK: [[q2:%.+]] = quantum.custom "PauliY"() [[q1]] : !quantum.bit + // CHECK: quantum.custom "PauliX"() [[q2]] : !quantum.bit + %1 = quantum.custom "PauliX"() %0 : !quantum.bit + %2 = quantum.custom "PauliY"() %1 : !quantum.bit + %3 = quantum.custom "PauliX"() %2 : !quantum.bit + return + } + """ + + pipeline = (IterativeCancelInversesPass(),) + run_filecheck(program, pipeline) + + +# pylint: disable=too-few-public-methods +@pytest.mark.usefixtures("enable_disable_plxpr") +class TestIterativeCancelInversesIntegration: + """Integration tests for the IterativeCancelInversesPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the IterativeCancelInversesPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @iterative_cancel_inverses_pass + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom + qml.H(0) + qml.X(0) + qml.X(0) + qml.H(0) + return qml.state() + + run_filecheck_qjit(circuit) + + def test_qjit_no_cancellation(self, run_filecheck_qjit): + """Test that the IterativeCancelInversesPass works correctly with qjit when + there are no operations that can be cancelled.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @iterative_cancel_inverses_pass + @qml.qnode(dev) + def circuit(): + # CHECK-NOT: quantum.custom + qml.H(1) + qml.X(1) + qml.X(0) + qml.H(0) + return qml.state() + + with pytest.raises(AssertionError, match="filecheck failed"): + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py new file mode 100644 index 0000000000..cdfa2795d4 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py @@ -0,0 +1,244 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the combine global phases transform""" +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + +import pennylane as qml + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + CombineGlobalPhasesPass, + combine_global_phases_pass, +) + + +class TestCombineGlobalPhasesPass: + """Unit tests for CombineGlobalPhasesPass.""" + + def test_combinable_ops_without_control_flow(self, run_filecheck): + """Test that combines global phases in a func without control flow.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi_sum:%.+]] = arith.addf %arg1, %arg0 : f64 + // CHECK: quantum.gphase([[phi_sum]]) + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: [[q1:%.+]] = quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_with_control_flow(self, run_filecheck): + """Test that combines global phases in a func with control flow.""" + program = """ + func.func @test_func(%cond: i32, %arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + quantum.gphase %arg0 + // CHECK: [[ret:%.+]] = scf.if %cond -> (f64) { + %ret = scf.if %cond -> (f64) { + // CHECK: [[two:%.+]] = arith.constant {{2.+}} : f64 + %two = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf %arg0, [[two]] : f64 + %arg0x2 = arith.mulf %arg0, %two : f64 + // CHECK: scf.yield [[arg02]] : f64 + scf.yield %arg0x2 : f64 + } else { + // CHECK: [[two_1:%.+]] = arith.constant {{2.+}} : f64 + %two_1 = arith.constant 2 : f64 + // CHECK: [[arg12:%.+]] = arith.mulf %arg1, [[two_1]] : f64 + %arg1x2 = arith.mulf %arg1, %two_1 : f64 + // CHECK: scf.yield [[arg12:%.+]] : f64 + scf.yield %arg1x2 : f64 + } + // CHECK: [[phi_sum:%.+]] = arith.addf [[ret]], %arg0 : f64 + // CHECK: quantum.gphase([[phi_sum]]) : + quantum.gphase %ret + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_in_control_flow_if(self, run_filecheck): + """Test that combines global phases in a func without control a flow. + Here the control flow is an `if` operation. + """ + program = """ + // CHECK: func.func @test_func(%cond : i32, [[arg0:%.+]] : f64, [[arg1:%.+]] : f64) + func.func @test_func(%cond : i32, %arg0 : f64, %arg1 : f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[ret:%.+]] = scf.if [[cond:%.+]] -> (f64) { + %ret = scf.if %cond -> (f64) { + // CHECK: [[two0:%.+]] = arith.constant {{2.+}} : f64 + %two0 = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf [[arg0]], [[two0]] : f64 + %arg02 = arith.mulf %arg0, %two0 : f64 + // CHECK: [[t0:%.+]] = "test.op"() : () -> f64 + %t0 = "test.op"() : () -> f64 + // CHECK: [[phi_sum_0:%.+]] = arith.addf [[t0]], [[t0]] : f64 + // CHECK: quantum.gphase([[phi_sum_0]]) + quantum.gphase %t0 + quantum.gphase %t0 + // CHECK: scf.yield [[arg02]] : f64 + scf.yield %arg02 : f64 + // CHECK: } else { + } else { + // CHECK: [[two1:%.+]] = arith.constant {{2.+}} : f64 + %two1 = arith.constant 2 : f64 + // CHECK: [[arg1x2:%.+]] = arith.mulf [[arg1]], [[two1]] : f64 + %arg1x2 = arith.mulf %arg1, %two1 : f64 + // CHECK: [[phi_sum_1:%.+]] = arith.addf [[arg1x2]], [[arg1x2]] : f64 + // CHECK: quantum.gphase([[phi_sum_1]]) + quantum.gphase %arg1x2 + quantum.gphase %arg1x2 + // CHECK: scf.yield [[arg1x2]] : f64 + scf.yield %arg1x2 : f64 + // CHECK: } + } + // CHECK: [[phi_sum_2:%.+]] = arith.addf [[arg1]], [[arg0]] : f64 + // CHECK: quantum.gphase([[phi_sum_2:%.+]]) + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_in_control_flow_for(self, run_filecheck): + """Test that combines global phases in a func with a control flow. + Here the control flow is a `for` operation. + """ + program = """ + // CHECK: func.func @test_func(%n : i32, [[arg0:%.+]] : f64, [[arg1:%.+]] : f64) + func.func @test_func(%n : i32, %arg0 : f64, %arg1 : f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[c0:%.+]] = arith.constant 0 : i32 + // CHECK: [[c1:%.+]] = arith.constant 1 : i32 + %0 = "test.op"() : () -> !quantum.bit + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + scf.for %i = %c0 to %n step %c1 { + // CHECK: [[two0:%.+]] = arith.constant {{2.+}} : f64 + %two0 = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf [[arg0]], [[two0]] : f64 + %arg02 = arith.mulf %arg0, %two0 : f64 + // CHECK: [[t0:%.+]] = "test.op"() : () -> f64 + %t0 = "test.op"() : () -> f64 + // CHECK: [[phi_sum_0:%.+]] = arith.addf [[t0]], [[t0]] : f64 + // CHECK: quantum.gphase([[phi_sum_0]]) + quantum.gphase %t0 + quantum.gphase %t0 + } + // CHECK: [[phi_sum_1:%.+]] = arith.addf [[arg1]], [[arg0]] : f64 + // CHECK: quantum.gphase([[phi_sum_1]]) + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + def test_combinable_ops_in_control_flow_while(self, run_filecheck): + """Test that combines global phases in a func with control flow. + Here the control flow is a `while` operation. + """ + program = """ + // CHECK: func.func @test_func(%n : i32, [[arg0:%.+]] : f64, [[arg1:%.+]] : f64) + func.func @test_func(%n : i32, %arg0 : f64, %arg1 : f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + scf.while (%current_i = %n) : (i32) -> () { + %cond = arith.cmpi slt, %current_i, %n : i32 + scf.condition(%cond) %current_i : i32 + } do { + ^bb0(%i : i32): + %next_i = arith.addi %i, %c1 : i32 + // CHECK: [[two0:%.+]] = arith.constant {{2.+}} : f64 + %two0 = arith.constant 2 : f64 + // CHECK: [[arg02:%.+]] = arith.mulf [[arg0]], [[two0]] : f64 + %arg02 = arith.mulf %arg0, %two0 : f64 + // CHECK: [[t0:%.+]] = "test.op"() : () -> f64 + %t0 = "test.op"() : () -> f64 + // CHECK: [[phi_sum_0:%.+]] = arith.addf [[t0]], [[t0]] : f64 + // CHECK: quantum.gphase([[phi_sum_0]]) + quantum.gphase %t0 + quantum.gphase %t0 + scf.yield %next_i : i32 + } + // CHECK: [[phi_sum_1:%.+]] = arith.addf [[arg1]], [[arg0]] : f64 + // CHECK: quantum.gphase([[phi_sum_1]]) : + quantum.gphase %arg0 + quantum.gphase %arg1 + // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit + %2 = quantum.custom "PauliX"() %0 : !quantum.bit + return + } + """ + + pipeline = (CombineGlobalPhasesPass(),) + run_filecheck(program, pipeline) + + +# pylint: disable=too-few-public-methods +@pytest.mark.usefixtures("enable_disable_plxpr") +class TestCombineGlobalPhasesIntegration: + """Integration tests for the CombineGlobalPhasesPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the CombineGlobalPhasesPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @combine_global_phases_pass + @qml.qnode(dev) + def circuit(x: float, y: float): + # CHECK: [[phi:%.+]] = arith.addf + # CHECK: quantum.gphase([[phi]]) + # CHECK-NOT: quantum.gphase + qml.GlobalPhase(x) + qml.GlobalPhase(y) + return qml.state() + + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py new file mode 100644 index 0000000000..376dd20923 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py @@ -0,0 +1,597 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the xDSL implementation of the diagonalize_final_measurements pass""" + + +# pylint: disable=wrong-import-position + +import numpy as np +import pytest + +pytestmark = pytest.mark.external + +xdsl = pytest.importorskip("xdsl") + +catalyst = pytest.importorskip("catalyst") +import pennylane as qml + +from catalyst.passes import xdsl_plugin +from catalyst.python_interface.transforms import ( + DiagonalizeFinalMeasurementsPass, + diagonalize_final_measurements_pass, +) + + +class TestDiagonalizeFinalMeasurementsPass: + """Unit tests for the diagonalize-final-measurements pass.""" + + def test_with_pauli_z(self, run_filecheck): + """Test that a PauliZ observable is not affected by diagonalization""" + + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: quantum.namedobs %0[PauliZ] : !quantum.obs + %1 = quantum.namedobs %0[PauliZ] : !quantum.obs + + // CHECK: quantum.expval %1 : f64 + %2 = quantum.expval %1 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_identity(self, run_filecheck): + """Test that an Identity observable is not affected by diagonalization.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: quantum.namedobs %0[Identity] : !quantum.obs + %1 = quantum.namedobs %0[Identity] : !quantum.obs + + // CHECK: quantum.var %1 : f64 + %2 = quantum.var %1 : f64 + return + } + """ + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_pauli_x(self, run_filecheck): + """Test that when diagonalizing a PauliX observable, the expected diagonalizing + gates are inserted and the observable becomes PauliZ.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "Hadamard"() [[q0]] + // CHECK-NEXT: [[q0_2:%.*]] = quantum.namedobs [[q0_1]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliX] + %1 = quantum.namedobs %0[PauliX] : !quantum.obs + + // CHECK: quantum.expval [[q0_2]] + %2 = quantum.expval %1 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_pauli_y(self, run_filecheck): + """Test that when diagonalizing a PauliY observable, the expected diagonalizing + gates are inserted and the observable becomes PauliZ.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "PauliZ"() [[q0]] + // CHECK-NEXT: [[q0_2:%.*]] = quantum.custom "S"() [[q0_1]] + // CHECK-NEXT: [[q0_3:%.*]] = quantum.custom "Hadamard"() [[q0_2]] + // CHECK-NEXT: [[q0_4:%.*]] = quantum.namedobs [[q0_3]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliY] + %1 = quantum.namedobs %0[PauliY] : !quantum.obs + + // CHECK: quantum.expval [[q0_4]] + %2 = quantum.expval %1 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_hadamard(self, run_filecheck): + """Test that when diagonalizing a Hadamard observable, the expected diagonalizing + gates are inserted and the observable becomes PauliZ.""" + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + + // CHECK: [[quarter_pi:%.*]] = arith.constant -0.78539816339744828 : f64 + // CHECK-NEXT: [[q0_1:%.*]] = quantum.custom "RY"([[quarter_pi]]) [[q0]] + // CHECK-NEXT: [[q0_2:%.*]] = quantum.namedobs [[q0_1]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][Hadamard] + %1 = quantum.namedobs %0[Hadamard] : !quantum.obs + + // CHECK: quantum.expval [[q0_2]] + %2 = quantum.expval %1 : f64 + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_composite_observable(self, run_filecheck): + """Test transform on a measurement process with a composite observable. In this + case, the simplified program is based on the MLIR generated by the circuit + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + return qml.expval(qml.Y(0)@qml.X(1) + qml.Z(2)) + """ + + program = """ + func.func @test_func() { + // CHECK: [[q0:%.*]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.*]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.*]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "PauliZ"() [[q0]] + // CHECK: [[q0_2:%.*]] = quantum.custom "S"() [[q0_1]] + // CHECK: [[q0_3:%.*]] = quantum.custom "Hadamard"() [[q0_2]] + // CHECK: [[q_y:%.*]] = quantum.namedobs [[q0_3]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliY] + %3 = quantum.namedobs %0[PauliY] : !quantum.obs + + // CHECK: [[q1_1:%.*]] = quantum.custom "Hadamard"() [[q1]] + // CHECK: [[q_x:%.*]] = quantum.namedobs [[q1_1]][PauliZ] + // CHECK-NOT: quantum.namedobs [[q:%.+]][PauliX] + %4 = quantum.namedobs %1[PauliX] : !quantum.obs + + // CHECK: [[tensor0:%.*]] = quantum.tensor [[q_y]], [[q_x]] : !quantum.obs + %5 = quantum.tensor %3, %4 : !quantum.obs + + // CHECK: [[q_z:%.*]] = quantum.namedobs [[q2]][PauliZ] : !quantum.obs + %6 = quantum.namedobs %2[PauliZ] : !quantum.obs + + // CHECK: [[size:%.*]] = "test.op"() : () -> tensor<2xf64> + %size_info = "test.op"() : () -> tensor<2xf64> + + // CHECK: quantum.hamiltonian([[size]] : tensor<2xf64>) [[tensor0]], [[q_z]] : !quantum.obs + %7 = quantum.hamiltonian(%size_info : tensor<2xf64>) %5, %6 : !quantum.obs + + // CHECK: quantum.expval + %8 = quantum.expval %7 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_with_multiple_measurements(self, run_filecheck): + """Test diagonalizing a circuit with multiple measurements. The simplified program + for this test is based on the circuit + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + return qml.var(qml.Y(0)), qml.var(qml.X(1)) + """ + + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + + // CHECK: quantum.custom "PauliZ"() + // CHECK-NEXT: quantum.custom "S"() + // CHECK-NEXT: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %2 = quantum.namedobs %0[PauliY] : !quantum.obs + // CHECK: quantum.var + %3 = quantum.var %2 : f64 + + + // CHECK: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %4 = quantum.namedobs %1[PauliX] : !quantum.obs + + // CHECK: quantum.expval + %5 = quantum.expval %4 : f64 + return + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + def test_overlapping_observables_raises_error(self, run_filecheck): + """Test the case where multiple overlapping (commuting) observables exist in + the same circuit (an error is raised - split_non_commuting should have been applied). + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + return qml.var(qml.X(0)), qml.var(qml.X(0)) + """ + + program = """ + func.func @test_func() { + %0 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %1 = quantum.namedobs %0[PauliX] : !quantum.obs + %2 = quantum.var %1 : f64 + // CHECK: quantum.custom "Hadamard"() + // CHECK-NEXT: quantum.namedobs [[q:%.+]][PauliZ] + %3 = quantum.namedobs %0[PauliX] : !quantum.obs + %4 = quantum.var %3 : f64 + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + + with pytest.raises( + RuntimeError, match="the circuit contains multiple observables with the same wire" + ): + run_filecheck(program, pipeline) + + def test_additional_qubit_uses_are_updated(self, run_filecheck): + """Test that when diagonalizing the circuit, if the MLIR contains + later manipulations of the qubit going into the observable, these are + updated as well. While quantum.custom operations can't be applied to + the same SSA value that is passed to the observable, it can still + be inserted into a register or deallocated. + + The simplified program for this test is based on the circuit + + @qml.qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + return qml.expval(qml.X(1)) + """ + + # we expect that instead of the SSA value that comes out of quantum.extract being passed to + # both quantum.namedobs and the quantum.insert, it will be passed to the Hadamard, and the + # SSA value that is output by the *Hadmard* operation will be passed to namedobs and insert. + program = """ + func.func @test_func() { + %0 = quantum.alloc(3) : !quantum.reg + %1 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %2 = tensor.extract %1[] : tensor + // CHECK: [[q0:%.*]] = quantum.extract + %3 = quantum.extract %0[%2] : !quantum.reg -> !quantum.bit + + // CHECK: [[q0_1:%.*]] = quantum.custom "Hadamard"() [[q0]] + // CHECK-NEXT: quantum.namedobs [[q0_1]][PauliZ] + %4 = quantum.namedobs %3[PauliX] : !quantum.obs + %5 = quantum.expval %4 : f64 + + // CHECK: quantum.insert [[q:%.+]][[[q:%.+]]], [[q0_1]] + %6 = tensor.extract %1[] : tensor + %7 = quantum.insert %0[%6], %3 : !quantum.reg, !quantum.bit + quantum.dealloc %7 : !quantum.reg + } + """ + + pipeline = (DiagonalizeFinalMeasurementsPass(),) + run_filecheck(program, pipeline) + + +class TestDiagonalizeFinalMeasurementsProgramCaptureExecution: + """Integration tests going through plxpr (program capture enabled)""" + + # pylint: disable=unnecessary-lambda + @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.parametrize( + "mp, obs, expected_res", + [ + (qml.expval, qml.Identity, lambda x: 1), + (qml.var, qml.Identity, lambda x: 0), + (qml.expval, qml.X, lambda x: 0), + (qml.var, qml.X, lambda x: 1), + (qml.expval, qml.Y, lambda x: -np.sin(x)), + (qml.var, qml.Y, lambda x: 1 - np.sin(x) ** 2), + (qml.expval, qml.Z, lambda x: np.cos(x)), + (qml.var, qml.Z, lambda x: 1 - np.cos(x) ** 2), + (qml.expval, qml.Hadamard, lambda x: np.cos(x) / np.sqrt(2)), + (qml.var, qml.Hadamard, lambda x: (2 - np.cos(x) ** 2) / 2), + ], + ) + def test_with_single_obs(self, mp, obs, expected_res): + """Test the diagonalization transform for a circuit with a single measurement + of a single supported observable""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit_ref(phi): + qml.RX(phi, 0) + return mp(obs(0)) + + angle = 0.7692 + + assert np.allclose( + expected_res(angle), circuit_ref(angle) + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + diagonalize_final_measurements_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.allclose(expected_res(angle), circuit_compiled(angle)) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_with_composite_observables(self): + """Test the transform works for an observable built using operator arithmetic + (sprod, prod, sum)""" + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + qml.RY(y / 2, 2) + return qml.expval(qml.Y(0) @ qml.X(1) + 3 * qml.X(2)) + + def expected_res(x, y): + y0_res = -np.sin(x) + x1_res = np.sin(y) + x2_res = np.sin(y / 2) + return y0_res * x1_res + 3 * x2_res + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + diagonalize_final_measurements_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_with_multiple_measurements(self): + """Test that the transform runs and returns the expected results for + a circuit with multiple measurements""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + return qml.expval(qml.Y(0)), qml.var(qml.X(1)) + + def expected_res(x, y): + return -np.sin(x), 1 - np.sin(y) ** 2 + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + diagonalize_final_measurements_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_overlapping_observables_raises_error(self): + """Test the case where multiple overlapping (commuting) observables exist in + the same circuit (an error is raised - split_non_commuting should have been applied).""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @diagonalize_final_measurements_pass + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.var(qml.Y(0)) + + with pytest.raises( + RuntimeError, match="the circuit contains multiple observables with the same wire" + ): + _ = circuit(1.23) + + @pytest.mark.xfail(reason="for now, assume split_non_commuting is always applied") + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_non_commuting_observables_raise_error(self): + """Check that an error is raised if we try to diagonalize a circuit that contains + non-commuting observables.""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @diagonalize_final_measurements_pass + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.expval(qml.X(0)) + + with pytest.raises( + RuntimeError, match="cannot diagonalize circuit with non-commuting observables" + ): + _ = circuit(0.7) + + +class TestDiagonalizeFinalMeasurementsCatalystFrontend: + """Integration tests going through the catalyst frontend (program capture disabled)""" + + # pylint: disable=unnecessary-lambda + @pytest.mark.parametrize( + "mp, obs, expected_res", + [ + (qml.expval, qml.Identity, lambda x: 1), + (qml.var, qml.Identity, lambda x: 0), + (qml.expval, qml.X, lambda x: 0), + (qml.var, qml.X, lambda x: 1), + (qml.expval, qml.Y, lambda x: -np.sin(x)), + (qml.var, qml.Y, lambda x: 1 - np.sin(x) ** 2), + (qml.expval, qml.Z, lambda x: np.cos(x)), + (qml.var, qml.Z, lambda x: 1 - np.cos(x) ** 2), + (qml.expval, qml.Hadamard, lambda x: np.cos(x) / np.sqrt(2)), + (qml.var, qml.Hadamard, lambda x: (2 - np.cos(x) ** 2) / 2), + ], + ) + def test_with_single_obs(self, mp, obs, expected_res): + """Test the diagonalization transform for a circuit with a single measurement + of a single supported observable""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit_ref(phi): + qml.RX(phi, 0) + return mp(obs(0)) + + angle = 0.7692 + + assert np.allclose( + expected_res(angle), circuit_ref(angle) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")( + circuit_ref + ), + ) + + np.allclose(expected_res(angle), circuit_compiled(angle)) + + def test_with_composite_observables(self): + """Test the transform works for an observable built using operator arithmetic + (sprod, prod, sum)""" + + dev = qml.device("lightning.qubit", wires=3) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + qml.RY(y / 2, 2) + return qml.expval(qml.Y(0) @ qml.X(1) + 3 * qml.X(2)) + + def expected_res(x, y): + y0_res = -np.sin(x) + x1_res = np.sin(y) + x2_res = np.sin(y / 2) + return y0_res * x1_res + 3 * x2_res + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")( + circuit_ref + ), + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + def test_with_multiple_measurements(self): + """Test that the transform runs and returns the expected results for + a circuit with multiple measurements""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev) + def circuit_ref(x, y): + qml.RX(x, 0) + qml.RY(y, 1) + return qml.expval(qml.Y(0)), qml.var(qml.X(1)) + + def expected_res(x, y): + return -np.sin(x), 1 - np.sin(y) ** 2 + + phi = 0.3867 + theta = 1.394 + + assert np.allclose( + expected_res(phi, theta), circuit_ref(phi, theta) + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")( + circuit_ref + ), + ) + + assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) + + def test_overlapping_observables_raises_error(self): + """Test the case where multiple overlapping (commuting) observables exist in + the same circuit (an error is raised - split_non_commuting should have been applied).""" + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit() + @catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.var(qml.Y(0)) + + with pytest.raises( + RuntimeError, match="the circuit contains multiple observables with the same wire" + ): + _ = circuit(1.23) + + @pytest.mark.xfail(reason="for now, assume split_non_commuting is always applied") + def test_non_commuting_observables_raise_error(self): + """Check that an error is raised if we try to diagonalize a circuit that contains + non-commuting observables.""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit() + @catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") + @qml.qnode(dev) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Y(0)), qml.expval(qml.X(0)) + + with pytest.raises( + RuntimeError, match="cannot diagonalize circuit with non-commuting observables" + ): + _ = circuit(0.7) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py new file mode 100644 index 0000000000..baaf6d381a --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py @@ -0,0 +1,889 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit and integration tests for the Python compiler `measurements_from_samples` transform.""" + +# pylint: disable=wrong-import-position + +from functools import partial + +import numpy as np +import pytest + +pytestmark = pytest.mark.external + +xdsl = pytest.importorskip("xdsl") +catalyst = pytest.importorskip("catalyst") +import pennylane as qml + +from catalyst.passes import xdsl_plugin +from catalyst.python_interface.transforms import ( + MeasurementsFromSamplesPass, + measurements_from_samples_pass, +) + + +class TestMeasurementsFromSamplesPass: + """Unit tests for the measurements-from-samples pass.""" + + def test_1_wire_expval(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with an expval(Z) + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %3 = quantum.namedobs %2[PauliZ] : !quantum.obs + + // CHECK: [[obs:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.expval + %4 = quantum.expval %3 : f64 + %5 = "tensor.from_elements"(%4) : (f64) -> tensor + + // CHECK: func.return [[res]] : tensor + func.return %5 : tensor + } + // CHECK-LABEL: func.func public @expval_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_expval_shots_from_arith_constantop(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit with shots from an arith.constant op and an expval(Z) measurement.""" + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = arith.constant 1 : i64 + quantum.device shots(%0) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %2 = quantum.namedobs %1[PauliZ] : !quantum.obs + + // CHECK: [[obs:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.expval + %3 = quantum.expval %2 : f64 + %4 = "tensor.from_elements"(%3) : (f64) -> tensor + + // CHECK: func.return [[res]] : tensor + func.return %4 : tensor + } + // CHECK-LABEL: func.func public @expval_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_var(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a var(Z) + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %3 = quantum.namedobs %2[PauliZ] : !quantum.obs + + // CHECK: [[obs:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res:%.+]] = func.call @var_from_samples.tensor.1x1xf64([[samples]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.var + %4 = quantum.var %3 : f64 + %5 = "tensor.from_elements"(%4) : (f64) -> tensor + + // CHECK: func.return [[res]] : tensor + func.return %5 : tensor + } + // CHECK-LABEL: func.func public @var_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_probs(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a probs + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + %2 = "test.op"() <{nqubits_attr = 1 : i64}> : () -> !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[q0]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample [[compbasis]] : tensor<1x1xf64> + // CHECK: [[res:%.+]] = func.call @probs_from_samples.tensor.1x1xf64([[samples]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64> + // CHECK-NOT: quantum.probs + %4 = quantum.probs %3 : tensor<2xf64> + + // CHECK: func.return [[res]] : tensor<2xf64> + func.return %4 : tensor<2xf64> + } + // CHECK-LABEL: func.func public @probs_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_1_wire_sample(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a sample + measurement. + + This pass should be a no-op. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.reg + %2 = "test.op"() : () -> !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[q0]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample [[compbasis]] : tensor<1x1xf64> + %4 = quantum.sample %3 : tensor<1x1xf64> + + // CHECK: func.return [[samples]] : tensor<1x1xf64> + func.return %4 : tensor<1x1xf64> + } + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.xfail(reason="Counts not supported", strict=True, raises=NotImplementedError) + def test_1_wire_counts(self, run_filecheck): + """Test the measurements-from-samples pass on a 1-wire circuit terminating with a counts + measurement. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.reg + %2 = "test.op"() : () -> !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[q0]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample %3 : tensor<1x1xf64> + // CHECK: [[eigvals:%.+]], [[counts:%.+]] = func.call @counts_from_samples.tensor.1x1xf64([[samples]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64>, tensor<2xi64> + // CHECK-NOT: quantum.counts + %eigvals, %counts = quantum.counts %3 : tensor<2xf64>, tensor<2xi64> + + // CHECK: [[eigvals_converted:%.+]] = {{.*}}stablehlo.convert{{.+}}[[eigvals]] : + // CHECK-SAME: (tensor<2xf64>) -> tensor<2xi64> + %4 = "stablehlo.convert"(%eigvals) : (tensor<2xf64>) -> tensor<2xi64> + + // CHECK: func.return [[eigvals_converted]], [[counts]] : tensor<1x1xf64> + func.return %4, %counts : tensor<2xi64>, tensor<2xi64> + } + // CHECK-LABEL: func.func public @counts_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_expval(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with an expval(Z) + measurement on each wire. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %4 = quantum.namedobs %2[PauliZ] : !quantum.obs + %5 = quantum.namedobs %3[PauliZ] : !quantum.obs + + // CHECK: [[obs0:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples0:%.+]] = quantum.sample [[obs0]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[obs1:%.+]] = quantum.compbasis qubits [[q1]] : !quantum.obs + // CHECK: [[samples1:%.+]] = quantum.sample [[obs1]] : tensor<1x1xf64> + // CHECK: [[c1:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res0:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples0]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK: [[res1:%.+]] = func.call @expval_from_samples.tensor.1x1xf64([[samples1]], [[c1]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.expval + %6 = quantum.expval %4 : f64 + %7 = "tensor.from_elements"(%6) : (f64) -> tensor + %8 = quantum.expval %5 : f64 + %9 = "tensor.from_elements"(%8) : (f64) -> tensor + + // CHECK: func.return [[res0]], [[res1]] : tensor, tensor + func.return %7, %9 : tensor, tensor + } + // CHECK-LABEL: func.func public @expval_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_var(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with a var(Z) + measurement on each wire. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %3 = "test.op"() : () -> !quantum.bit + + // CHECK-NOT: quantum.namedobs + %4 = quantum.namedobs %2[PauliZ] : !quantum.obs + %5 = quantum.namedobs %3[PauliZ] : !quantum.obs + + // CHECK: [[obs0:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + // CHECK: [[samples0:%.+]] = quantum.sample [[obs0]] : tensor<1x1xf64> + // CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[obs1:%.+]] = quantum.compbasis qubits [[q1]] : !quantum.obs + // CHECK: [[samples1:%.+]] = quantum.sample [[obs1]] : tensor<1x1xf64> + // CHECK: [[c1:%.+]] = arith.constant dense<0> : tensor + // CHECK: [[res0:%.+]] = func.call @var_from_samples.tensor.1x1xf64([[samples0]], [[c0]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK: [[res1:%.+]] = func.call @var_from_samples.tensor.1x1xf64([[samples1]], [[c1]]) : + // CHECK-SAME: (tensor<1x1xf64>, tensor) -> tensor + // CHECK-NOT: quantum.var + %6 = quantum.var %4 : f64 + %7 = "tensor.from_elements"(%6) : (f64) -> tensor + %8 = quantum.var %5 : f64 + %9 = "tensor.from_elements"(%8) : (f64) -> tensor + + // CHECK: func.return [[res0]], [[res1]] : tensor, tensor + func.return %7, %9 : tensor, tensor + } + // CHECK-LABEL: func.func public @var_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_probs_global(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with a "global" + probs measurement (one that implicitly acts on all wires). + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[qreg:%.+]] = quantum.alloc + %2 = quantum.alloc(2) : !quantum.reg + + // CHECK: [[compbasis:%.+]] = quantum.compbasis qreg [[qreg]] : !quantum.obs + %3 = quantum.compbasis qreg %2 : !quantum.obs + + // CHECK: [[samples:%.+]] = quantum.sample [[compbasis]] : tensor<1x2xf64> + // CHECK: [[res:%.+]] = func.call @probs_from_samples.tensor.1x2xf64([[samples]]) : + // CHECK-SAME: (tensor<1x2xf64>) -> tensor<4xf64> + // CHECK-NOT: quantum.probs + %4 = quantum.probs %3 : tensor<4xf64> + + // CHECK: func.return [[res]] : tensor<4xf64> + func.return %4 : tensor<4xf64> + } + // CHECK-LABEL: func.func public @probs_from_samples.tensor.1x2xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + def test_2_wire_probs_per_wire(self, run_filecheck): + """Test the measurements-from-samples pass on a 2-wire circuit terminating with separate + probs measurements per wire. + """ + program = """ + builtin.module @module_circuit { + // CHECK-LABEL: circuit + func.func public @circuit() -> (tensor) { + %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["", "", ""] + + // CHECK: [[qreg:%.+]] = quantum.alloc + %2 = quantum.alloc(2) : !quantum.reg + + // CHECK: [[q0:%.+]] = quantum.extract [[qreg]][0] + %3 = quantum.extract %2[0] : !quantum.reg -> !quantum.bit + + // CHECK: [[compbasis0:%.+]] = quantum.compbasis qubits [[q0]] : !quantum.obs + %4 = quantum.compbasis qubits %3 : !quantum.obs + + // CHECK: [[samples0:%.+]] = quantum.sample [[compbasis0]] : tensor<1x1xf64> + // CHECK: [[res0:%.+]] = func.call @probs_from_samples.tensor.1x1xf64([[samples0]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64> + // CHECK-NOT: quantum.probs + %5 = quantum.probs %4 : tensor<2xf64> + + // CHECK: [[q1:%.+]] = quantum.extract [[qreg]][1] + %6 = quantum.extract %2[1] : !quantum.reg -> !quantum.bit + + // CHECK: [[compbasis1:%.+]] = quantum.compbasis qubits [[q1]] : !quantum.obs + %7 = quantum.compbasis qubits %6 : !quantum.obs + + // CHECK: [[samples1:%.+]] = quantum.sample [[compbasis1]] : tensor<1x1xf64> + // CHECK: [[res1:%.+]] = func.call @probs_from_samples.tensor.1x1xf64([[samples1]]) : + // CHECK-SAME: (tensor<1x1xf64>) -> tensor<2xf64> + // CHECK-NOT: quantum.probs + %8 = quantum.probs %7 : tensor<2xf64> + + // CHECK: func.return [[res0]], [[res1]] : tensor<2xf64>, tensor<2xf64> + func.return %5, %8 : tensor<2xf64>, tensor<2xf64> + } + // CHECK-LABEL: func.func public @probs_from_samples.tensor.1x1xf64 + } + """ + + pipeline = (MeasurementsFromSamplesPass(),) + run_filecheck(program, pipeline) + + +@pytest.mark.usefixtures("enable_disable_plxpr") +class TestMeasurementsFromSamplesIntegration: + """Tests of the execution of simple workloads with the xDSL-based MeasurementsFromSamplesPass + transform. + """ + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, mp, obs, expected_res", + [ + # PauliZ observables + (qml.I, qml.expval, qml.Z, 1.0), + (qml.X, qml.expval, qml.Z, -1.0), + (qml.I, qml.var, qml.Z, 0.0), + (qml.X, qml.var, qml.Z, 0.0), + # PauliX observables + pytest.param( + partial(qml.RY, phi=np.pi / 2), + qml.expval, + qml.X, + 1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RY, phi=-np.pi / 2), + qml.expval, + qml.X, + -1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RY, phi=np.pi / 2), + qml.var, + qml.X, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RY, phi=-np.pi / 2), + qml.var, + qml.X, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + # PauliY observables + pytest.param( + partial(qml.RX, phi=-np.pi / 2), + qml.expval, + qml.Y, + 1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RX, phi=np.pi / 2), + qml.expval, + qml.Y, + -1.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RX, phi=-np.pi / 2), + qml.var, + qml.Y, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + pytest.param( + partial(qml.RX, phi=np.pi / 2), + qml.var, + qml.Y, + 0.0, + marks=pytest.mark.xfail( + reason="Only PauliZ-basis measurements supported", + strict=True, + raises=NotImplementedError, + ), + ), + ], + ) + # pylint: disable=too-many-arguments,too-many-positional-arguments + def test_exec_1_wire_mp_with_obs(self, shots, initial_op, mp, obs, expected_res): + """Test the measurements_from_samples transform on a device with a single wire and terminal + measurements that require an observable (i.e. expval and var). + """ + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return mp(obs(wires=0)) + + assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert expected_res == circuit_compiled() + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, expected_res", + [ + (qml.I, [1.0, 0.0]), + (qml.X, [0.0, 1.0]), + ], + ) + def test_exec_1_wire_probs(self, shots, initial_op, expected_res): + """Test the measurements_from_samples transform on a device with a single wire and terminal + probs measurements. + """ + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return qml.probs(wires=0) + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.xfail( + reason="Counts not supported in Catalyst with program capture", + strict=True, + raises=NotImplementedError, + ) + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, expected_res", + [ + (qml.I, {"0": 10, "1": 0}), + (qml.X, {"0": 0, "1": 10}), + ], + ) + def test_exec_1_wire_counts(self, shots, initial_op, expected_res): + """Test the measurements_from_samples transform on a device with a single wire and terminal + counts measurements. + """ + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return qml.counts(wires=0) + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, _counts_catalyst_to_pl(*circuit_compiled())) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_op, expected_res_base", + [ + (qml.I, 0), + (qml.X, 1), + ], + ) + def test_exec_1_wire_sample(self, shots, initial_op, expected_res_base): + """Test the measurements_from_samples transform on a device with a single wire and terminal + sample measurements. + + In this case, the measurements_from_samples pass should effectively be a no-op. + """ + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_op(wires=0) + return qml.sample(wires=0) + + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + expected_res = expected_res_base * np.ones(shape=(shots, 1), dtype=int) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, mp, obs, expected_res", + [ + ((qml.I, qml.I), qml.expval, qml.Z, (1.0, 1.0)), + ((qml.I, qml.X), qml.expval, qml.Z, (1.0, -1.0)), + ((qml.X, qml.I), qml.expval, qml.Z, (-1.0, 1.0)), + ((qml.X, qml.X), qml.expval, qml.Z, (-1.0, -1.0)), + ((qml.I, qml.I), qml.var, qml.Z, (0.0, 0.0)), + ((qml.I, qml.X), qml.var, qml.Z, (0.0, 0.0)), + ((qml.X, qml.I), qml.var, qml.Z, (0.0, 0.0)), + ((qml.X, qml.X), qml.var, qml.Z, (0.0, 0.0)), + ], + ) + # pylint: disable=too-many-arguments,too-many-positional-arguments + def test_exec_2_wire_with_obs_separate(self, shots, initial_ops, mp, obs, expected_res): + """Test the measurements_from_samples transform on a device with two wires and terminal + measurements that require an observable (i.e. expval and var). + + In this test, the terminal measurements are performed separately per wire. + """ + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return mp(obs(wires=0)), mp(obs(wires=1)) + + assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert expected_res == circuit_compiled() + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.xfail( + reason="Operator arithmetic not yet supported with capture enabled", strict=True + ) + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, mp, expected_res", + [ + ((qml.I, qml.I), qml.expval, 1.0), + ((qml.I, qml.X), qml.expval, -1.0), + ((qml.X, qml.I), qml.expval, -1.0), + ((qml.X, qml.X), qml.expval, 1.0), + ((qml.I, qml.I), qml.var, 0.0), + ((qml.I, qml.X), qml.var, 0.0), + ((qml.X, qml.I), qml.var, 0.0), + ((qml.X, qml.X), qml.var, 0.0), + ], + ) + def test_exec_2_wire_with_obs_combined(self, shots, initial_ops, mp, expected_res): + """Test the measurements_from_samples transform on a device with two wires and terminal + measurements that require an observable (i.e. expval and var). + + In this test, the terminal measurements are performed on the combination of both wires. + """ + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return mp(qml.Z(wires=0) @ qml.Z(wires=1)) + + assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" + + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert expected_res == circuit_compiled() + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, expected_res", + [ + ((qml.I, qml.I), [1.0, 0.0, 0.0, 0.0]), + ((qml.I, qml.X), [0.0, 1.0, 0.0, 0.0]), + ((qml.X, qml.I), [0.0, 0.0, 1.0, 0.0]), + ((qml.X, qml.X), [0.0, 0.0, 0.0, 1.0]), + ], + ) + def test_exec_2_wire_probs_global(self, shots, initial_ops, expected_res): + """Test the measurements_from_samples transform on a device with two wires and a terminal, + "global" probs measurements (one that implicitly acts on all wires). + """ + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return qml.probs() + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.parametrize("shots", [1, 2]) + @pytest.mark.parametrize( + "initial_ops, expected_res", + [ + ((qml.I, qml.I), ([1.0, 0.0], [1.0, 0.0])), + ((qml.I, qml.X), ([1.0, 0.0], [0.0, 1.0])), + ((qml.X, qml.I), ([0.0, 1.0], [1.0, 0.0])), + ((qml.X, qml.X), ([0.0, 1.0], [0.0, 1.0])), + ], + ) + def test_exec_2_wire_probs_per_wire(self, shots, initial_ops, expected_res): + """Test the measurements_from_samples transform on a device with two wires and a terminal, + "global" probs measurements (one that implicitly acts on all wires). + """ + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev, shots=shots) + def circuit_ref(): + initial_ops[0](wires=0) + initial_ops[1](wires=1) + return qml.probs(wires=0), qml.probs(wires=1) + + assert np.array_equal( + expected_res, circuit_ref() + ), "Sanity check failed, is expected_res correct?" + circuit_compiled = qml.qjit( + measurements_from_samples_pass(circuit_ref), + pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], + ) + + assert np.array_equal(expected_res, circuit_compiled()) + + # -------------------------------------------------------------------------------------------- # + + @pytest.mark.xfail(reason="Dynamic shots not supported") + def test_exec_expval_dynamic_shots(self): + """Test the measurements_from_samples transform where the number of shots is dynamic. + + This use case is not currently supported. + """ + + @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + def workload(shots): + dev = qml.device("lightning.qubit", wires=1) + + @measurements_from_samples_pass + @qml.qnode(dev, shots=shots) + def circuit(): + return qml.expval(qml.Z(wires=0)) + + return circuit() + + result = workload(2) + assert result == 1.0 + + def test_qjit_filecheck(self, run_filecheck_qjit): + """Test that the measurements_from_samples_pass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @measurements_from_samples_pass + @qml.qnode(dev, shots=25) + def circuit(): + # CHECK-NOT: quantum.namedobs + # CHECK: [[obs:%.+]] = quantum.compbasis + # CHECK: [[samples:%.+]] = quantum.sample [[obs]] : tensor<25x1xf64> + # CHECK: [[c0:%.+]] = arith.constant dense<0> : tensor + # CHECK: [[res:%.+]] = func.call @expval_from_samples.tensor.25x1xf64([[samples]], [[c0]]) : + # CHECK-SAME: (tensor<25x1xf64>, tensor) -> tensor + # CHECK-NOT: quantum.expval + return qml.expval(qml.Z(wires=0)) + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_integrate_with_decompose(self): + dev = qml.device("null.qubit", wires=4) + + @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @measurements_from_samples_pass + @partial( + qml.transforms.decompose, + gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "GlobalPhase"}, + ) + @qml.set_shots(1000) + @qml.qnode(dev, shots=1000) + def circuit(): + qml.CRX(0.1, wires=[0, 1]) + return qml.expval(qml.Z(0)) + + res = circuit() + assert res == 1.0 + + +def _counts_catalyst_to_pl(basis_states, counts): + """Helper function to convert counts in the Catalyst format to the PennyLane format. + + Example: + + >>> basis_states, counts = ([0, 1], [6, 4]) + >>> _counts_catalyst_to_pl(basis_states, counts) + {'0': 6, '1': 4} + """ + return {format(int(state), "01b"): count for state, count in zip(basis_states, counts)} + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py new file mode 100644 index 0000000000..a6f8ba4e28 --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py @@ -0,0 +1,247 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the merge rotations transform""" +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + +import pennylane as qml + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import MergeRotationsPass, merge_rotations_pass + + +class TestMergeRotationsPass: + """Unit tests for MergeRotationsPass.""" + + pipeline = (MergeRotationsPass(),) + + def test_no_composable_ops(self, run_filecheck): + """Test that nothing changes when there are no composable gates.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "RX"(%arg0) [[q0]] : !quantum.bit + // CHECK: quantum.custom "RY"(%arg1) [[q1]] : !quantum.bit + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RY"(%arg1) %1 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_composable_ops(self, run_filecheck): + """Test that composable gates are merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: quantum.custom "RX"([[phi0]]) [[q0]] : !quantum.bit + // CHECK-NOT: "quantum.custom" + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RX"(%arg1) %1 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_many_composable_ops(self, run_filecheck): + """Test that more than 2 composable ops are merged correctly.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[phi1:%.+]] = arith.addf [[phi0]], %arg2 : f64 + // CHECK: [[phi2:%.+]] = arith.addf [[phi1]], %arg3 : f64 + // CHECK: quantum.custom "RX"([[phi2]]) [[q0]] : !quantum.bit + // CHECK-NOT: "quantum.custom" + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RX"(%arg1) %1 : !quantum.bit + %3 = quantum.custom "RX"(%arg2) %2 : !quantum.bit + %4 = quantum.custom "RX"(%arg3) %3 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_non_consecutive_composable_ops(self, run_filecheck): + """Test that non-consecutive composable gates are not merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.custom "RX"(%arg0) [[q0]] : !quantum.bit + // CHECK: [[q2:%.+]] = quantum.custom "RY"(%arg0) [[q1]] : !quantum.bit + // CHECK: quantum.custom "RX"(%arg1) [[q2]] : !quantum.bit + %1 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %2 = quantum.custom "RY"(%arg0) %1 : !quantum.bit + %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_composable_ops_different_qubits(self, run_filecheck): + """Test that composable gates on different qubits are not merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + // CHECK: quantum.custom "RX"(%arg0) [[q0]] : !quantum.bit + // CHECK: quantum.custom "RX"(%arg1) [[q1]] : !quantum.bit + %2 = quantum.custom "RX"(%arg0) %0 : !quantum.bit + %3 = quantum.custom "RX"(%arg1) %1 : !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_controlled_composable_ops(self, run_filecheck): + """Test that controlled composable ops can be merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + %cst = "arith.constant"() <{value = true}> : () -> i1 + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: quantum.custom "RX"([[phi0]]) [[q0]] ctrls([[q1]], [[q2]]) ctrlvals(%cst, %cst) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK-NOT: "quantum.custom" + %3, %4, %5 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%cst, %cst) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %6, %7, %8 = quantum.custom "RX"(%arg1) %3 ctrls(%4, %5) ctrlvals(%cst, %cst) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_controlled_composable_ops_same_control_values(self, run_filecheck): + """Test that controlled composable ops with the same control values + can be merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + %cst0 = "arith.constant"() <{value = true}> : () -> i1 + %cst1 = "arith.constant"() <{value = false}> : () -> i1 + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + // CHECK: [[phi0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: quantum.custom "RX"([[phi0]]) [[q0]] ctrls([[q1]], [[q2]]) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.custom + %3, %4, %5 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %6, %7, %8 = quantum.custom "RX"(%arg1) %3 ctrls(%4, %5) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + def test_controlled_composable_ops_different_control_values(self, run_filecheck): + """Test that controlled composable ops with different control values + are not merged.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64) { + %cst0 = "arith.constant"() <{value = true}> : () -> i1 + %cst1 = "arith.constant"() <{value = false}> : () -> i1 + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit + // CHECK: [[q2:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + %1 = "test.op"() : () -> !quantum.bit + %2 = "test.op"() : () -> !quantum.bit + // CHECK: [[q3:%.+]], [[q4:%.+]], [[q5:%.+]] = quantum.custom "RX"(%arg0) [[q0]] ctrls([[q1]], [[q2]]) ctrlvals(%cst1, %cst0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + // CHECK: quantum.custom "RX"(%arg1) [[q3]] ctrls([[q4]], [[q5]]) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %3, %4, %5 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%cst1, %cst0) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %6, %7, %8 = quantum.custom "RX"(%arg1) %3 ctrls(%4, %5) ctrlvals(%cst0, %cst1) : !quantum.bit ctrls !quantum.bit, !quantum.bit + return + } + """ + + run_filecheck(program, self.pipeline) + + @pytest.mark.parametrize( + "first_adj, second_adj, sign", + [ + (False, False, "+"), + (False, True, "-"), + (True, False, "-"), + (True, True, "+"), + ], + ) + def test_adjoint_property(self, first_adj, second_adj, sign, run_filecheck): + """Test that composable ops with and without adjoint property are merged correctly.""" + adj_string_0 = "adj" if first_adj else "" + adj_string_1 = "adj" if second_adj else "" + arith_op_str = "arith.addf" if sign == "+" else "arith.subf" + program = f""" + func.func @test_func(%arg0: f64, %arg1: f64) {{ + // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit + %0 = "test.op"() : () -> !quantum.bit + // CHECK: [[new_angle:%.+]] = {arith_op_str} %arg0, %arg1 : f64 + // CHECK: [[q1:%.+]] = quantum.custom "RX"([[new_angle]]) [[q0]] {adj_string_0} : !quantum.bit + // CHECK-NOT: "quantum.custom" + %1 = quantum.custom "RX"(%arg0) %0 {adj_string_0}: !quantum.bit + %2 = quantum.custom "RX"(%arg1) %1 {adj_string_1}: !quantum.bit + return + }} + """ + + run_filecheck(program, self.pipeline) + + +# pylint: disable=too-few-public-methods +@pytest.mark.usefixtures("enable_disable_plxpr") +class TestMergeRotationsIntegration: + """Integration tests for the MergeRotationsPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the MergeRotationsPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @merge_rotations_pass + @qml.qnode(dev) + def circuit(x: float, y: float): + # CHECK: [[phi:%.+]] = arith.addf + # CHECK: quantum.custom "RX"([[phi]]) + # CHECK-NOT: quantum.custom + qml.RX(x, 0) + qml.RX(y, 0) + return qml.state() + + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py new file mode 100644 index 0000000000..ddf5cddb5f --- /dev/null +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py @@ -0,0 +1,313 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the split non-commuting transform""" +import pytest + +pytestmark = pytest.mark.external + +xdsl = pytest.importorskip("xdsl") +catalyst = pytest.importorskip("catalyst") + +import pennylane as qml + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + SplitNonCommutingPass, + split_non_commuting_pass, +) + + +class TestSplitNonCommutingPass: + """Unit tests for SplitNonCommutingPass.""" + + def test_func_w_qnode_attr(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit() -> tensor attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[value:%.+]] = func.call [[dup_func:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[value]] + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func [[dup_func]] + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_multiple_func_w_qnode_attr(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with a qnode attribute.""" + program = """ + module @module_circuit { + func.func public @circuit1() -> tensor attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[value:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[value]] + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + func.func public @circuit2() -> tensor attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[value:%.+]] = func.call [[dup_func2:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[value]] + %0 = arith.constant 0 : i64 + quantum.device shots(%0) ["", "", ""] + %1 = quantum.alloc( 50) : !quantum.reg + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "PauliX"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %1[ 0], %out_qubits : !quantum.reg, !quantum.bit + quantum.dealloc %5 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + // CHECK: func.func [[dup_func1]] + // CHECK: func.func [[dup_func2]] + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_func_w_commuting_measurements(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with commuting ops.""" + program = """ + module @module_circuit { + func.func public @circuit() -> (tensor, tensor) attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[v0:%.+]], [[v1:%.+]] = func.call [[dup_func:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[v0]], [[v1]] + %c0 = arith.constant 0 : i64 + quantum.device shots(%c0) ["", "", ""] + %0 = quantum.alloc( 5) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.namedobs %out_qubits_0[ PauliZ] : !quantum.obs + %6 = quantum.expval %5 : f64 + %from_elements_1 = tensor.from_elements %6 : tensor + %7 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %8 = quantum.insert %7[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit + quantum.dealloc %8 : !quantum.reg + quantum.device_release + return %from_elements, %from_elements_1 : tensor, tensor + } + // CHECK: func.func [[dup_func]] + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_func_w_non_commuting_measurements(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with non-commuting ops.""" + program = """ + module @module_circuit { + func.func public @circuit() -> (tensor, tensor) attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[v0:%.+]] = func.call [[dup_func0:@[a-zA-Z0-9_.]+]] + // CHECK: [[v1:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[v0]], [[v1]] + %c0 = arith.constant 0 : i64 + quantum.device shots(%c0) ["", "", ""] + %0 = quantum.alloc( 5) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs + %6 = quantum.expval %5 : f64 + %from_elements_1 = tensor.from_elements %6 : tensor + %7 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %8 = quantum.insert %7[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit + quantum.dealloc %8 : !quantum.reg + quantum.device_release + return %from_elements, %from_elements_1 : tensor, tensor + } + // CHECK: func.func [[dup_func0]] + // CHECK: PauliX + // CHECK: func.func [[dup_func1]] + // CHECK: PauliZ + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + def test_func_w_mixed_measurements(self, run_filecheck): + """Test split non-commuting pass would be applied to a func with mixed ops.""" + program = """ + module @module_circuit { + func.func public @circuit() -> (tensor, tensor) attributes {qnode} { + // CHECK-NOT: arith.constant + // CHECK: [[v0:%.+]], [[v1:%.+]] = func.call [[dup_func0:@[a-zA-Z0-9_.]+]] + // CHECK: [[v2:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + // CHECK: func.return [[v0]], [[v2]], [[v1]] + %c0 = arith.constant 0 : i64 + quantum.device shots(%c0) ["", "", ""] + %0 = quantum.alloc( 5) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0 = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs + %6 = quantum.expval %5 : f64 + %from_elements_1 = tensor.from_elements %6 : tensor + %7 = quantum.namedobs %out_qubits_0[ PauliY] : !quantum.obs + %8 = quantum.expval %7 : f64 + %from_elements_2 = tensor.from_elements %8 : tensor + %9 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %10 = quantum.insert %9[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit + quantum.dealloc %10 : !quantum.reg + quantum.device_release + return %from_elements, %from_elements_1, %from_elements_2 : tensor, tensor, tensor + } + // CHECK: func.func [[dup_func0]] + // CHECK: PauliX + // CHECK: PauliY + // CHECK: func.func [[dup_func1]] + // CHECK: PauliZ + } + """ + + pipeline = (SplitNonCommutingPass(),) + run_filecheck(program, pipeline) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_split_non_commuting_pass_only(self, run_filecheck_qjit): + """Test the split non-commuting pass only.""" + dev = qml.device("lightning.qubit", wires=5) + + @qml.while_loop(lambda i: i < 5) + def _while_for(i): + qml.H(i) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @split_non_commuting_pass + @qml.set_shots(10) + @qml.qnode(dev) + def circuit(): + # CHECK-LABEL: func.func public @circuit() + # CHECK: [[v0:%.+]], [[v1:%.+]] = func.call [[dup_func:@[a-zA-Z0-9_.]+]] + # CHECK: [[v2:%.+]] = func.call [[dup_func1:@[a-zA-Z0-9_.]+]] + # CHECK: func.return [[v0]], [[v1]], [[v2]] + _while_for(0) + qml.CNOT(wires=[0, 1]) + return ( + qml.expval(qml.Z(wires=0)), + qml.expval(qml.Y(wires=1)), + qml.expval(qml.X(wires=0)), + ) + # CHECK: func.func [[dup_func]] + # CHECK: PauliZ + # CHECK: PauliY + # CHECK: func.func [[dup_func1]] + # CHECK: PauliX + + run_filecheck_qjit(circuit) + + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_lightning_execution_with_structure(self): + """Test that the split non-commuting pass on lightning.qubit for a circuit with program structure is executable and returns results as expected.""" + dev = qml.device("lightning.qubit", wires=10) + + @qml.for_loop(0, 10, 1) + def for_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + + @qml.while_loop(lambda i: i < 10) + def while_fn(i): + qml.H(i) + qml.S(i) + qml.RZ(phi=0.1, wires=[i]) + i = i + 1 + return i + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + ) + @split_non_commuting_pass + @qml.qnode(dev) + def circuit(): + for_fn() + while_fn(0) + qml.CNOT(wires=[0, 1]) + return ( + qml.expval(qml.Z(wires=0)), + qml.expval(qml.Y(wires=1)), + qml.expval(qml.X(wires=0)), + ) + + res = circuit() + + @qml.qjit( + target="mlir", + ) + @qml.qnode(dev) + def circuit_ref(): + for_fn() + while_fn(0) + qml.CNOT(wires=[0, 1]) + return ( + qml.expval(qml.Z(wires=0)), + qml.expval(qml.Y(wires=1)), + qml.expval(qml.X(wires=0)), + ) + + res_ref = circuit_ref() + assert res == res_ref diff --git a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py new file mode 100644 index 0000000000..c828f8f87c --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py @@ -0,0 +1,534 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the draw function in the Python Compiler visualization module.""" + + +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + +# pylint: disable=wrong-import-position +import jax +import pennylane as qml + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import iterative_cancel_inverses_pass +from catalyst.python_interface.visualization import draw + +# pylint: disable=implicit-str-concat, unnecessary-lambda + + +@pytest.mark.usefixtures("enable_disable_plxpr") +class Testdraw: + """Unit tests for the draw function in the Python Compiler visualization module.""" + + @pytest.fixture + def transforms_circuit(self): + """Fixture for a circuit.""" + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circ(): + qml.RX(1, 0) + qml.RX(2.0, 0) + qml.RY(3.0, 1) + qml.RY(4.0, 1) + qml.RZ(5.0, 2) + qml.RZ(6.0, 2) + qml.Hadamard(0) + qml.Hadamard(0) + qml.CNOT([0, 1]) + qml.CNOT([0, 1]) + qml.Hadamard(1) + qml.Hadamard(1) + qml.RZ(7.0, 0) + qml.RZ(8.0, 0) + qml.CNOT([0, 2]) + qml.CNOT([0, 2]) + return qml.state() + + return circ + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + 1, + "0: ──RX──H──H─╭●─╭●──RZ────╭●─╭●─┤ State\n" + "1: ──RY───────╰X─╰X──H───H─│──│──┤ State\n" + "2: ──RZ────────────────────╰X─╰X─┤ State", + ), + (2, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (None, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + ], + ) + def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected): + """Test that multiple levels of transformation are applied correctly with xDSL compilation passes.""" + + transforms_circuit = iterative_cancel_inverses_pass( + qml.compiler.python_compiler.transforms.merge_rotations_pass(transforms_circuit) + ) + + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + 1, + "0: ──RX──H──H─╭●─╭●──RZ────╭●─╭●─┤ State\n" + "1: ──RY───────╰X─╰X──H───H─│──│──┤ State\n" + "2: ──RZ────────────────────╰X─╰X─┤ State", + ), + (2, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (None, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + ], + ) + def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expected): + """Test that multiple levels of transformation are applied correctly with Catalyst compilation passes.""" + + transforms_circuit = qml.transforms.cancel_inverses( + qml.transforms.merge_rotations(transforms_circuit) + ) + + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + 1, + "0: ──RX──H──H─╭●─╭●──RZ────╭●─╭●─┤ State\n" + "1: ──RY───────╰X─╰X──H───H─│──│──┤ State\n" + "2: ──RZ────────────────────╰X─╰X─┤ State", + ), + (2, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (None, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + ], + ) + def test_multiple_levels_xdsl_catalyst(self, transforms_circuit, level, qjit, expected): + """Test that multiple levels of transformation are applied correctly with xDSL and Catalyst compilation passes.""" + + transforms_circuit = iterative_cancel_inverses_pass( + qml.transforms.merge_rotations(transforms_circuit) + ) + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize("qjit", [True, False]) + @pytest.mark.parametrize( + "level, expected", + [ + ( + 0, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + 1, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + 2, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + None, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ( + 50, + "0: ──RX──RX──H──H─╭●─╭●──RZ──RZ─╭●─╭●─┤ State\n" + "1: ──RY──RY───────╰X─╰X──H───H──│──│──┤ State\n" + "2: ──RZ──RZ─────────────────────╰X─╰X─┤ State", + ), + ], + ) + def test_no_passes(self, transforms_circuit, level, qjit, expected): + """Test that if no passes are applied, the circuit is still visualized.""" + + if qjit: + transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( + transforms_circuit + ) + + assert draw(transforms_circuit, level=level)() == expected + + @pytest.mark.parametrize( + "op, expected", + [ + ( + lambda: qml.ctrl(qml.RX(0.1, 0), control=(1, 2, 3)), + "1: ─╭●──┤ State\n2: ─├●──┤ State\n3: ─├●──┤ State\n0: ─╰RX─┤ State", + ), + ( + lambda: qml.ctrl(qml.RX(0.1, 0), control=(1, 2, 3), control_values=(0, 1, 0)), + "1: ─╭○──┤ State\n2: ─├●──┤ State\n3: ─├○──┤ State\n0: ─╰RX─┤ State", + ), + ( + lambda: qml.adjoint(qml.ctrl(qml.RX(0.1, 0), (1, 2, 3), control_values=(0, 1, 0))), + "1: ─╭○───┤ State\n2: ─├●───┤ State\n3: ─├○───┤ State\n0: ─╰RX†─┤ State", + ), + ( + lambda: qml.ctrl(qml.adjoint(qml.RX(0.1, 0)), (1, 2, 3), control_values=(0, 1, 0)), + "1: ─╭○───┤ State\n2: ─├●───┤ State\n3: ─├○───┤ State\n0: ─╰RX†─┤ State", + ), + ], + ) + def test_ctrl_adjoint_variants(self, op, expected): + """ + Test the visualization of control and adjoint variants. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + op() + return qml.state() + + assert draw(circuit)() == expected + + def test_ctrl_before_custom_op(self): + """ + Test the visualization of control operations before custom ops. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + qml.ctrl(qml.X(3), control=[0, 1, 2], control_values=[1, 0, 1]) + qml.RX(0.1, 2) + return qml.state() + + assert ( + draw(circuit)() + == "0: ─╭●─────┤ State\n1: ─├○─────┤ State\n2: ─├●──RX─┤ State\n3: ─╰X─────┤ State" + ) + + @pytest.mark.parametrize( + "measurement, expected", + [ + ( + lambda: (qml.probs(0), qml.probs(1), qml.probs(2)), + "0: ──RX─┤ Probs\n1: ──RY─┤ Probs\n2: ──RZ─┤ Probs", + ), + ( + lambda: qml.probs(), + "0: ──RX─┤ Probs\n1: ──RY─┤ Probs\n2: ──RZ─┤ Probs", + ), + ( + lambda: qml.sample(), + "0: ──RX─┤ Sample\n1: ──RY─┤ Sample\n2: ──RZ─┤ Sample", + ), + ( + lambda: (qml.expval(qml.X(0)), qml.expval(qml.Y(1)), qml.expval(qml.Z(2))), + "0: ──RX─┤ \n1: ──RY─┤ \n2: ──RZ─┤ ", + ), + ( + lambda: ( + qml.expval(qml.X(0) @ qml.Y(1)), + qml.expval(qml.Y(1) @ qml.Z(2) @ qml.X(0)), + qml.expval(qml.Z(2) @ qml.X(0) @ qml.Y(1)), + ), + "0: ──RX─┤ ╭\n1: ──RY─┤ ╰\n2: ──RZ─┤ ╰", + ), + ( + lambda: ( + qml.expval( + qml.Hamiltonian([0.2, 0.2], [qml.PauliX(0), qml.Y(1)]) + @ qml.Hamiltonian([0.1, 0.1], [qml.PauliZ(2), qml.PauliZ(3)]) + ) + ), + "0: ──RX─┤ ╭<(𝓗)@(𝓗)>\n1: ──RY─┤ ├<(𝓗)@(𝓗)>\n2: ──RZ─┤ ├<(𝓗)@(𝓗)>\n3: ─────┤ ╰<(𝓗)@(𝓗)>", + ), + ( + lambda: (qml.var(qml.X(0)), qml.var(qml.Y(1)), qml.var(qml.Z(2))), + "0: ──RX─┤ Var[X]\n1: ──RY─┤ Var[Y]\n2: ──RZ─┤ Var[Z]", + ), + ( + lambda: ( + qml.var(qml.X(0) @ qml.Y(1)), + qml.var(qml.Y(1) @ qml.Z(2) @ qml.X(0)), + qml.var(qml.Z(2) @ qml.X(0) @ qml.Y(1)), + ), + "0: ──RX─┤ ╭Var[X@Y] ╭Var[Y@Z@X] ╭Var[Z@X@Y]\n1: ──RY─┤ ╰Var[X@Y] ├Var[Y@Z@X] ├Var[Z@X@Y]\n2: ──RZ─┤ ╰Var[Y@Z@X] ╰Var[Z@X@Y]", + ), + ], + ) + def test_measurements(self, measurement, expected): + """ + Test the visualization of measurements. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + qml.RX(0.1, 0) + qml.RY(0.2, 1) + qml.RZ(0.3, 2) + return measurement() + + if isinstance(measurement(), qml.measurements.SampleMP): + circuit = qml.set_shots(10)(circuit) + + assert draw(circuit)() == expected + + def test_global_phase(self): + """Test the visualization of global phase shifts.""" + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + qml.H(0) + qml.H(1) + qml.H(2) + qml.GlobalPhase(0.5) + return qml.state() + + assert ( + draw(circuit)() + == "0: ──H─╭GlobalPhase─┤ State\n1: ──H─├GlobalPhase─┤ State\n2: ──H─╰GlobalPhase─┤ State" + ) + + @pytest.mark.parametrize( + "postselect, mid_measure_label", + [ + (None, "┤↗├"), + (0, "┤↗₀├"), + (1, "┤↗₁├"), + ], + ) + def test_draw_mid_circuit_measurement_postselect(self, postselect, mid_measure_label): + """Test that mid-circuit measurements are drawn correctly.""" + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + qml.Hadamard(0) + qml.measure(0, postselect=postselect) + qml.PauliX(0) + return qml.expval(qml.PauliZ(0)) + + drawing = draw(circuit)() + expected_drawing = "0: ──H──" + mid_measure_label + "──X─┤ " + + assert drawing == expected_drawing + + @pytest.mark.jax + @pytest.mark.parametrize( + "ops, expected", + [ + ( + [ + (qml.QubitUnitary, jax.numpy.array([[0, 1], [1, 0]]), [0]), + ( + qml.QubitUnitary, + jax.numpy.array([[0, 1, 0, 1], [1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]), + [0, 1], + ), + (qml.QubitUnitary, jax.numpy.zeros((8, 8)), [0, 1, 2]), + ( + qml.QubitUnitary, + jax.numpy.array([[0, 1, 0, 1], [1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]), + [0, 1], + ), + (qml.QubitUnitary, jax.numpy.array([[0, 1], [1, 0]]), [0]), + ], + "0: ──U(M0)─╭U(M1)─╭U(M2)─╭U(M1)──U(M0)─┤ State\n" + "1: ────────╰U(M1)─├U(M2)─╰U(M1)────────┤ State\n" + "2: ───────────────╰U(M2)───────────────┤ State", + ), + ( + [ + (qml.StatePrep, jax.numpy.array([1, 0]), [0]), + (qml.StatePrep, jax.numpy.array([1, 0, 0, 0]), [0, 1]), + (qml.StatePrep, jax.numpy.array([1, 0, 0, 0, 1, 0, 0, 0]), [0, 1, 2]), + (qml.StatePrep, jax.numpy.array([1, 0, 0, 0]), [0, 1]), + (qml.StatePrep, jax.numpy.array([1, 0]), [0]), + ], + "0: ──|Ψ⟩─╭|Ψ⟩─╭|Ψ⟩─╭|Ψ⟩──|Ψ⟩─┤ State\n" + "1: ──────╰|Ψ⟩─├|Ψ⟩─╰|Ψ⟩──────┤ State\n" + "2: ───────────╰|Ψ⟩───────────┤ State", + ), + ( + [ + (qml.MultiRZ, 0.1, [0]), + (qml.MultiRZ, 0.1, [0, 1]), + (qml.MultiRZ, 0.1, [0, 1, 2]), + (qml.MultiRZ, 0.1, [0, 1]), + (qml.MultiRZ, 0.1, [0]), + ], + "0: ──MultiRZ─╭MultiRZ─╭MultiRZ─╭MultiRZ──MultiRZ─┤ State\n" + "1: ──────────╰MultiRZ─├MultiRZ─╰MultiRZ──────────┤ State\n" + "2: ───────────────────╰MultiRZ───────────────────┤ State", + ), + ( + [ + (qml.BasisState, jax.numpy.array([1]), [0]), + (qml.BasisState, jax.numpy.array([1, 0]), [0, 1]), + (qml.BasisState, jax.numpy.array([1, 0, 0]), [0, 1, 2]), + (qml.BasisState, jax.numpy.array([1, 0]), [0, 1]), + (qml.BasisState, jax.numpy.array([1]), [0]), + ], + "0: ──|Ψ⟩─╭|Ψ⟩─╭|Ψ⟩─╭|Ψ⟩──|Ψ⟩─┤ State\n" + "1: ──────╰|Ψ⟩─├|Ψ⟩─╰|Ψ⟩──────┤ State\n" + "2: ───────────╰|Ψ⟩───────────┤ State", + ), + ], + ) + def test_visualization_cases(self, ops, expected): + """ + Test the visualization of the quantum operations defined in the unified compiler dialect. + """ + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(): + for op, param, wires in ops: + op(param, wires=wires) + return qml.state() + + assert draw(circuit)() == expected + + def test_reshape(self): + """Test that the visualization works when the parameters are reshaped.""" + + one_dim = jax.numpy.array([1, 0]) + two_dim = jax.numpy.array([[0, 1], [1, 0]]) + eight_dim = jax.numpy.zeros((8, 8)) + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + qml.RX(one_dim[0], wires=0) + qml.RZ(two_dim[0, 0], wires=0) + qml.QubitUnitary(eight_dim[:2, :2], wires=0) + qml.QubitUnitary(eight_dim[0:4, 0:4], wires=[0, 1]) + return qml.state() + + expected = ( + "0: ──RX(M0)──RZ(M0)──U(M1)─╭U(M2)─┤ State\n" + "1: ────────────────────────╰U(M2)─┤ State" + ) + assert draw(circuit)() == expected + + def test_args_warning(self): + """Test that a warning is raised when dynamic arguments are used.""" + + # pylint: disable=unused-argument + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circ(arg): + qml.RX(0.1, wires=0) + return qml.state() + + with pytest.warns(UserWarning): + draw(circ)(0.1) + + def adjoint_op_not_implemented(self): + """Test that NotImplementedError is raised when AdjointOp is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + qml.adjoint(qml.QubitUnitary)(jax.numpy.array([[0, 1], [1, 0]]), wires=[0]) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + def test_cond_not_implemented(self): + """Test that NotImplementedError is raised when cond is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + m0 = qml.measure(0, reset=False, postselect=0) + qml.cond(m0, qml.RX, qml.RY)(1.23, 1) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + def test_for_loop_not_implemented(self): + """Test that NotImplementedError is raised when for loop is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + for _ in range(3): + qml.RX(0.1, 0) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + def test_while_loop_not_implemented(self): + """Test that NotImplementedError is raised when while loop is used.""" + + @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + i = 0 + while i < 3: + qml.RX(0.1, 0) + i += 1 + return qml.expval(qml.PauliZ(0)) + + with pytest.raises(NotImplementedError, match="not yet supported"): + print(draw(circuit)()) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py new file mode 100644 index 0000000000..d643ab4e82 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py @@ -0,0 +1,292 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the MLIR graph generation in the Unified Compiler visualization module.""" + +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + + +import pennylane as qml + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.transforms import ( + iterative_cancel_inverses_pass, + merge_rotations_pass, +) +from catalyst.python_interface.visualization import generate_mlir_graph + + +@pytest.fixture(autouse=True) +def _chdir_tmp(monkeypatch, tmp_path: Path): + """Ensure all tests run inside a temp directory.""" + monkeypatch.chdir(tmp_path) + return tmp_path + + +def collect_files(tmp_path: Path) -> set[str]: + """Return the set of generated SVG files.""" + out_dir = tmp_path / "mlir_generated_graphs" + return {f.name for f in out_dir.glob("*.svg")} + + +def assert_files(tmp_path: Path, expected: set[str]): + """Check that the generated files match the expected set.""" + files = collect_files(tmp_path) + assert files == expected, f"Expected {expected}, got {files}" + + +@pytest.mark.usefixtures("enable_disable_plxpr") +class TestMLIRGraph: + "Test the MLIR graph generation" + + @pytest.mark.parametrize("qjit", [True, False]) + def test_no_transforms(self, tmp_path: Path, qjit: bool): + "Test the MLIR graph is still generated when no transforms are applied" + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(): + qml.RX(0.1, 0) + qml.RX(2.0, 0) + qml.CNOT([0, 2]) + qml.CNOT([0, 2]) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)() + assert collect_files(tmp_path) == {"QNode_level_0_no_transforms.svg"} + + @pytest.mark.parametrize("qjit", [True, False]) + def test_xdsl_transforms_no_args(self, tmp_path: Path, qjit: bool): + "Test the MLIR graph generation with no arguments to the QNode with and without qjit" + + @merge_rotations_pass + @iterative_cancel_inverses_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(): + qml.RX(0.1, 0) + qml.RX(2.0, 0) + qml.CNOT([0, 2]) + qml.CNOT([0, 2]) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)() + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-cancel-inverses.svg", + "QNode_level_2_after_xdsl-merge-rotations.svg", + }, + ) + + @pytest.mark.parametrize("qjit", [True, False]) + def test_xdsl_transforms_args(self, tmp_path: Path, qjit: bool): + "Test the MLIR graph generation with arguments to the QNode for xDSL transforms" + + @merge_rotations_pass + @iterative_cancel_inverses_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y, w1, w2): + qml.RX(x, w1) + qml.RX(y, w2) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)(0.1, 0.2, 0, 1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-cancel-inverses.svg", + "QNode_level_2_after_xdsl-merge-rotations.svg", + }, + ) + + @pytest.mark.parametrize("qjit", [True, False]) + def test_catalyst_transforms_args(self, tmp_path: Path, qjit: bool): + "Test the MLIR graph generation with arguments to the QNode for catalyst transforms" + + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y, w1, w2): + qml.RX(x, w1) + qml.RX(y, w2) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)(0.1, 0.2, 0, 1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_remove-chained-self-inverse.svg", + "QNode_level_2_after_merge-rotations.svg", + }, + ) + + @pytest.mark.parametrize("qjit", [True, False]) + def test_catalyst_xdsl_transforms_args(self, tmp_path: Path, qjit: bool): + "Test the MLIR graph generation with arguments to the QNode for catalyst and xDSL transforms" + + @qml.transforms.merge_rotations + @iterative_cancel_inverses_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y, w1, w2): + qml.RX(x, w1) + qml.RX(y, w2) + return qml.state() + + if qjit: + _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + + generate_mlir_graph(_)(0.1, 0.2, 0, 1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-cancel-inverses.svg", + "QNode_level_2_after_merge-rotations.svg", + }, + ) + + def test_cond(self, tmp_path: Path): + "Test the MLIR graph generation for a conditional" + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(pred, arg1, arg2): + """Quantum circuit with conditional branches.""" + + qml.RX(0.10, wires=0) + + def true_fn(arg1, arg2): + qml.RY(arg1, wires=0) + qml.RX(arg2, wires=0) + qml.RZ(arg1, wires=0) + + def false_fn(arg1, arg2): + qml.RX(arg1, wires=0) + qml.RX(arg2, wires=0) + + qml.cond(pred > 0, true_fn, false_fn)(arg1, arg2) + qml.RX(0.10, wires=0) + return qml.expval(qml.Z(wires=0)) + + generate_mlir_graph(_)(0.5, 0.1, 0.2) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + def test_cond_with_mcm(self, tmp_path: Path): + "Test the MLIR graph generation for a conditional with MCM" + + def true_fn(arg): + qml.RX(arg, 0) + + def false_fn(arg): + qml.RY(3 * arg, 0) + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x, y): + """Quantum circuit with conditional branches.""" + + qml.RX(x, 0) + m = qml.measure(0) + + qml.cond(m, true_fn, false_fn)(y) + return qml.expval(qml.Z(0)) + + generate_mlir_graph(_)(0.5, 0.1) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + def test_for_loop(self, tmp_path: Path): + "Test the MLIR graph generation for a for loop" + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(): + @qml.for_loop(0, 100) + def loop(_): + qml.RX(0.1, 0) + qml.RX(0.1, 0) + + # pylint: disable=no-value-for-parameter + loop() + return qml.state() + + generate_mlir_graph(_)() + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + def test_while_loop(self, tmp_path: Path): + "Test the MLIR graph generation for a while loop" + + @merge_rotations_pass + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def _(x): + def cond_fn(x): + return x < 2 + + @qml.while_loop(cond_fn) + def loop(x): + return x**2 + + loop(x) + return qml.expval(qml.PauliZ(0)) + + generate_mlir_graph(_)(0.5) + assert_files( + tmp_path, + { + "QNode_level_0_no_transforms.svg", + "QNode_level_1_after_xdsl-merge-rotations.svg", + }, + ) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py b/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py new file mode 100644 index 0000000000..90ce76b125 --- /dev/null +++ b/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py @@ -0,0 +1,547 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the constraints defined within the xdsl_extras module.""" + +import pytest + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") + +# pylint: disable=wrong-import-position +from xdsl.context import Context +from xdsl.dialects import builtin, test +from xdsl.dialects.builtin import MemRefType, TensorType, TupleType, i1, i32 +from xdsl.ir import Dialect +from xdsl.irdl import ( + BaseAttr, + IRDLOperation, + irdl_op_definition, + operand_def, + result_def, +) +from xdsl.irdl.constraints import ConstraintContext +from xdsl.utils.exceptions import VerifyException +from xdsl_jax.dialects.stablehlo import TokenType + +from catalyst.python_interface import QuantumParser +from catalyst.python_interface.xdsl_extras import ( + MemRefConstraint, + NestedTupleOfConstraint, + TensorConstraint, +) + + +@pytest.fixture(scope="module", name="my_dialect") +def my_dialect_fixture(): + """Returns a test dialect, called 'my_dialect', with simple ops that operate on memref and + tensor types. + """ + + @irdl_op_definition + class Float64MemrefOp(IRDLOperation): + """A test op with float memref types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.memref_float64" + in_value = operand_def(MemRefConstraint(element_type=builtin.Float64Type())) + out_value = result_def(MemRefConstraint(element_type=builtin.Float64Type())) + + @irdl_op_definition + class Float64TensorOp(IRDLOperation): + """A test op with float tensor types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.tensor_float64" + in_value = operand_def(TensorConstraint(element_type=builtin.Float64Type())) + out_value = result_def(TensorConstraint(element_type=builtin.Float64Type())) + + @irdl_op_definition + class Rank1MemrefOp(IRDLOperation): + """A test op with rank-1 memref types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.memref_rank1" + in_value = operand_def(MemRefConstraint(rank=1)) + out_value = result_def(MemRefConstraint(rank=1)) + + @irdl_op_definition + class Rank1TensorOp(IRDLOperation): + """A test op with rank-1 tensor types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.tensor_rank1" + in_value = operand_def(TensorConstraint(rank=1)) + out_value = result_def(TensorConstraint(rank=1)) + + @irdl_op_definition + class Rank2Or4MemrefOp(IRDLOperation): + """A test op with rank-2 or -4 memref types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.memref_rank24" + in_value = operand_def(MemRefConstraint(rank=(2, 4))) + out_value = result_def(MemRefConstraint(rank=(2, 4))) + + @irdl_op_definition + class Rank2Or4TensorOp(IRDLOperation): + """A test op with rank-2 or -4 tensor types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.tensor_rank24" + in_value = operand_def(TensorConstraint(rank=(2, 4))) + out_value = result_def(TensorConstraint(rank=(2, 4))) + + @irdl_op_definition + class Shape123MemrefOp(IRDLOperation): + """A test op with shape-(1, 2, 3) memref types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.memref_shape123" + in_value = operand_def(MemRefConstraint(shape=(1, 2, 3))) + out_value = result_def(MemRefConstraint(shape=(1, 2, 3))) + + @irdl_op_definition + class Shape123TensorOp(IRDLOperation): + """A test op with shape-(1, 2, 3) tensor types""" + + # pylint: disable=too-few-public-methods + + name = "my_dialect.tensor_shape123" + in_value = operand_def(TensorConstraint(shape=(1, 2, 3))) + out_value = result_def(TensorConstraint(shape=(1, 2, 3))) + + MyDialect = Dialect( + "my_dialect", + [ + Float64MemrefOp, + Float64TensorOp, + Rank1MemrefOp, + Rank1TensorOp, + Rank2Or4MemrefOp, + Rank2Or4TensorOp, + Shape123MemrefOp, + Shape123TensorOp, + ], + ) + + return MyDialect + + +class TestMemRefConstraint: + """Tests for the MemRefConstraint class.""" + + def test_memref_constraint_init_invalid_element_type(self): + """Test that an error is raised if the provided element_type is invalid""" + + element_type = int + with pytest.raises( + TypeError, match="is not a valid constraint for the 'element_type' argument" + ): + MemRefConstraint(element_type=element_type) + + def test_memref_constraint_init_invalid_shape(self): + """Test that an error is raised if the provided shape is invalid""" + + shape = {1, 2} + with pytest.raises(TypeError, match="is not a valid constraint for the 'shape' argument"): + MemRefConstraint(shape=shape) + + def test_memref_constraint_init_invalid_rank(self): + """Test that an error is raised if the provided rank is invalid""" + + rank = 1.5 + with pytest.raises(TypeError, match="is not a valid constraint for the 'rank' argument"): + MemRefConstraint(rank=rank) + + @pytest.mark.parametrize("element_type", [None, builtin.Float64Type(), builtin.i64]) + def test_memref_constraint_init_rank_and_shape_error(self, element_type): + """Test that an error is raised if both rank and shape are provided.""" + shape = (1, 2) + rank = 2 + + with pytest.raises(ValueError, match="Only one of 'shape' or 'rank' may be provided"): + MemRefConstraint(element_type=element_type, shape=shape, rank=rank) + + def test_memref_constraint_properties(self): + """Test that the properties of MemRefConstraint object are correct.""" + rank = 1 + constraint = MemRefConstraint(rank=rank) + + assert constraint.expected_type == builtin.MemRefType + assert constraint.type_name == "memref" + assert constraint.mapping_type_vars({}) is constraint + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_memref_single_rank_constraint_verify_valid(self, rank): + """Test that verifying a MemRefType attribute with the same rank as the MemRefConstraint + does not raise an exception.""" + constraint = MemRefConstraint(rank=rank) + attr = builtin.MemRefType(builtin.i32, [1] * rank) + + constraint.verify(attr, None) + + def test_memref_multi_rank_constraint_verify_valid(self): + """Test that verifying a MemRefType attribute with any of the ranks as the MemRefConstraint + does not raise an exception.""" + rank = (3, 4, 5) + constraint = MemRefConstraint(rank=rank) + + for r in rank: + attr = builtin.MemRefType(builtin.i32, [1] * r) + constraint.verify(attr, None) + + def test_memref_shape_constraint_verify_valid(self): + """Test that verifying an attribute with a valid shape does not raise an exception.""" + constraint = MemRefConstraint(shape=(1, 2, 3)) + attr = builtin.MemRefType(builtin.i32, (1, 2, 3)) + + constraint.verify(attr, None) + + def test_memref_element_type_constraint_verify_valid(self): + """Test that verifying an attribute with a valid element type does not raise an exception.""" + constraint = MemRefConstraint(element_type=builtin.i32) + attr = builtin.MemRefType(builtin.i32, [1]) + + constraint.verify(attr, None) + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_memref_single_rank_constraint_verify_invalid(self, rank): + """Test that verifying a MemRefType attribute with a different rank as the + MemRefConstraint raises a VerifyException.""" + constraint = MemRefConstraint(rank=rank) + attr = builtin.MemRefType(builtin.i32, [1] * (rank + 1)) + + with pytest.raises(VerifyException, match=f"Invalid value {rank + 1}, expected {rank}"): + constraint.verify(attr, None) + + def test_memref_multi_rank_constraint_verify_invalid(self): + """Test that verifying a MemRefType attribute with a different rank as the + MemRefConstraint raises a VerifyException.""" + + rank = {3, 4, 5} + invalid_rank = 2 + constraint = MemRefConstraint(rank=rank) + attr = builtin.MemRefType(builtin.i32, [1] * invalid_rank) + + with pytest.raises( + VerifyException, match=f"Invalid value {invalid_rank}, expected one of {rank}" + ): + constraint.verify(attr, None) + + def test_memref_constraint_verify_invalid_type(self): + """Test that verifying an attribute with a type other than MemRefType raises a + VerifyException.""" + constraint = MemRefConstraint(rank=1) + attr = builtin.TensorType(builtin.i32, [1]) + + with pytest.raises(VerifyException, match=f"{attr} should be of type MemRefType"): + constraint.verify(attr, None) + + def test_memref_shape_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid shape raises an exception.""" + constraint = MemRefConstraint(shape=(1, 2, 3)) + attr = builtin.MemRefType(builtin.i32, (2, 2, 3)) + + with pytest.raises(VerifyException, match=r"Expected attribute \[.*\] but got \[.*\]"): + constraint.verify(attr, None) + + def test_memref_element_type_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid element type raises an exception.""" + constraint = MemRefConstraint(element_type=builtin.i32) + attr = builtin.MemRefType(builtin.i64, [1]) + + with pytest.raises( + VerifyException, match=f"Expected attribute {builtin.i32} but got {builtin.i64}" + ): + constraint.verify(attr, None) + + def test_memref_constraint_integration(self, my_dialect): + """Test that verification of legal operations with memref operand/result constraints does + not raise an exception.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> memref<2xf64> + %1 = "test.op"() : () -> memref<2x4xf64> + %2 = "test.op"() : () -> memref<2x4x5x3xf64> + %3 = "test.op"() : () -> memref<1x2x3xf64> + %4 = "my_dialect.memref_float64"(%0) : (memref<2xf64>) -> memref<2xf64> + %5 = "my_dialect.memref_rank1"(%0) : (memref<2xf64>) -> memref<2xf64> + %6 = "my_dialect.memref_rank24"(%1) : (memref<2x4xf64>) -> memref<2x4xf64> + %7 = "my_dialect.memref_rank24"(%2) : (memref<2x4x5x3xf64>) -> memref<2x4x5x3xf64> + %8 = "my_dialect.memref_shape123"(%3) : (memref<1x2x3xf64>) -> memref<1x2x3xf64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + xdsl_module.verify() + + def test_memref_constraint_integration_invalid(self, my_dialect): + """Test that verification of illegal operations with memref operands/result constraints + raises a VerifyException.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> memref<2x2xi64> + %1 = "my_dialect.memref_float64"(%0) : (memref<2x2xi64>) -> memref<2x2xi64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + + with pytest.raises( + VerifyException, + match=f"Expected attribute {builtin.Float64Type()} but got {builtin.i64}", + ): + xdsl_module.verify() + + +class TestTensorConstraint: + """Tests for the TensorConstraint class.""" + + def test_tensor_constraint_init_invalid_element_type(self): + """Test that an error is raised if the provided element_type is invalid""" + + element_type = int + with pytest.raises( + TypeError, match="is not a valid constraint for the 'element_type' argument" + ): + TensorConstraint(element_type=element_type) + + def test_tensor_constraint_init_invalid_shape(self): + """Test that an error is raised if the provided shape is invalid""" + + shape = {1, 2} + with pytest.raises(TypeError, match="is not a valid constraint for the 'shape' argument"): + TensorConstraint(shape=shape) + + def test_tensor_constraint_init_invalid_rank(self): + """Test that an error is raised if the provided rank is invalid""" + + rank = 1.5 + with pytest.raises(TypeError, match="is not a valid constraint for the 'rank' argument"): + TensorConstraint(rank=rank) + + @pytest.mark.parametrize("element_type", [None, builtin.Float64Type(), builtin.i64]) + def test_tensor_constraint_init_rank_and_shape_error(self, element_type): + """Test that an error is raised if both rank and shape are provided.""" + shape = (1, 2) + rank = 2 + + with pytest.raises(ValueError, match="Only one of 'shape' or 'rank' may be provided"): + TensorConstraint(element_type=element_type, shape=shape, rank=rank) + + def test_tensor_constraint_properties(self): + """Test that the properties of TensorConstraint object are correct.""" + rank = 1 + constraint = TensorConstraint(rank=rank) + + assert constraint.expected_type == builtin.TensorType + assert constraint.type_name == "tensor" + assert constraint.mapping_type_vars({}) is constraint + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_tensor_single_rank_constraint_verify_valid(self, rank): + """Test that verifying a TensorType attribute with the same rank as the TensorConstraint + does not raise an exception.""" + constraint = TensorConstraint(rank=rank) + attr = builtin.TensorType(builtin.i32, [1] * rank) + + constraint.verify(attr, None) + + def test_tensor_multi_rank_constraint_verify_valid(self): + """Test that verifying a TensorType attribute with any of the ranks as the TensorConstraint + does not raise an exception.""" + rank = (3, 4, 5) + constraint = TensorConstraint(rank=rank) + + for r in rank: + attr = builtin.TensorType(builtin.i32, [1] * r) + constraint.verify(attr, None) + + def test_tensor_shape_constraint_verify_valid(self): + """Test that verifying an attribute with a valid shape does not raise an exception.""" + constraint = TensorConstraint(shape=(1, 2, 3)) + attr = builtin.TensorType(builtin.i32, (1, 2, 3)) + + constraint.verify(attr, None) + + def test_tensor_element_type_constraint_verify_valid(self): + """Test that verifying an attribute with a valid element type does not raise an exception.""" + constraint = TensorConstraint(element_type=builtin.i32) + attr = builtin.TensorType(builtin.i32, [1]) + + constraint.verify(attr, None) + + @pytest.mark.parametrize("rank", [0, 1, 2]) + def test_tensor_single_rank_constraint_verify_invalid(self, rank): + """Test that verifying a TensorType attribute with a different rank as the + TensorConstraint raises a VerifyException.""" + constraint = TensorConstraint(rank=rank) + attr = builtin.TensorType(builtin.i32, [1] * (rank + 1)) + + with pytest.raises(VerifyException, match=f"Invalid value {rank + 1}, expected {rank}"): + constraint.verify(attr, None) + + def test_tensor_multi_rank_constraint_verify_invalid(self): + """Test that verifying a TensorType attribute with a different rank as the + TensorConstraint raises a VerifyException.""" + + rank = {3, 4, 5} + invalid_rank = 2 + constraint = TensorConstraint(rank=rank) + attr = builtin.TensorType(builtin.i32, [1] * invalid_rank) + + with pytest.raises( + VerifyException, match=f"Invalid value {invalid_rank}, expected one of {rank}" + ): + constraint.verify(attr, None) + + def test_tensor_constraint_verify_invalid_type(self): + """Test that verifying an attribute with a type other than TensorType raises a + VerifyException.""" + constraint = TensorConstraint(rank=1) + attr = builtin.MemRefType(builtin.i32, [1]) + + with pytest.raises(VerifyException, match=f"{attr} should be of type TensorType"): + constraint.verify(attr, None) + + def test_tensor_shape_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid shape raises an exception.""" + constraint = TensorConstraint(shape=(1, 2, 3)) + attr = builtin.TensorType(builtin.i32, (2, 2, 3)) + + with pytest.raises(VerifyException, match=r"Expected attribute \[.*\] but got \[.*\]"): + constraint.verify(attr, None) + + def test_tensor_element_type_constraint_verify_invalid(self): + """Test that verifying an attribute with an invalid element type raises an exception.""" + constraint = TensorConstraint(element_type=builtin.i32) + attr = builtin.TensorType(builtin.i64, [1]) + + with pytest.raises( + VerifyException, match=f"Expected attribute {builtin.i32} but got {builtin.i64}" + ): + constraint.verify(attr, None) + + def test_tensor_constraint_integration(self, my_dialect): + """Test that verification of legal operations with tensor operand/result constraints does + not raise an exception.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> tensor<2xf64> + %1 = "test.op"() : () -> tensor<2x4xf64> + %2 = "test.op"() : () -> tensor<2x4x5x3xf64> + %3 = "test.op"() : () -> tensor<1x2x3xf64> + %4 = "my_dialect.tensor_float64"(%0) : (tensor<2xf64>) -> tensor<2xf64> + %5 = "my_dialect.tensor_rank1"(%0) : (tensor<2xf64>) -> tensor<2xf64> + %6 = "my_dialect.tensor_rank24"(%1) : (tensor<2x4xf64>) -> tensor<2x4xf64> + %7 = "my_dialect.tensor_rank24"(%2) : (tensor<2x4x5x3xf64>) -> tensor<2x4x5x3xf64> + %8 = "my_dialect.tensor_shape123"(%3) : (tensor<1x2x3xf64>) -> tensor<1x2x3xf64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + xdsl_module.verify() + + def test_tensor_constraint_integration_invalid(self, my_dialect): + """Test that verification of illegal operations with tensor operands/result constraints + raises a VerifyException.""" + program = """ + func.func public @test_workload() -> () { + %0 = "test.op"() : () -> tensor<2x2xi64> + %1 = "my_dialect.tensor_float64"(%0) : (tensor<2x2xi64>) -> tensor<2x2xi64> + func.return + } + """ + + ctx = Context(allow_unregistered=False) + xdsl_module: builtin.ModuleOp = QuantumParser( + ctx, program, extra_dialects=(test.Test, my_dialect) + ).parse_module() + + with pytest.raises( + VerifyException, + match=f"Expected attribute {builtin.Float64Type()} but got {builtin.i64}", + ): + xdsl_module.verify() + + +class TestNestedTupleOfConstraint: + """Tests for the NestedTupleOfConstraint class.""" + + constraint = NestedTupleOfConstraint([TensorType, TokenType]) + + def test_nested_tuple_of_constraint(self): + """Test that the properties of NestedTupleOfConstraint object are correct.""" + assert self.constraint.elem_constraints == (BaseAttr(TensorType), BaseAttr(TokenType)) + + def test_nested_tuple_of_constraint_verify_valid(self): + """Test that verifying a valid tuple of tensor and token types passes.""" + tensor = TensorType(i32, [2]) + token = TokenType() + tup = TupleType([tensor, token]) + self.constraint.verify(tup, ConstraintContext()) + + def test_nested_tuple_of_constraint_accepts_two_tensors(self): + """Test that any mix of allowed types is accepted.""" + tensor1 = TensorType(i32, [2]) + tensor2 = TensorType(i1, [1]) + tup = TupleType([tensor1, tensor2]) + self.constraint.verify(tup, ConstraintContext()) + + def test_nested_tuple_of_constraint_accepts_reversed_order(self): + """Test that the order of the tuple is not enforced.""" + tensor = TensorType(i32, [2]) + token = TokenType() + tup = TupleType([token, tensor]) + self.constraint.verify(tup, ConstraintContext()) + + def test_nested_tuple_of_constraint_accepts_nested(self): + """Test that nested tuples are accepted.""" + tensor1 = TensorType(i32, [2]) + tensor2 = TensorType(i1, [1]) + token = TokenType() + inner = TupleType([token, tensor2]) + outer = TupleType([tensor1, inner]) + self.constraint.verify(outer, ConstraintContext()) + + def test_nested_tuple_of_constraint_rejects_disallowed_type(self): + """Test that a tuple with a disallowed type raises a VerifyException.""" + tensor = TensorType(i32, [2]) + memref = MemRefType(i32, [2]) + tup = TupleType([tensor, memref]) + with pytest.raises( + VerifyException, match="tuple leaf 1 failed all allowed constraints: memref<2xi32>" + ): + self.constraint.verify(tup, ConstraintContext()) diff --git a/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py b/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py new file mode 100644 index 0000000000..6be241a163 --- /dev/null +++ b/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py @@ -0,0 +1,321 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the traits defined within the xdsl_extras module.""" + +from typing import TypeAlias + +import pytest + +# pylint: disable=wrong-import-position +# pylint: disable=too-few-public-methods + +xdsl = pytest.importorskip("xdsl") + +from xdsl.dialects.builtin import AnyAttr, TensorType, f32, i32, i64 +from xdsl.ir import Attribute + +# xdsl imports +from xdsl.irdl import ( + IRDLOperation, + attr_def, + irdl_op_definition, + operand_def, + result_def, + traits_def, + var_operand_def, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.test_value import create_ssa_value + +# Import the custom traits we want to test +from catalyst.python_interface.xdsl_extras.traits import ( + AllMatchSameOperatorTrait, + Elementwise, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +pytestmark = pytest.mark.external + +AnyTensorType: TypeAlias = TensorType[Attribute] + + +# Test the SameOperandsAndResultShape trait +def test_same_operands_and_result_shape_trait(): + """Test the SameOperandsAndResultShape trait.""" + + @irdl_op_definition + class ShapeTestOp(IRDLOperation): + """Test operation for the SameOperandsAndResultShape trait.""" + + name = "test.shape_test" + traits = traits_def(SameOperandsAndResultShape()) + operand = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + assert ShapeTestOp.has_trait(SameOperandsAndResultShape) + + operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ShapeTestOp.create(operands=[operand], result_types=[TensorType(i32, [2, 3])]) + op.verify() + + op = ShapeTestOp.create(operands=[operand], result_types=[TensorType(i32, [3, 2])]) + with pytest.raises(VerifyException, match="requires the same shape"): + op.verify() + + +# Test the SameOperandsElementType trait +def test_same_operands_element_type_trait(): + """Test the SameOperandsElementType trait.""" + + @irdl_op_definition + class ShapeTestOp(IRDLOperation): + """Test operation for the SameOperandsElementType trait.""" + + name = "test.element_type_test" + traits = traits_def(SameOperandsElementType()) + + operand1 = operand_def(AnyTensorType) + operand2 = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + assert ShapeTestOp.has_trait(SameOperandsElementType) + + operand1 = create_ssa_value(TensorType(i32, [2, 3])) + operand2 = create_ssa_value(TensorType(i32, [3, 2])) + op = ShapeTestOp.create(operands=[operand1, operand2], result_types=[TensorType(i32, [2, 2])]) + op.verify() + + operand3 = create_ssa_value(TensorType(f32, [2, 3])) + op = ShapeTestOp.create(operands=[operand1, operand3], result_types=[TensorType(f32, [2, 3])]) + + with pytest.raises(VerifyException, match="requires the same element type for all operands"): + op.verify() + + +# Test the SameOperandsAndResultElementType trait +def test_same_operands_and_result_element_type_trait(): + """Test the SameOperandsAndResultElementType trait.""" + + @irdl_op_definition + class ElementTypeTestOp(IRDLOperation): + """Test operation for the SameOperandsAndResultElementType trait.""" + + name = "test.element_type_test" + traits = traits_def(SameOperandsAndResultElementType()) + + operand = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + assert ElementTypeTestOp.has_trait(SameOperandsAndResultElementType) + + op = ElementTypeTestOp.create( + operands=[create_ssa_value(TensorType(i32, [2, 3]))], + result_types=[TensorType(i32, [2, 3])], + ) + op.verify() + + op = ElementTypeTestOp.create( + operands=[create_ssa_value(TensorType(i32, [2, 3]))], + result_types=[TensorType(f32, [2, 3])], + ) + with pytest.raises( + VerifyException, match="requires the same element type for all operands and results" + ): + op.verify() + + +# Test the Elementwise trait +@irdl_op_definition +class ElementwiseTestOp(IRDLOperation): + """Test operation for the Elementwise trait.""" + + name = "test.elementwise_test" + traits = traits_def(Elementwise()) + + operand = var_operand_def(AnyAttr()) + result = result_def(AnyAttr()) + + +def test_elementwise_trait(): + """Test the Elementwise trait.""" + assert ElementwiseTestOp.has_trait(Elementwise) + + operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ElementwiseTestOp.create(operands=[operand], result_types=[TensorType(i32, [2, 3])]) + op.verify() + + operand = create_ssa_value(i32) + op = ElementwiseTestOp.create(operands=[operand], result_types=[i32]) + op.verify() + + scalar_operand = create_ssa_value(i32) + tensor_operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ElementwiseTestOp.create( + operands=[scalar_operand, tensor_operand], result_types=[TensorType(i32, [2, 3])] + ) + op.verify() + + +def test_elementwise_trait_failure_no_tensor_result(): + """Test that Elementwise trait fails when operand is tensor but result is scalar.""" + operand = create_ssa_value(TensorType(i32, [2, 3])) + op = ElementwiseTestOp.create(operands=[operand], result_types=[i32]) + + with pytest.raises( + VerifyException, + match="if an operand is non-scalar, then there must be at least one non-scalar result", + ): + op.verify() + + +def test_elementwise_trait_failure_no_tensor_operand(): + """Test that Elementwise trait fails when result is tensor but operand is scalar.""" + operand = create_ssa_value(i32) + op = ElementwiseTestOp.create(operands=[operand], result_types=[TensorType(i32, [2, 3])]) + + with pytest.raises( + VerifyException, + match="if a result is non-scalar, then at least one operand must be non-scalar", + ): + op.verify() + + +def test_elementwise_trait_failure_shape_mismatch(): + """Test that Elementwise trait fails for shape mismatches.""" + operand1 = create_ssa_value(TensorType(i32, [2, 3])) + operand2 = create_ssa_value(TensorType(i32, [3, 2])) + op = ElementwiseTestOp.create( + operands=[operand1, operand2], result_types=[TensorType(i32, [2, 3])] + ) + + with pytest.raises( + VerifyException, + match="all non-scalar operands/results must have the same shape and base type", + ): + op.verify() + + +def test_all_shapes_match(): + """AllMatchSameOperatorTrait: shape equality on TensorType attributes.""" + + @irdl_op_definition + class ShapeMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.shapes_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = ShapeMockOp.create(attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [2, 3])}) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.get_shape(), "shape") + trait.verify(op) + + op = ShapeMockOp.create(attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [2, 4])}) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same shape"): + trait.verify(op) + + +def test_all_ranks_match(): + """AllMatchSameOperatorTrait: rank equality on TensorType attributes.""" + + @irdl_op_definition + class RankMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.ranks_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = RankMockOp.create(attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [4, 5])}) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.get_num_dims(), "rank") + trait.verify(op) + + op = RankMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [1, 2, 3])} + ) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same rank"): + trait.verify(op) + + +def test_all_element_types_match(): + """AllMatchSameOperatorTrait: element type equality on TensorType attributes.""" + + @irdl_op_definition + class ElemTypeMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.elem_types_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = ElemTypeMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [1, 6])} + ) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.get_element_type(), "element type") + trait.verify(op) + + op = ElemTypeMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(f32, [2, 3])} + ) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same element type"): + trait.verify(op) + + +def test_all_element_counts_match(): + """AllMatchSameOperatorTrait: element count equality on TensorType attributes.""" + + @irdl_op_definition + class ElemCountMockOp(IRDLOperation): + """Test operation for the AllMatchSameOperatorTrait trait.""" + + name = "test.elem_counts_match" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + op = ElemCountMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [6])} + ) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda a: a.element_count(), "element count") + trait.verify(op) + + op = ElemCountMockOp.create( + attributes={"a": TensorType(i64, [2, 3]), "b": TensorType(i64, [2, 4])} + ) + with pytest.raises(VerifyException, match=r"all of \{a, b\} must have the same element count"): + trait.verify(op) + + +def test_operator_cannot_compute_raises_verifyexception(): + """Trait should raise when it cannot compute the property for given attributes.""" + + @irdl_op_definition + class CannotComputeMockOp(IRDLOperation): + name = "test.cannot_compute" + traits = traits_def() + a = attr_def(AnyAttr()) + b = attr_def(AnyAttr()) + + # Use non-shaped attributes; calling get_shape should fail + op = CannotComputeMockOp.create(attributes={"a": i64, "b": f32}) + trait = AllMatchSameOperatorTrait(("a", "b"), lambda x: x.get_shape(), "shape") + + with pytest.raises(VerifyException, match=r"cannot compute shape for \{a, b\}:"): + trait.verify(op) From 953670797baefb20095880ecfc5d941ed46ae99b Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 14:22:30 -0500 Subject: [PATCH 02/38] Fixing migration artifacts; linting --- .dep-versions | 2 - .../dialects/test_catalyst_dialect.py | 5 +- .../dialects/test_mbqc_dialect.py | 19 +++--- .../dialects/test_qec_dialect.py | 6 +- .../dialects/test_quantum_dialect.py | 7 +- .../dialects/test_stablehlo_dialect.py | 9 +-- .../dialects/test_transform_dialect.py | 13 ++-- .../python_interface/test_python_compiler.py | 52 +++++++-------- .../python_interface/test_xdsl_utils.py | 5 +- .../transforms/mbqc/test_graph_state_utils.py | 5 +- .../test_xdsl_convert_to_mbqc_formalism.py | 19 +++--- .../mbqc/test_xdsl_decompose_graph_state.py | 6 +- .../mbqc/test_xdsl_outline_state_evolution.py | 46 +++++++------ .../quantum/test_xdsl_cancel_inverses.py | 10 +-- .../test_xdsl_combine_global_phases.py | 9 +-- .../test_xdsl_diagonalize_measurements.py | 33 ++++------ .../test_xdsl_measurements_from_samples.py | 12 ++-- .../quantum/test_xdsl_merge_rotations.py | 9 +-- .../quantum/test_xdsl_split_non_commuting.py | 18 +++-- .../test_draw_unified_compiler.py | 66 +++++++++++-------- .../visualization/test_mlir_graph.py | 31 ++++----- .../xdsl_extras/test_constraints.py | 26 ++------ .../xdsl_extras/test_traits.py | 8 +-- 23 files changed, 175 insertions(+), 241 deletions(-) diff --git a/.dep-versions b/.dep-versions index e6aaadbec6..e26b195774 100644 --- a/.dep-versions +++ b/.dep-versions @@ -5,8 +5,6 @@ jax=0.6.2 stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 -xdsl -xdsl-jax=0.1.0 # Always remove custom PL/LQ versions before release. diff --git a/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py index e5af80a305..4f5c149ac0 100644 --- a/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py @@ -18,10 +18,7 @@ # pylint: disable=wrong-import-position -xdsl = pytest.importorskip("xdsl") -filecheck = pytest.importorskip("filecheck") - -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") from catalyst.python_interface.dialects import Catalyst diff --git a/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py index 3ca61a8ba5..6e776e5175 100644 --- a/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py @@ -14,16 +14,15 @@ """Unit test module for pennylane/compiler/python_compiler/dialects/mbqc.py.""" -import pytest - -# pylint: disable=wrong-import-position -xdsl = pytest.importorskip("xdsl") -filecheck = pytest.importorskip("filecheck") +import pytest -pytestmark = pytest.mark.external +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") +from xdsl.context import Context from xdsl.dialects import arith, builtin, test +from xdsl.parser import Parser from xdsl.utils.exceptions import VerifyException from catalyst.python_interface.dialects import Quantum, mbqc @@ -115,14 +114,14 @@ def test_measure_in_basis_properties(self, plane, postselect): %mres, %out_qubit = mbqc.measure_in_basis [{plane}, %angle] %qubit {postselect} : i1, !quantum.bit """ - ctx = xdsl.context.Context() + ctx = Context() ctx.load_dialect(builtin.Builtin) ctx.load_dialect(test.Test) ctx.load_dialect(Quantum) ctx.load_dialect(mbqc.MBQC) - module = xdsl.parser.Parser(ctx, program).parse_module() + module = Parser(ctx, program).parse_module() measure_in_basis_op: mbqc.MeasureInBasisOp = module.ops.last assert isinstance(measure_in_basis_op, mbqc.MeasureInBasisOp) @@ -146,14 +145,14 @@ def test_invalid_postselect_raises_on_verify(self, postselect): %mres, %out_qubit = mbqc.measure_in_basis [XY, %angle] %qubit postselect {postselect} : i1, !quantum.bit """ - ctx = xdsl.context.Context() + ctx = Context() ctx.load_dialect(builtin.Builtin) ctx.load_dialect(test.Test) ctx.load_dialect(Quantum) ctx.load_dialect(mbqc.MBQC) - module = xdsl.parser.Parser(ctx, program).parse_module() + module = Parser(ctx, program).parse_module() measure_in_basis_op: mbqc.MeasureInBasisOp = module.ops.last assert isinstance(measure_in_basis_op, mbqc.MeasureInBasisOp) diff --git a/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py index 274e8f4465..02ed4a6f97 100644 --- a/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py @@ -17,11 +17,7 @@ import pytest # pylint: disable=wrong-import-position - -xdsl = pytest.importorskip("xdsl") -filecheck = pytest.importorskip("filecheck") - -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") from catalyst.python_interface.dialects import QEC diff --git a/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py index 843412be09..d83561a193 100644 --- a/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py @@ -17,11 +17,7 @@ import pytest # pylint: disable=wrong-import-position - -xdsl = pytest.importorskip("xdsl") -filecheck = pytest.importorskip("filecheck") - -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") from xdsl.dialects.builtin import ( I32, @@ -205,6 +201,7 @@ def test_only_existing_operations_are_expected(): @pytest.mark.parametrize("op", all_ops) def test_operation_construction(op): + """Test the constructors of operations in the Quantum dialect.""" kwargs = expected_ops_init_kwargs[op.__name__] _ = op(**kwargs) diff --git a/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py index ea4492d690..c63ed80e44 100644 --- a/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py @@ -13,15 +13,10 @@ # limitations under the License. """Unit test module for pennylane/compiler/python_compiler/dialects/stablehlo.py.""" - +# pylint: disable=line-too-long import pytest -# pylint: disable=wrong-import-position - -xdsl = pytest.importorskip("xdsl") -filecheck = pytest.importorskip("filecheck") - -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") def test_all_unary_operations(run_filecheck): diff --git a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py index 04ac810c22..ddfc4ad156 100644 --- a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py @@ -18,18 +18,15 @@ import pytest -# pylint: disable=wrong-import-position - -xdsl = pytest.importorskip("xdsl") -filecheck = pytest.importorskip("filecheck") - -pytestmark = pytest.mark.external +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") from xdsl import passes from xdsl.context import Context from xdsl.dialects import builtin from xdsl.dialects.builtin import DictionaryAttr, IntegerAttr, i64 from xdsl.dialects.transform import AnyOpType +from xdsl.passes import PassPipeline from xdsl.utils.exceptions import VerifyException from xdsl.utils.test_value import create_ssa_value @@ -138,12 +135,12 @@ def program(): } """ - ctx = xdsl.context.Context() + ctx = Context() ctx.load_dialect(builtin.Builtin) ctx.load_dialect(transform.Transform) mod = program() - pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) + pipeline = PassPipeline((ApplyTransformSequence(),)) pipeline.apply(ctx, mod) assert "Hello from custom option!" in capsys.readouterr().out diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index 2110477fde..f8b32f2aaf 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -16,24 +16,22 @@ from dataclasses import dataclass -# pylint: disable=wrong-import-position +# pylint: disable=wrong-import-position,line-too-long import pytest -pytestmark = pytest.mark.catalyst - -catalyst = pytest.importorskip("catalyst") -jax = pytest.importorskip("jax") -jaxlib = pytest.importorskip("jaxlib") -xdsl = pytest.importorskip("xdsl") +pytestmark = pytest.mark.usefixtures("requires_xdsl") +import jax import pennylane as qml +from jaxlib.mlir.ir import Module as jModule from pennylane.capture import enabled as capture_enabled from xdsl import passes from xdsl.context import Context from xdsl.dialects import builtin from xdsl.interpreters import Interpreter +from xdsl.passes import PassPipeline -from catalyst import CompileError +from catalyst import CompileError, qjit from catalyst.passes import apply_pass from catalyst.passes import cancel_inverses as catalyst_cancel_inverses from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath @@ -88,7 +86,7 @@ def identity(x): input_module = identity(1) retval = Compiler.run(input_module) - assert isinstance(retval, jaxlib.mlir.ir.Module) + assert isinstance(retval, jModule) assert str(retval) == str(input_module) @@ -139,7 +137,7 @@ def program(): """ retval = Compiler.run(program()) - assert isinstance(retval, jaxlib.mlir.ir.Module) + assert isinstance(retval, jModule) def test_generic_catalyst_program_as_string(): @@ -250,11 +248,11 @@ def program(): } """ - ctx = xdsl.context.Context() + ctx = Context() ctx.load_dialect(builtin.Builtin) ctx.load_dialect(transform.Transform) - pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) + pipeline = PassPipeline((ApplyTransformSequence(),)) pipeline.apply(ctx, program()) captured = capsys.readouterr() assert captured.out.strip() == "hello world" @@ -263,14 +261,14 @@ def program(): class TestCatalystIntegration: """Tests for integration of the Python compiler with Catalyst""" - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_integration_catalyst_no_passes_with_capture(self): """Test that the xDSL plugin can be used even when no passes are applied when capture is enabled.""" assert capture_enabled() - @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) @qml.qnode(qml.device("lightning.qubit", wires=2)) def f(x): qml.RX(x, 0) @@ -285,7 +283,7 @@ def test_integration_catalyst_no_passes_no_capture(self): assert not capture_enabled() - @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) @qml.qnode(qml.device("lightning.qubit", wires=2)) def f(x): qml.RX(x, 0) @@ -294,14 +292,14 @@ def f(x): out = f(1.5) assert jax.numpy.allclose(out, jax.numpy.cos(1.5)) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_integration_catalyst_xdsl_pass_with_capture(self, capsys): """Test that a pass is run via the transform interpreter when using with a qjit workflow and capture is enabled.""" assert capture_enabled() - @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) @hello_world_pass @qml.qnode(qml.device("lightning.qubit", wires=2)) def f(x): @@ -319,7 +317,7 @@ def test_integration_catalyst_xdsl_pass_no_capture(self, capsys): assert not capture_enabled() - @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) @apply_pass("hello-world") @qml.qnode(qml.device("lightning.qubit", wires=2)) def f(x): @@ -331,14 +329,14 @@ def f(x): captured = capsys.readouterr() assert captured.out.strip() == "hello world" - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_integration_catalyst_mixed_passes_with_capture(self, capsys): """Test that both Catalyst and Python compiler passes can be used with qjit when capture is enabled.""" assert capture_enabled() - @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) @hello_world_pass @qml.transforms.cancel_inverses @qml.qnode(qml.device("lightning.qubit", wires=2)) @@ -359,7 +357,7 @@ def test_integration_catalyst_mixed_passes_no_capture(self, capsys): assert not capture_enabled() - @catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) @apply_pass("hello-world") @catalyst_cancel_inverses @qml.qnode(qml.device("lightning.qubit", wires=2)) @@ -408,9 +406,7 @@ def program(): ctx = Context() ctx.load_dialect(builtin.Builtin) - pipeline = xdsl.passes.PassPipeline( - (ApplyTransformSequence(callback=print_between_passes),) - ) + pipeline = PassPipeline((ApplyTransformSequence(callback=print_between_passes),)) pipeline.apply(ctx, program()) captured = capsys.readouterr() assert captured.out.strip() == "hello world" @@ -418,7 +414,6 @@ def program(): def test_callback_prints_module_after_each_pass(self, capsys): """Test that the callback prints the module after each pass""" - # pylint: disable=redefined-outer-name def print_between_passes(_, module, __, pass_level=0): if pass_level == 0: return @@ -460,9 +455,7 @@ def program_2_passes(): ctx = Context() ctx.load_dialect(builtin.Builtin) - pipeline = xdsl.passes.PassPipeline( - (ApplyTransformSequence(callback=print_between_passes),) - ) + pipeline = PassPipeline((ApplyTransformSequence(callback=print_between_passes),)) pipeline.apply(ctx, program_2_passes()) out = capsys.readouterr().out @@ -483,11 +476,10 @@ def program_2_passes(): assert printed_modules[0] != printed_modules[1], "IR should differ between passes" - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_callback_run_integration(self, capsys): """Test that the callback is integrated into the pass pipeline with the Compiler.run() method""" - # pylint: disable=redefined-outer-name def print_between_passes(_, module, __, pass_level=0): if pass_level == 0: return diff --git a/frontend/test/pytest/python_interface/test_xdsl_utils.py b/frontend/test/pytest/python_interface/test_xdsl_utils.py index 23254dd7ec..7d5a7dff8b 100644 --- a/frontend/test/pytest/python_interface/test_xdsl_utils.py +++ b/frontend/test/pytest/python_interface/test_xdsl_utils.py @@ -16,10 +16,9 @@ import pytest -pytestmark = pytest.mark.external -xdsl = pytest.importorskip("xdsl") +pytestmark = pytest.mark.usefixtures("requires_xdsl") -# pylint: disable=wrong-import-position +# pylint: disable=wrong-import-position,line-too-long from xdsl.dialects import arith, builtin, tensor, test from catalyst.python_interface.dialects.stablehlo import ConstantOp as hloConstantOp diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py index 37fbde22ee..96c19ed3e1 100644 --- a/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_graph_state_utils.py @@ -17,10 +17,7 @@ import pytest -pytestmark = pytest.mark.external - -pytest.importorskip("xdsl") -pytest.importorskip("catalyst") +pytestmark = pytest.mark.usefixtures("requires_xdsl") from pennylane.exceptions import CompileError diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py index e6fc9a3229..9bfe8328f0 100644 --- a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_convert_to_mbqc_formalism.py @@ -14,15 +14,12 @@ """Unit test module for the convert to MBQC formalism transform""" import pytest -pytestmark = pytest.mark.external - -xdsl = pytest.importorskip("xdsl") -catalyst = pytest.importorskip("catalyst") +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml from pennylane.ftqc import RotXZX -# pylint: disable=wrong-import-position from catalyst.ftqc import mbqc_pipeline from catalyst.python_interface.transforms import ( ConvertToMBQCFormalismPass, @@ -377,7 +374,7 @@ def test_function_no_body(self, run_filecheck): pipeline = (ConvertToMBQCFormalismPass(),) run_filecheck(program, pipeline) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_gates_in_mbqc_gate_set_lowering(self, run_filecheck_qjit): """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and unrolled loops.""" dev = qml.device("null.qubit", wires=1000) @@ -421,7 +418,7 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_gates_in_mbqc_gate_set_lowering_for(self, run_filecheck_qjit): """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and for-loop structure.""" dev = qml.device("null.qubit", wires=1000) @@ -461,13 +458,13 @@ def circuit(): # CHECK: quantum.custom "PauliX"() # CHECK: quantum.custom "PauliZ"() # CHECK: quantum.dealloc_qb - loop_for() + loop_for() # pylint: disable=no-value-for-parameter qml.CNOT(wires=[0, 1]) return qml.expval(qml.Z(wires=0)) run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_gates_in_mbqc_gate_set_lowering_graph_state_decomp(self, run_filecheck_qjit): """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and for-loop structure.""" dev = qml.device("null.qubit", wires=1000) @@ -508,7 +505,7 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_gates_in_mbqc_gate_set_lowering_while(self, run_filecheck_qjit): """Test that the convert_to_mbqc_formalism_pass works correctly with qjit and while-loop structure.""" dev = qml.device("null.qubit", wires=1000) @@ -548,7 +545,7 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_gates_in_mbqc_gate_set_e2e(self): """Test that the convert_to_mbqc_formalism_pass end to end on null.qubit.""" dev = qml.device("null.qubit", wires=1000) diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py index 6c74d782e0..9b48c93835 100644 --- a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py @@ -21,13 +21,11 @@ each use. """ -# pylint: disable=wrong-import-position +# pylint: disable=wrong-import-position,line-too-long import pytest -pytestmark = pytest.mark.external - -xdsl = pytest.importorskip("xdsl") +pytestmark = pytest.mark.usefixtures("requires_xdsl") from catalyst.python_interface.transforms import ( DecomposeGraphStatePass, diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py index af040ecca0..eb9525ec5e 100644 --- a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py @@ -14,15 +14,12 @@ """Unit test module for the outline state evolution transform""" import pytest -pytestmark = pytest.mark.external - -xdsl = pytest.importorskip("xdsl") -catalyst = pytest.importorskip("catalyst") +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml from pennylane.ftqc import RotXZX -# pylint: disable=wrong-import-position from catalyst.ftqc import mbqc_pipeline from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.transforms import ( @@ -144,9 +141,10 @@ def test_multiple_func_w_qnode_attr(self, run_filecheck): pipeline = (OutlineStateEvolutionPass(),) run_filecheck(program, pipeline) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_outline_state_evolution_no_error(self): - """Test outline_state_evolution_pass does not raise error for circuit with classical operations only.""" + """Test outline_state_evolution_pass does not raise error for circuit with classical + operations only.""" @qml.qjit( target="mlir", @@ -158,10 +156,12 @@ def circuit(x, y): circuit(1, 4) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_outline_state_evolution_no_terminal_op_error(self): - """Test outline_state_evolution_pass raises error when no terminal_boundary_op is found in circuit with quantum operation.""" - # TODOs: we can resolve this issue if the boundary op is inserted when the program is captured. + """Test outline_state_evolution_pass raises error when no terminal_boundary_op is found in + circuit with quantum operation.""" + # TODOs: we can resolve this issue if the boundary op is inserted when + # the program is captured. dev = qml.device("null.qubit", wires=10) @qml.qjit( @@ -178,7 +178,7 @@ def circuit(): ): circuit() - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_outline_state_evolution_pass_only(self, run_filecheck_qjit): """Test the outline_state_evolution_pass only.""" dev = qml.device("lightning.qubit", wires=1000) @@ -216,9 +216,10 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_filecheck_qjit): - """Test if the outline_state_evolution_pass works with the convert-to-mbqc-formalism pass on lightning.qubit.""" + """Test if the outline_state_evolution_pass works with the convert-to-mbqc-formalism pass + on lightning.qubit.""" dev = qml.device("lightning.qubit", wires=1000) @qml.qjit( @@ -265,9 +266,10 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qjit): - """Test if the outline_state_evolution_pass works with all mbqc transform pipeline on null.qubit.""" + """Test if the outline_state_evolution_pass works with all mbqc transform pipeline on + null.qubit.""" dev = qml.device("null.qubit", wires=1000) @qml.qjit( @@ -314,9 +316,10 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self): - """Test if a circuit can be transfored with the outline_state_evolution_pass and all mbqc transform pipeline can be executed on null.qubit.""" + """Test if a circuit can be transfored with the outline_state_evolution_pass and all mbqc + transform pipeline can be executed on null.qubit.""" dev = qml.device("null.qubit", wires=1000) @qml.qjit( @@ -343,9 +346,10 @@ def circuit(): res = circuit() assert res == 1.0 - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_lightning_execution_with_structure(self): - """Test that the outline_state_evolution_pass on lightning.qubit for a circuit with program structure is executable and returns results as expected.""" + """Test that the outline_state_evolution_pass on lightning.qubit for a circuit with program + structure is executable and returns results as expected.""" dev = qml.device("lightning.qubit", wires=10) @qml.for_loop(0, 10, 1) @@ -369,7 +373,7 @@ def while_fn(i): @outline_state_evolution_pass @qml.qnode(dev) def circuit(): - for_fn() + for_fn() # pylint: disable=no-value-for-parameter while_fn(0) qml.CNOT(wires=[0, 1]) return qml.expval(qml.prod(qml.X(0), qml.Z(1))) @@ -381,7 +385,7 @@ def circuit(): ) @qml.qnode(dev) def circuit_ref(): - for_fn() + for_fn() # pylint: disable=no-value-for-parameter while_fn(0) qml.CNOT(wires=[0, 1]) return qml.expval(qml.prod(qml.X(0), qml.Z(1))) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py index db3ddfa312..f5b8eb1159 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py @@ -15,14 +15,11 @@ import pytest -pytestmark = pytest.mark.external - -pytest.importorskip("xdsl") -pytest.importorskip("catalyst") +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml -# pylint: disable=wrong-import-position from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.transforms import ( IterativeCancelInversesPass, @@ -188,8 +185,7 @@ def test_non_consecutive_self_inverse_ops(self, run_filecheck): run_filecheck(program, pipeline) -# pylint: disable=too-few-public-methods -@pytest.mark.usefixtures("enable_disable_plxpr") +@pytest.mark.usefixtures("use_capture") class TestIterativeCancelInversesIntegration: """Integration tests for the IterativeCancelInversesPass.""" diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py index cdfa2795d4..6d35b7c2ef 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py @@ -14,14 +14,11 @@ """Unit test module for the combine global phases transform""" import pytest -pytestmark = pytest.mark.external - -pytest.importorskip("xdsl") -pytest.importorskip("catalyst") +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml -# pylint: disable=wrong-import-position from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.transforms import ( CombineGlobalPhasesPass, @@ -218,7 +215,7 @@ def test_combinable_ops_in_control_flow_while(self, run_filecheck): # pylint: disable=too-few-public-methods -@pytest.mark.usefixtures("enable_disable_plxpr") +@pytest.mark.usefixtures("use_capture") class TestCombineGlobalPhasesIntegration: """Integration tests for the CombineGlobalPhasesPass.""" diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py index 376dd20923..d719ca9f7e 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_diagonalize_measurements.py @@ -19,14 +19,11 @@ import numpy as np import pytest -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") -xdsl = pytest.importorskip("xdsl") - -catalyst = pytest.importorskip("catalyst") import pennylane as qml -from catalyst.passes import xdsl_plugin +from catalyst.passes import apply_pass, xdsl_plugin from catalyst.python_interface.transforms import ( DiagonalizeFinalMeasurementsPass, diagonalize_final_measurements_pass, @@ -311,7 +308,7 @@ class TestDiagonalizeFinalMeasurementsProgramCaptureExecution: """Integration tests going through plxpr (program capture enabled)""" # pylint: disable=unnecessary-lambda - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") @pytest.mark.parametrize( "mp, obs, expected_res", [ @@ -350,7 +347,7 @@ def circuit_ref(phi): assert np.allclose(expected_res(angle), circuit_compiled(angle)) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_with_composite_observables(self): """Test the transform works for an observable built using operator arithmetic (sprod, prod, sum)""" @@ -383,7 +380,7 @@ def expected_res(x, y): assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_with_multiple_measurements(self): """Test that the transform runs and returns the expected results for a circuit with multiple measurements""" @@ -413,7 +410,7 @@ def expected_res(x, y): assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_overlapping_observables_raises_error(self): """Test the case where multiple overlapping (commuting) observables exist in the same circuit (an error is raised - split_non_commuting should have been applied).""" @@ -433,7 +430,7 @@ def circuit(x): _ = circuit(1.23) @pytest.mark.xfail(reason="for now, assume split_non_commuting is always applied") - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_non_commuting_observables_raise_error(self): """Check that an error is raised if we try to diagonalize a circuit that contains non-commuting observables.""" @@ -489,9 +486,7 @@ def circuit_ref(phi): ), "Sanity check failed, is expected_res correct?" circuit_compiled = qml.qjit( - catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")( - circuit_ref - ), + apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")(circuit_ref), ) np.allclose(expected_res(angle), circuit_compiled(angle)) @@ -523,9 +518,7 @@ def expected_res(x, y): ), "Sanity check failed, is expected_res correct?" circuit_compiled = qml.qjit( - catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")( - circuit_ref - ), + apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")(circuit_ref), ) assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) @@ -553,9 +546,7 @@ def expected_res(x, y): ), "Sanity check failed, is expected_res correct?" circuit_compiled = qml.qjit( - catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")( - circuit_ref - ), + apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements")(circuit_ref), ) assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) @@ -567,7 +558,7 @@ def test_overlapping_observables_raises_error(self): dev = qml.device("lightning.qubit", wires=2) @qml.qjit() - @catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") + @apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") @qml.qnode(dev) def circuit(x): qml.RX(x, 0) @@ -585,7 +576,7 @@ def test_non_commuting_observables_raise_error(self): dev = qml.device("lightning.qubit", wires=1) @qml.qjit() - @catalyst.passes.apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") + @apply_pass("catalyst_xdsl_plugin.diagonalize-final-measurements") @qml.qnode(dev) def circuit(x): qml.RX(x, 0) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py index baaf6d381a..8889bb08ac 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py @@ -14,17 +14,15 @@ """Unit and integration tests for the Python compiler `measurements_from_samples` transform.""" -# pylint: disable=wrong-import-position +# pylint: disable=wrong-import-position,line-too-long from functools import partial import numpy as np import pytest -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") -xdsl = pytest.importorskip("xdsl") -catalyst = pytest.importorskip("catalyst") import pennylane as qml from catalyst.passes import xdsl_plugin @@ -430,7 +428,7 @@ def test_2_wire_probs_per_wire(self, run_filecheck): run_filecheck(program, pipeline) -@pytest.mark.usefixtures("enable_disable_plxpr") +@pytest.mark.usefixtures("use_capture") class TestMeasurementsFromSamplesIntegration: """Tests of the execution of simple workloads with the xDSL-based MeasurementsFromSamplesPass transform. @@ -853,8 +851,10 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_integrate_with_decompose(self): + """Test that the measurements_from_samples pass works correctly when used in combination + with the decompose pass.""" dev = qml.device("null.qubit", wires=4) @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py index a6f8ba4e28..3dfb9dad77 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_merge_rotations.py @@ -14,14 +14,11 @@ """Unit test module for the merge rotations transform""" import pytest -pytestmark = pytest.mark.external - -pytest.importorskip("xdsl") -pytest.importorskip("catalyst") +# pylint: disable=wrong-import-position,line-too-long +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml -# pylint: disable=wrong-import-position from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.transforms import MergeRotationsPass, merge_rotations_pass @@ -221,7 +218,7 @@ def test_adjoint_property(self, first_adj, second_adj, sign, run_filecheck): # pylint: disable=too-few-public-methods -@pytest.mark.usefixtures("enable_disable_plxpr") +@pytest.mark.usefixtures("use_capture") class TestMergeRotationsIntegration: """Integration tests for the MergeRotationsPass.""" diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py index ddf5cddb5f..30179ec736 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_split_non_commuting.py @@ -14,14 +14,11 @@ """Unit test module for the split non-commuting transform""" import pytest -pytestmark = pytest.mark.external - -xdsl = pytest.importorskip("xdsl") -catalyst = pytest.importorskip("catalyst") +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml -# pylint: disable=wrong-import-position from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.transforms import ( SplitNonCommutingPass, @@ -220,7 +217,7 @@ def test_func_w_mixed_measurements(self, run_filecheck): pipeline = (SplitNonCommutingPass(),) run_filecheck(program, pipeline) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_split_non_commuting_pass_only(self, run_filecheck_qjit): """Test the split non-commuting pass only.""" dev = qml.device("lightning.qubit", wires=5) @@ -258,9 +255,10 @@ def circuit(): run_filecheck_qjit(circuit) - @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.usefixtures("use_capture") def test_lightning_execution_with_structure(self): - """Test that the split non-commuting pass on lightning.qubit for a circuit with program structure is executable and returns results as expected.""" + """Test that the split non-commuting pass on lightning.qubit for a circuit with program + structure is executable and returns results as expected.""" dev = qml.device("lightning.qubit", wires=10) @qml.for_loop(0, 10, 1) @@ -284,7 +282,7 @@ def while_fn(i): @split_non_commuting_pass @qml.qnode(dev) def circuit(): - for_fn() + for_fn() # pylint: disable=no-value-for-parameter while_fn(0) qml.CNOT(wires=[0, 1]) return ( @@ -300,7 +298,7 @@ def circuit(): ) @qml.qnode(dev) def circuit_ref(): - for_fn() + for_fn() # pylint: disable=no-value-for-parameter while_fn(0) qml.CNOT(wires=[0, 1]) return ( diff --git a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py index c828f8f87c..80cfe437ad 100644 --- a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py +++ b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py @@ -16,24 +16,21 @@ import pytest -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") -pytest.importorskip("xdsl") -pytest.importorskip("catalyst") - -# pylint: disable=wrong-import-position +# pylint: disable=wrong-import-position,unnecessary-lambda import jax import pennylane as qml -# pylint: disable=wrong-import-position from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath -from catalyst.python_interface.transforms import iterative_cancel_inverses_pass +from catalyst.python_interface.transforms import ( + iterative_cancel_inverses_pass, + merge_rotations_pass, +) from catalyst.python_interface.visualization import draw -# pylint: disable=implicit-str-concat, unnecessary-lambda - -@pytest.mark.usefixtures("enable_disable_plxpr") +@pytest.mark.usefixtures("use_capture") class Testdraw: """Unit tests for the draw function in the Python Compiler visualization module.""" @@ -79,16 +76,17 @@ def circ(): "1: ──RY───────╰X─╰X──H───H─│──│──┤ State\n" "2: ──RZ────────────────────╰X─╰X─┤ State", ), - (2, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), - (None, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), - (50, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (2, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), + (None, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State1: ──RY─────┤ State2: ──RZ─────┤ State"), ], ) def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected): - """Test that multiple levels of transformation are applied correctly with xDSL compilation passes.""" + """Test that multiple levels of transformation are applied correctly with xDSL + compilation passes.""" transforms_circuit = iterative_cancel_inverses_pass( - qml.compiler.python_compiler.transforms.merge_rotations_pass(transforms_circuit) + merge_rotations_pass(transforms_circuit) ) if qjit: @@ -114,13 +112,14 @@ def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected): "1: ──RY───────╰X─╰X──H───H─│──│──┤ State\n" "2: ──RZ────────────────────╰X─╰X─┤ State", ), - (2, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), - (None, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), - (50, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (2, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), + (None, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), ], ) def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expected): - """Test that multiple levels of transformation are applied correctly with Catalyst compilation passes.""" + """Test that multiple levels of transformation are applied correctly with Catalyst + compilation passes.""" transforms_circuit = qml.transforms.cancel_inverses( qml.transforms.merge_rotations(transforms_circuit) @@ -149,13 +148,14 @@ def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expecte "1: ──RY───────╰X─╰X──H───H─│──│──┤ State\n" "2: ──RZ────────────────────╰X─╰X─┤ State", ), - (2, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), - (None, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), - (50, "0: ──RX──RZ─┤ State\n" "1: ──RY─────┤ State\n" "2: ──RZ─────┤ State"), + (2, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), + (None, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), ], ) def test_multiple_levels_xdsl_catalyst(self, transforms_circuit, level, qjit, expected): - """Test that multiple levels of transformation are applied correctly with xDSL and Catalyst compilation passes.""" + """Test that multiple levels of transformation are applied correctly with xDSL and + Catalyst compilation passes.""" transforms_circuit = iterative_cancel_inverses_pass( qml.transforms.merge_rotations(transforms_circuit) @@ -287,7 +287,9 @@ def circuit(): qml.expval(qml.Y(1) @ qml.Z(2) @ qml.X(0)), qml.expval(qml.Z(2) @ qml.X(0) @ qml.Y(1)), ), - "0: ──RX─┤ ╭\n1: ──RY─┤ ╰\n2: ──RZ─┤ ╰", + "0: ──RX─┤ ╭\n" + "1: ──RY─┤ ╰\n" + "2: ──RZ─┤ ╰", ), ( lambda: ( @@ -296,7 +298,10 @@ def circuit(): @ qml.Hamiltonian([0.1, 0.1], [qml.PauliZ(2), qml.PauliZ(3)]) ) ), - "0: ──RX─┤ ╭<(𝓗)@(𝓗)>\n1: ──RY─┤ ├<(𝓗)@(𝓗)>\n2: ──RZ─┤ ├<(𝓗)@(𝓗)>\n3: ─────┤ ╰<(𝓗)@(𝓗)>", + "0: ──RX─┤ ╭<(𝓗)@(𝓗)>\n" + "1: ──RY─┤ ├<(𝓗)@(𝓗)>\n" + "2: ──RZ─┤ ├<(𝓗)@(𝓗)>\n" + "3: ─────┤ ╰<(𝓗)@(𝓗)>", ), ( lambda: (qml.var(qml.X(0)), qml.var(qml.Y(1)), qml.var(qml.Z(2))), @@ -308,7 +313,9 @@ def circuit(): qml.var(qml.Y(1) @ qml.Z(2) @ qml.X(0)), qml.var(qml.Z(2) @ qml.X(0) @ qml.Y(1)), ), - "0: ──RX─┤ ╭Var[X@Y] ╭Var[Y@Z@X] ╭Var[Z@X@Y]\n1: ──RY─┤ ╰Var[X@Y] ├Var[Y@Z@X] ├Var[Z@X@Y]\n2: ──RZ─┤ ╰Var[Y@Z@X] ╰Var[Z@X@Y]", + "0: ──RX─┤ ╭Var[X@Y] ╭Var[Y@Z@X] ╭Var[Z@X@Y]\n" + "1: ──RY─┤ ╰Var[X@Y] ├Var[Y@Z@X] ├Var[Z@X@Y]\n" + "2: ──RZ─┤ ╰Var[Y@Z@X] ╰Var[Z@X@Y]", ), ], ) @@ -340,9 +347,10 @@ def circuit(): qml.GlobalPhase(0.5) return qml.state() - assert ( - draw(circuit)() - == "0: ──H─╭GlobalPhase─┤ State\n1: ──H─├GlobalPhase─┤ State\n2: ──H─╰GlobalPhase─┤ State" + assert draw(circuit)() == ( + "0: ──H─╭GlobalPhase─┤ State\n" + "1: ──H─├GlobalPhase─┤ State\n" + "2: ──H─╰GlobalPhase─┤ State" ) @pytest.mark.parametrize( diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py index d643ab4e82..6e564637b8 100644 --- a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py +++ b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py @@ -17,15 +17,11 @@ import pytest -pytestmark = pytest.mark.external - -pytest.importorskip("xdsl") -pytest.importorskip("catalyst") - +# pylint: disable=wrong-import-position +pytestmark = pytest.mark.usefixtures("requires_xdsl") import pennylane as qml -# pylint: disable=wrong-import-position from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.transforms import ( iterative_cancel_inverses_pass, @@ -53,13 +49,13 @@ def assert_files(tmp_path: Path, expected: set[str]): assert files == expected, f"Expected {expected}, got {files}" -@pytest.mark.usefixtures("enable_disable_plxpr") +@pytest.mark.usefixtures("use_capture") class TestMLIRGraph: - "Test the MLIR graph generation" + """Test the MLIR graph generation""" @pytest.mark.parametrize("qjit", [True, False]) def test_no_transforms(self, tmp_path: Path, qjit: bool): - "Test the MLIR graph is still generated when no transforms are applied" + """Test the MLIR graph is still generated when no transforms are applied""" @qml.qnode(qml.device("lightning.qubit", wires=3)) def _(): @@ -77,7 +73,7 @@ def _(): @pytest.mark.parametrize("qjit", [True, False]) def test_xdsl_transforms_no_args(self, tmp_path: Path, qjit: bool): - "Test the MLIR graph generation with no arguments to the QNode with and without qjit" + """Test the MLIR graph generation with no arguments to the QNode with and without qjit""" @merge_rotations_pass @iterative_cancel_inverses_pass @@ -104,7 +100,7 @@ def _(): @pytest.mark.parametrize("qjit", [True, False]) def test_xdsl_transforms_args(self, tmp_path: Path, qjit: bool): - "Test the MLIR graph generation with arguments to the QNode for xDSL transforms" + """Test the MLIR graph generation with arguments to the QNode for xDSL transforms""" @merge_rotations_pass @iterative_cancel_inverses_pass @@ -129,7 +125,7 @@ def _(x, y, w1, w2): @pytest.mark.parametrize("qjit", [True, False]) def test_catalyst_transforms_args(self, tmp_path: Path, qjit: bool): - "Test the MLIR graph generation with arguments to the QNode for catalyst transforms" + """Test the MLIR graph generation with arguments to the QNode for catalyst transforms""" @qml.transforms.merge_rotations @qml.transforms.cancel_inverses @@ -154,7 +150,8 @@ def _(x, y, w1, w2): @pytest.mark.parametrize("qjit", [True, False]) def test_catalyst_xdsl_transforms_args(self, tmp_path: Path, qjit: bool): - "Test the MLIR graph generation with arguments to the QNode for catalyst and xDSL transforms" + """Test the MLIR graph generation with arguments to the QNode for catalyst and xDSL + transforms""" @qml.transforms.merge_rotations @iterative_cancel_inverses_pass @@ -178,7 +175,7 @@ def _(x, y, w1, w2): ) def test_cond(self, tmp_path: Path): - "Test the MLIR graph generation for a conditional" + """Test the MLIR graph generation for a conditional""" @merge_rotations_pass @qml.qnode(qml.device("lightning.qubit", wires=3)) @@ -210,7 +207,7 @@ def false_fn(arg1, arg2): ) def test_cond_with_mcm(self, tmp_path: Path): - "Test the MLIR graph generation for a conditional with MCM" + """Test the MLIR graph generation for a conditional with MCM""" def true_fn(arg): qml.RX(arg, 0) @@ -239,7 +236,7 @@ def _(x, y): ) def test_for_loop(self, tmp_path: Path): - "Test the MLIR graph generation for a for loop" + """Test the MLIR graph generation for a for loop""" @merge_rotations_pass @qml.qnode(qml.device("lightning.qubit", wires=3)) @@ -263,7 +260,7 @@ def loop(_): ) def test_while_loop(self, tmp_path: Path): - "Test the MLIR graph generation for a while loop" + """Test the MLIR graph generation for a while loop""" @merge_rotations_pass @qml.qnode(qml.device("lightning.qubit", wires=3)) diff --git a/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py b/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py index 90ce76b125..46a53070f0 100644 --- a/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py +++ b/frontend/test/pytest/python_interface/xdsl_extras/test_constraints.py @@ -16,9 +16,7 @@ import pytest -pytestmark = pytest.mark.external - -pytest.importorskip("xdsl") +pytestmark = pytest.mark.usefixtures("requires_xdsl") # pylint: disable=wrong-import-position from xdsl.context import Context @@ -54,8 +52,6 @@ def my_dialect_fixture(): class Float64MemrefOp(IRDLOperation): """A test op with float memref types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.memref_float64" in_value = operand_def(MemRefConstraint(element_type=builtin.Float64Type())) out_value = result_def(MemRefConstraint(element_type=builtin.Float64Type())) @@ -64,8 +60,6 @@ class Float64MemrefOp(IRDLOperation): class Float64TensorOp(IRDLOperation): """A test op with float tensor types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.tensor_float64" in_value = operand_def(TensorConstraint(element_type=builtin.Float64Type())) out_value = result_def(TensorConstraint(element_type=builtin.Float64Type())) @@ -74,8 +68,6 @@ class Float64TensorOp(IRDLOperation): class Rank1MemrefOp(IRDLOperation): """A test op with rank-1 memref types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.memref_rank1" in_value = operand_def(MemRefConstraint(rank=1)) out_value = result_def(MemRefConstraint(rank=1)) @@ -84,8 +76,6 @@ class Rank1MemrefOp(IRDLOperation): class Rank1TensorOp(IRDLOperation): """A test op with rank-1 tensor types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.tensor_rank1" in_value = operand_def(TensorConstraint(rank=1)) out_value = result_def(TensorConstraint(rank=1)) @@ -94,8 +84,6 @@ class Rank1TensorOp(IRDLOperation): class Rank2Or4MemrefOp(IRDLOperation): """A test op with rank-2 or -4 memref types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.memref_rank24" in_value = operand_def(MemRefConstraint(rank=(2, 4))) out_value = result_def(MemRefConstraint(rank=(2, 4))) @@ -104,8 +92,6 @@ class Rank2Or4MemrefOp(IRDLOperation): class Rank2Or4TensorOp(IRDLOperation): """A test op with rank-2 or -4 tensor types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.tensor_rank24" in_value = operand_def(TensorConstraint(rank=(2, 4))) out_value = result_def(TensorConstraint(rank=(2, 4))) @@ -114,8 +100,6 @@ class Rank2Or4TensorOp(IRDLOperation): class Shape123MemrefOp(IRDLOperation): """A test op with shape-(1, 2, 3) memref types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.memref_shape123" in_value = operand_def(MemRefConstraint(shape=(1, 2, 3))) out_value = result_def(MemRefConstraint(shape=(1, 2, 3))) @@ -124,8 +108,6 @@ class Shape123MemrefOp(IRDLOperation): class Shape123TensorOp(IRDLOperation): """A test op with shape-(1, 2, 3) tensor types""" - # pylint: disable=too-few-public-methods - name = "my_dialect.tensor_shape123" in_value = operand_def(TensorConstraint(shape=(1, 2, 3))) out_value = result_def(TensorConstraint(shape=(1, 2, 3))) @@ -218,7 +200,8 @@ def test_memref_shape_constraint_verify_valid(self): constraint.verify(attr, None) def test_memref_element_type_constraint_verify_valid(self): - """Test that verifying an attribute with a valid element type does not raise an exception.""" + """Test that verifying an attribute with a valid element type does not raise an + exception.""" constraint = MemRefConstraint(element_type=builtin.i32) attr = builtin.MemRefType(builtin.i32, [1]) @@ -393,7 +376,8 @@ def test_tensor_shape_constraint_verify_valid(self): constraint.verify(attr, None) def test_tensor_element_type_constraint_verify_valid(self): - """Test that verifying an attribute with a valid element type does not raise an exception.""" + """Test that verifying an attribute with a valid element type does not raise an + exception.""" constraint = TensorConstraint(element_type=builtin.i32) attr = builtin.TensorType(builtin.i32, [1]) diff --git a/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py b/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py index 6be241a163..45ba543a42 100644 --- a/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py +++ b/frontend/test/pytest/python_interface/xdsl_extras/test_traits.py @@ -19,9 +19,7 @@ import pytest # pylint: disable=wrong-import-position -# pylint: disable=too-few-public-methods - -xdsl = pytest.importorskip("xdsl") +pytestmark = pytest.mark.usefixtures("requires_xdsl") from xdsl.dialects.builtin import AnyAttr, TensorType, f32, i32, i64 from xdsl.ir import Attribute @@ -48,7 +46,7 @@ SameOperandsElementType, ) -pytestmark = pytest.mark.external +pytestmark = pytest.mark.usefixtures("requires_xdsl") AnyTensorType: TypeAlias = TensorType[Attribute] @@ -308,6 +306,8 @@ def test_operator_cannot_compute_raises_verifyexception(): @irdl_op_definition class CannotComputeMockOp(IRDLOperation): + """Test operation with no traits.""" + name = "test.cannot_compute" traits = traits_def() a = attr_def(AnyAttr()) From 1453aae5c2561737679854313f19436d6ac33ef4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 14:54:14 -0500 Subject: [PATCH 03/38] Fix CI errors; Ignore unified compiler in coverage reports --- .codecov.yml | 3 + .github/workflows/check-catalyst.yaml | 2 + frontend/catalyst/compiler.py | 6 +- frontend/catalyst/jit.py | 6 +- .../catalyst/python_interface/__init__.py | 1 - .../dialects/stablehlo/__init__.py | 105 ++++++++---------- .../python_interface/pass_api/__init__.py | 2 +- .../pass_api/apply_transform_sequence.py | 3 +- .../pass_api/compiler_transform.py | 4 +- .../python_interface/transforms/__init__.py | 22 ++-- .../python_interface/xdsl_extras/__init__.py | 2 +- 11 files changed, 75 insertions(+), 81 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index e45b6334f3..dc9a2b56c6 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -7,3 +7,6 @@ coverage: status: project: false patch: true + +ignore: + - "frontend/catalyst/python_interface" \ No newline at end of file diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 999e5ad137..7ffc60a8a1 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -466,6 +466,8 @@ jobs: run: | sudo apt-get update sudo apt-get install -y libasan6 make + # Install graphviz for testing the mlir-op-graph integration + sudo apt-get install -y graphviz python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} python3 -m pip install -r requirements.txt # cuda-quantum is added manually here. diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 2b29fdcc2f..71537e8e53 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -376,14 +376,14 @@ def to_llvmir(*args, stdin=None, options: Optional[CompileOptions] = None): def to_mlir_opt( - *args, stdin=None, options: Optional[CompileOptions] = None, using_python_compiler=False + *args, stdin=None, options: Optional[CompileOptions] = None, using_unified_compiler=False ): """echo ${input} | catalyst --tool=opt *args *opts -""" # Check if we need to use Python compiler for xDSL passes - if using_python_compiler: + if using_unified_compiler: # Use Python compiler path for xDSL passes # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler import Compiler as PythonCompiler + from catalyst.python_interface import Compiler as PythonCompiler compiler = PythonCompiler() stdin = compiler.run(stdin, callback=None) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 81f873917c..62b0acbde2 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -582,13 +582,13 @@ def mlir_opt(self): """Obtain the MLIR representation after optimization""" if not self.mlir_module: return None - using_python_compiler = self.compiler.is_using_python_compiler(self.mlir_module) + using_unified_compiler = self.compiler.is_using_unified_compiler(self.mlir_module) stdin = self.mlir_module.operation.get_asm( - print_generic_op_form=using_python_compiler, + print_generic_op_form=using_unified_compiler, enable_debug_info=self.compile_options.use_nameloc, ) return to_mlir_opt( - stdin=stdin, options=self.compile_options, using_python_compiler=using_python_compiler + stdin=stdin, options=self.compile_options, using_unified_compiler=using_unified_compiler ) @debug_logger diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py index b38cdd4744..85fb8ceef5 100644 --- a/frontend/catalyst/python_interface/__init__.py +++ b/frontend/catalyst/python_interface/__init__.py @@ -18,7 +18,6 @@ from .pass_api import compiler_transform from .visualization import QMLCollector - __all__ = [ "Compiler", "compiler_transform", diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py index 2eb1cb67ce..64380529ac 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py @@ -19,6 +19,51 @@ functionality. """ +from .attributes import ( + CustomCallApiVersion, + CustomCallApiVersionAttr, + GatherDimensionNumbers, + OutputOperandAlias, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, +) +from .control_flow import ( + IfOp, + OptimizationBarrierOp, + WhileOp, +) +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) + +# Import the main StableHLO dialect +from .dialect import StableHLO +from .dynamism import ( + DynamicBroadcastInDimOp, +) +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) +from .elementwise_other import ( + ClampOp, + CompareOp, + ConstantOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) + # Import all elementwise operations explicitly from .elementwise_unary import ( ConvertOp, @@ -28,9 +73,9 @@ FloorOp, ImagOp, IsFiniteOp, + LogisticOp, LogOp, LogPlusOneOp, - LogisticOp, NegateOp, RealOp, RoundNearestAfzOp, @@ -39,68 +84,16 @@ SignOp, SineOp, SqrtOp, - TanOp, TanhOp, + TanOp, ) - -from .elementwise_binary import ( - ComplexOp, - DivideOp, - MaximumOp, - MinimumOp, - PowerOp, - RemainderOp, -) - -from .elementwise_other import ( - ClampOp, - CompareOp, - ConstantOp, - MapOp, - ReducePrecisionOp, - SelectOp, -) - -from .control_flow import ( - IfOp, - WhileOp, - OptimizationBarrierOp, -) - -from .data_movement import ( - BroadcastInDimOp, - ConcatenateOp, - DynamicSliceOp, - GatherOp, - ReshapeOp, - ScatterOp, - SliceOp, -) - -from .dynamism import ( - DynamicBroadcastInDimOp, -) - -from .reduction import ( - ReduceOp, -) - from .extensibility import ( CustomCallOp, ) - -from .attributes import ( - GatherDimensionNumbers, - ResultAccuracyModeAttr, - ScatterDimensionNumbers, - CustomCallApiVersion, - CustomCallApiVersionAttr, - OutputOperandAlias, +from .reduction import ( + ReduceOp, ) -# Import the main StableHLO dialect -from .dialect import StableHLO - # Export all operations and the dialect for external use __all__ = [ # Main dialect diff --git a/frontend/catalyst/python_interface/pass_api/__init__.py b/frontend/catalyst/python_interface/pass_api/__init__.py index 8f2560030f..ceb32f9de5 100644 --- a/frontend/catalyst/python_interface/pass_api/__init__.py +++ b/frontend/catalyst/python_interface/pass_api/__init__.py @@ -19,8 +19,8 @@ is_xdsl_pass, register_pass, ) -from .transform_interpreter import TransformFunctionsExt, TransformInterpreterPass from .compiler_transform import PassDispatcher, compiler_transform +from .transform_interpreter import TransformFunctionsExt, TransformInterpreterPass __all__ = [ "ApplyTransformSequence", diff --git a/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py index a9da6640a0..ba93e75eef 100644 --- a/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py +++ b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py @@ -16,12 +16,11 @@ from dataclasses import dataclass +from pennylane.typing import Callable from xdsl.context import Context from xdsl.dialects import builtin from xdsl.passes import ModulePass, PassPipeline -from pennylane.typing import Callable - from .transform_interpreter import TransformInterpreterPass available_passes = {} diff --git a/frontend/catalyst/python_interface/pass_api/compiler_transform.py b/frontend/catalyst/python_interface/pass_api/compiler_transform.py index 2016c16866..87943c555d 100644 --- a/frontend/catalyst/python_interface/pass_api/compiler_transform.py +++ b/frontend/catalyst/python_interface/pass_api/compiler_transform.py @@ -15,10 +15,10 @@ from collections.abc import Callable -from catalyst.from_plxpr import register_transform +from pennylane.transforms.core.transform_dispatcher import TransformDispatcher from xdsl.passes import ModulePass -from pennylane.transforms.core.transform_dispatcher import TransformDispatcher +from catalyst.from_plxpr import register_transform from .apply_transform_sequence import register_pass diff --git a/frontend/catalyst/python_interface/transforms/__init__.py b/frontend/catalyst/python_interface/transforms/__init__.py index ac661ba0de..905893b2bf 100644 --- a/frontend/catalyst/python_interface/transforms/__init__.py +++ b/frontend/catalyst/python_interface/transforms/__init__.py @@ -14,32 +14,30 @@ """PennyLane-xDSL transformations API.""" from .mbqc import ( - convert_to_mbqc_formalism_pass, ConvertToMBQCFormalismPass, - decompose_graph_state_pass, DecomposeGraphStatePass, - outline_state_evolution_pass, + NullDecomposeGraphStatePass, OutlineStateEvolutionPass, + convert_to_mbqc_formalism_pass, + decompose_graph_state_pass, null_decompose_graph_state_pass, - NullDecomposeGraphStatePass, + outline_state_evolution_pass, ) - from .quantum import ( - combine_global_phases_pass, CombineGlobalPhasesPass, - diagonalize_final_measurements_pass, DiagonalizeFinalMeasurementsPass, - iterative_cancel_inverses_pass, IterativeCancelInversesPass, - measurements_from_samples_pass, MeasurementsFromSamplesPass, - merge_rotations_pass, MergeRotationsPass, - split_non_commuting_pass, SplitNonCommutingPass, + combine_global_phases_pass, + diagonalize_final_measurements_pass, + iterative_cancel_inverses_pass, + measurements_from_samples_pass, + merge_rotations_pass, + split_non_commuting_pass, ) - __all__ = [ # Quantum "combine_global_phases_pass", diff --git a/frontend/catalyst/python_interface/xdsl_extras/__init__.py b/frontend/catalyst/python_interface/xdsl_extras/__init__.py index 094b56f4fc..16785fe2c6 100644 --- a/frontend/catalyst/python_interface/xdsl_extras/__init__.py +++ b/frontend/catalyst/python_interface/xdsl_extras/__init__.py @@ -14,7 +14,7 @@ """This module contains additional utilities and functionality not available upstream in xDSL.""" -from .constraints import MemRefConstraint, TensorConstraint, NestedTupleOfConstraint +from .constraints import MemRefConstraint, NestedTupleOfConstraint, TensorConstraint from .traits import ( AllMatchSameOperatorTrait, Elementwise, From efc0f88bb2228094f4618af68297ff30ab136a17 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 15:11:31 -0500 Subject: [PATCH 04/38] Test out how graphviz is installed --- .github/workflows/check-catalyst.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 7ffc60a8a1..6ed3d7edfd 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -466,8 +466,6 @@ jobs: run: | sudo apt-get update sudo apt-get install -y libasan6 make - # Install graphviz for testing the mlir-op-graph integration - sudo apt-get install -y graphviz python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} python3 -m pip install -r requirements.txt # cuda-quantum is added manually here. @@ -475,6 +473,8 @@ jobs: # macOS requirements.txt python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client + # Install graphviz for testing the mlir-op-graph integration + python3 -m pip install graphviz make frontend - name: Get Cached LLVM Build From 10cd74937ff34f3b42ba4d1a68229b74ef210185 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 15:50:37 -0500 Subject: [PATCH 05/38] Lint some more --- .../catalyst/python_interface/compiler.py | 4 +- .../catalyst/python_interface/conversion.py | 6 +- .../python_interface/dialects/catalyst.py | 6 +- .../python_interface/dialects/mbqc.py | 10 +-- .../catalyst/python_interface/dialects/qec.py | 2 - .../python_interface/dialects/quantum.py | 3 +- .../dialects/stablehlo/attributes.py | 2 +- .../dialects/stablehlo/control_flow.py | 4 +- .../dialects/stablehlo/data_movement.py | 7 +- .../dialects/stablehlo/dynamism.py | 5 +- .../dialects/stablehlo/elementwise_binary.py | 2 - .../dialects/stablehlo/elementwise_other.py | 2 - .../dialects/stablehlo/elementwise_unary.py | 2 - .../dialects/stablehlo/extensibility.py | 8 +-- .../dialects/stablehlo/reduction.py | 5 +- .../dialects/stablehlo/types.py | 2 - .../python_interface/dialects/transform.py | 3 +- frontend/catalyst/python_interface/parser.py | 3 +- .../pass_api/apply_transform_sequence.py | 2 +- .../pass_api/transform_interpreter.py | 2 +- .../mbqc/convert_to_mbqc_formalism.py | 65 ++++++++++++------- .../transforms/mbqc/decompose_graph_state.py | 10 ++- .../transforms/mbqc/graph_state_utils.py | 24 ++++--- .../mbqc/outline_state_evolution.py | 47 ++++++++------ .../transforms/quantum/cancel_inverses.py | 8 +-- .../quantum/combine_global_phases.py | 11 ++-- .../quantum/diagonalize_measurements.py | 33 +++++----- .../quantum/measurements_from_samples.py | 18 ++--- .../transforms/quantum/merge_rotations.py | 8 +-- .../transforms/quantum/split_non_commuting.py | 28 +++++--- .../python_interface/visualization/draw.py | 10 +-- .../visualization/mlir_graph.py | 8 ++- .../visualization/xdsl_conversion.py | 4 +- .../xdsl_extras/constraints.py | 3 - .../python_interface/xdsl_extras/traits.py | 3 +- 35 files changed, 188 insertions(+), 172 deletions(-) diff --git a/frontend/catalyst/python_interface/compiler.py b/frontend/catalyst/python_interface/compiler.py index bded5d4552..f98ff2b8d8 100644 --- a/frontend/catalyst/python_interface/compiler.py +++ b/frontend/catalyst/python_interface/compiler.py @@ -18,8 +18,8 @@ from jax._src.interpreters import mlir from jaxlib.mlir.dialects import stablehlo -from jaxlib.mlir.ir import Context as jaxContext # pylint: disable=no-name-in-module -from jaxlib.mlir.ir import Module as jaxModule # pylint: disable=no-name-in-module +from jaxlib.mlir.ir import Context as jaxContext +from jaxlib.mlir.ir import Module as jaxModule from pennylane.typing import Callable from xdsl.context import Context as xContext from xdsl.dialects.builtin import ModuleOp diff --git a/frontend/catalyst/python_interface/conversion.py b/frontend/catalyst/python_interface/conversion.py index 614bc65ca0..975a46b3cb 100644 --- a/frontend/catalyst/python_interface/conversion.py +++ b/frontend/catalyst/python_interface/conversion.py @@ -19,9 +19,9 @@ from typing import TypeAlias from jax._src.lib import _jax -from jaxlib.mlir.dialects import stablehlo as jstablehlo # pylint: disable=no-name-in-module -from jaxlib.mlir.ir import Context as jContext # pylint: disable=no-name-in-module -from jaxlib.mlir.ir import Module as jModule # pylint: disable=no-name-in-module +from jaxlib.mlir.dialects import stablehlo as jstablehlo +from jaxlib.mlir.ir import Context as jContext +from jaxlib.mlir.ir import Module as jModule from xdsl.context import Context as xContext from xdsl.dialects import builtin as xbuiltin from xdsl.dialects import func as xfunc diff --git a/frontend/catalyst/python_interface/dialects/catalyst.py b/frontend/catalyst/python_interface/dialects/catalyst.py index 7b3c474063..b42b256c37 100644 --- a/frontend/catalyst/python_interface/dialects/catalyst.py +++ b/frontend/catalyst/python_interface/dialects/catalyst.py @@ -14,15 +14,13 @@ """ This file contains the Catalyst dialect for the Python compiler. -This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool) and modified manually -to support the unified compiler. +This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool) +and modified manually to support the unified compiler. The catalyst dialect serves as a standard library for the Catalyst compiler. It contains data structures that support core compiler functionality. """ -# pylint: disable=too-few-public-methods - from typing import ClassVar from xdsl.dialects.builtin import ( diff --git a/frontend/catalyst/python_interface/dialects/mbqc.py b/frontend/catalyst/python_interface/dialects/mbqc.py index 8cf96b4910..b608e037fb 100644 --- a/frontend/catalyst/python_interface/dialects/mbqc.py +++ b/frontend/catalyst/python_interface/dialects/mbqc.py @@ -58,8 +58,6 @@ class MeasurementPlaneEnum(StrEnum): """Enum containing supported measurement-plane attributes""" - # pylint: disable=too-few-public-methods - XY = "XY" YZ = "YZ" ZX = "ZX" @@ -69,8 +67,6 @@ class MeasurementPlaneEnum(StrEnum): class MeasurementPlaneAttr(EnumAttribute[MeasurementPlaneEnum], SpacedOpaqueSyntaxAttribute): """Planes in the Bloch sphere representation with support for arbitrary-basis measurements""" - # pylint: disable=too-few-public-methods - name = "mbqc.measurement_plane" @@ -78,8 +74,6 @@ class MeasurementPlaneAttr(EnumAttribute[MeasurementPlaneEnum], SpacedOpaqueSynt class MeasureInBasisOp(IRDLOperation): """A parametric single-qubit projective measurement in an arbitrary basis.""" - # pylint: disable=too-few-public-methods - name = "mbqc.measure_in_basis" assembly_format = """ @@ -124,7 +118,7 @@ def verify_(self): if self.postselect is None: return - if self.postselect.value.data not in [0, 1]: + if self.postselect.value.data not in [0, 1]: # pylint: disable=no-member raise VerifyException("'postselect' must be 0 or 1.") @@ -132,8 +126,6 @@ def verify_(self): class GraphStatePrepOp(IRDLOperation): """Allocate resources for a new graph state.""" - # pylint: disable=too-few-public-methods - name = "mbqc.graph_state_prep" assembly_format = """ diff --git a/frontend/catalyst/python_interface/dialects/qec.py b/frontend/catalyst/python_interface/dialects/qec.py index 962adaba9f..be5ba288f3 100644 --- a/frontend/catalyst/python_interface/dialects/qec.py +++ b/frontend/catalyst/python_interface/dialects/qec.py @@ -21,8 +21,6 @@ catalyst/mlir/include/QEC/IR/QECDialect.td file in the catalyst repository. """ -# pylint: disable=too-few-public-methods - from xdsl.dialects.builtin import I16, ArrayAttr, IntegerAttr, IntegerType, StringAttr, i16 from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import Attribute, Dialect, EnumAttribute, SpacedOpaqueSyntaxAttribute diff --git a/frontend/catalyst/python_interface/dialects/quantum.py b/frontend/catalyst/python_interface/dialects/quantum.py index a4915fd5fc..8b549f5440 100644 --- a/frontend/catalyst/python_interface/dialects/quantum.py +++ b/frontend/catalyst/python_interface/dialects/quantum.py @@ -20,8 +20,7 @@ It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the catalyst/mlir/include/Quantum/IR/QuantumOps.td file in the catalyst repository. """ - -# pylint: disable=too-few-public-methods +# pylint: disable=too-many-lines from collections.abc import Sequence from typing import TypeAlias diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py index 7f2e60e2f9..7618d77afc 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py @@ -20,7 +20,7 @@ attributes for StableHLO operations. """ -# pylint: disable=too-few-public-methods +# pylint: disable=line-too-long from collections.abc import Sequence diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py index 293c831eaa..c9edced70d 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-few-public-methods - """ Control flow operations for the StableHLO dialect. """ @@ -84,6 +82,7 @@ class IfOp(IRDLOperation): # TODO: Add custom assembly format +# pylint: disable=line-too-long @irdl_op_definition class WhileOp(IRDLOperation): """ @@ -125,6 +124,7 @@ class WhileOp(IRDLOperation): ) +# pylint: disable=line-too-long @irdl_op_definition class OptimizationBarrierOp(IRDLOperation): """ diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py index ee6a9a3109..68ce2bddc7 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-few-public-methods - """ Data movement operations for the StableHLO dialect. """ @@ -53,6 +51,7 @@ from .types import HLO_AnyIntegerOrIndexTensor, HLO_AnyTensor, HLO_Int, HLO_IntTensor, HLO_Tensor +# pylint: disable=line-too-long @irdl_op_definition class BroadcastInDimOp(IRDLOperation): """ @@ -93,7 +92,7 @@ def verify_(self) -> None: assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType) # broadcast_in_dim_c2: broadcast_dimensions size == operand rank - dims = tuple(self.broadcast_dimensions.get_values()) + dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member operand_rank = o_type.get_num_dims() if len(dims) != operand_rank: raise VerifyException( @@ -133,6 +132,7 @@ def verify_(self) -> None: ) +# pylint: disable=line-too-long @irdl_op_definition class ConcatenateOp(IRDLOperation): """ @@ -205,6 +205,7 @@ class DynamicSliceOp(IRDLOperation): ) +# pylint: disable=line-too-long @irdl_op_definition class GatherOp(IRDLOperation): """ diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py index e012c07529..6da7b376d8 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-few-public-methods - """ Dynamism operations for the StableHLO dialect. """ @@ -147,7 +145,8 @@ def verify_(self): # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank out_dims_ty = self.output_dimensions.type # pylint: disable=no-member assert isinstance(out_dims_ty, TensorType) - # Must be rank-1 tensor (enforced by type constraint), and length must match result rank when statically known + # Must be rank-1 tensor (enforced by type constraint), and length must match result + # rank when statically known out_shape = out_dims_ty.get_shape() if len(out_shape) != 1: raise VerifyException("output_dimensions must be a 1D tensor") diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py index 135c865005..8180c7a208 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py @@ -16,8 +16,6 @@ Binary elementwise operations for the StableHLO dialect. """ -# pylint: disable=too-few-public-methods - import abc from typing import Generic, TypeVar diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py index e4d4c7087d..18527a028d 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py @@ -16,8 +16,6 @@ Other elementwise operations for the StableHLO dialect. """ -# pylint: disable=too-few-public-methods - import xdsl_jax.dialects.stablehlo as xstablehlo from xdsl.dialects.builtin import ( AnyFloat, diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py index 5e80f6adb7..a8a6e6ca86 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-few-public-methods - """ Unary elementwise operations for the StableHLO dialect. """ diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py index 6133b5e6c7..e6f8f0542d 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-few-public-methods - """ Dynamism operations for the StableHLO dialect. """ @@ -108,7 +106,8 @@ def verify_(self) -> None: # Layout constraints for either both operands & results or none should be specified. if (self.operand_layouts is None) != (self.result_layouts is None): raise VerifyException( - "Layout attributes should be specified for either both operands and results or none." + "Layout attributes should be specified for either both operands and results " + "or none." ) assert self.operand_layouts is not None and self.result_layouts is not None @@ -148,7 +147,8 @@ def verify_types_and_layouts( rank = ty.get_num_dims() if rank != len(dims) or sorted(dims) != list(range(rank)): raise VerifyException( - f"incorrect layout {dims} for type {ty}, layout must be a permutation of [0, {rank})" + f"incorrect layout {dims} for type {ty}, layout must be a permutation " + f"of [0, {rank})" ) # Operand types diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py index 3597aaaeae..28b332d22f 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-few-public-methods - """ Dynamism operations for the StableHLO dialect. """ @@ -122,7 +120,8 @@ def verify_(self): raise VerifyException("input and init_value must have the same element type") # reduce_c2/c6: verify reducer region shape - # Expect block with arity 2 * number of inputs, with matching tensor element types and 0D tensors + # Expect block with arity 2 * number of inputs, with matching tensor element types + # and 0D tensors if len(self.body.blocks) != 1: raise VerifyException("reducer must have a single block") block = self.body.blocks[0] diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/types.py b/frontend/catalyst/python_interface/dialects/stablehlo/types.py index 7bbbf5ad5f..a27f4fc8f2 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/types.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/types.py @@ -20,8 +20,6 @@ token types and other necessary type definitions for StableHLO operations. """ -# pylint: disable=too-few-public-methods - from typing import TypeAlias from xdsl.dialects.builtin import ( diff --git a/frontend/catalyst/python_interface/dialects/transform.py b/frontend/catalyst/python_interface/dialects/transform.py index 1e0d58963f..32449e7d5e 100644 --- a/frontend/catalyst/python_interface/dialects/transform.py +++ b/frontend/catalyst/python_interface/dialects/transform.py @@ -40,8 +40,6 @@ """ from xdsl.dialects.builtin import Dialect - -# pylint: disable=too-few-public-methods from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp from xdsl.dialects.transform import ( DictionaryAttr, @@ -59,6 +57,7 @@ from xdsl.irdl import IRDLOperation, ParsePropInAttrDict +# pylint: disable=line-too-long @irdl_op_definition class ApplyRegisteredPassOp(IRDLOperation): """ diff --git a/frontend/catalyst/python_interface/parser.py b/frontend/catalyst/python_interface/parser.py index c6b52f05b2..b7169362e8 100644 --- a/frontend/catalyst/python_interface/parser.py +++ b/frontend/catalyst/python_interface/parser.py @@ -28,7 +28,7 @@ from catalyst.python_interface.dialects import MBQC, QEC, Catalyst, Quantum, StableHLO, Transform -class QuantumParser(xParser): # pylint: disable=abstract-method,too-few-public-methods +class QuantumParser(xParser): # pylint: disable=abstract-method """A subclass of ``xdsl.parser.Parser`` that automatically loads relevant dialects into the input context. @@ -54,6 +54,7 @@ class QuantumParser(xParser): # pylint: disable=abstract-method,too-few-public- QEC, ) + # pylint: disable=redefined-builtin def __init__( self, ctx: xContext, diff --git a/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py index ba93e75eef..7a07218b2a 100644 --- a/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py +++ b/frontend/catalyst/python_interface/pass_api/apply_transform_sequence.py @@ -59,7 +59,7 @@ class ApplyTransformSequence(ModulePass): name = "apply-transform-sequence" callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None - def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: # pylint: disable=no-self-use + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: """Applies the transformation""" nested_modules = [] for region in op.regions: diff --git a/frontend/catalyst/python_interface/pass_api/transform_interpreter.py b/frontend/catalyst/python_interface/pass_api/transform_interpreter.py index 87983ec491..cf635c1591 100644 --- a/frontend/catalyst/python_interface/pass_api/transform_interpreter.py +++ b/frontend/catalyst/python_interface/pass_api/transform_interpreter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=line-too-long """Custom Transform Dialect Interpreter Pass Differs from xDSL's upstream implementation by allowing passes @@ -43,7 +44,6 @@ from catalyst.python_interface.dialects.transform import ApplyRegisteredPassOp -# pylint: disable=too-few-public-methods @register_impls class TransformFunctionsExt(TransformFunctions): """ diff --git a/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py b/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py index fcd8bd81df..4476d62e96 100644 --- a/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py +++ b/frontend/catalyst/python_interface/transforms/mbqc/convert_to_mbqc_formalism.py @@ -70,8 +70,8 @@ class ConvertToMBQCFormalismPass(passes.ModulePass): name = "convert-to-mbqc-formalism" def _prep_graph_state(self, gate_name: str): - """Add a graph state prep operation into the subroutine for each gate and extract and return auxiliary qubits - in the graph state. + """Add a graph state prep operation into the subroutine for each gate and extract and + return auxiliary qubits in the graph state. Args: gate_name[str]: Name of gate operation. @@ -115,6 +115,7 @@ def _insert_xy_basis_measure_op( qubit: QubitType, ): """Add an arbitrary basis measure related operations to the subroutine. + Args: const_angle_op (arith.ConstantOp) : The angle of measurement basis. qubit (QubitType) : The target qubit to be measured. @@ -135,12 +136,15 @@ def _insert_cond_arbitrary_basis_measure_op( angle: SSAValue[builtin.Float64Type], plane: str, qubit: QubitType, - ): # pylint: disable=too-many-arguments, too-many-positional-arguments + ): """ - Add a conditional arbitrary basis measurement operation based on a previous measurement result. + Add a conditional arbitrary basis measurement operation based on a previous measurement + result. + Args: meas_parity (builtin.IntegerType) : A parity of previous measurements. - angle (SSAValue[builtin.Float64Type]) : An angle SSAValue from a parametric gate operation. + angle (SSAValue[builtin.Float64Type]) : An angle SSAValue from a parametric gate + operation. plane (str): Plane of the measurement basis. qubit (QubitType) : The target qubit to be measured. @@ -314,10 +318,12 @@ def _parity_check( additional_const_one: bool = False, ): """Add parity check related operations to the subroutine. + Args: mres (list[builtin.IntegerType]): A list of the mid-measurement results. - additional_const_one (bool) : Whether we need to add an additional const one to get the - parity or not. Defaults to False. + additional_const_one (bool) : Whether we need to add an additional const one to + get the parity or not. Defaults to False. + Returns: The result of parity check. """ @@ -341,8 +347,9 @@ def _insert_cond_byproduct_op( parity_res: OpResult, gate_name: str, qubit: QubitType, - ): # pylint: disable=too-many-arguments, too-many-positional-arguments + ): """Add a byproduct op related operations to the subroutine. + Args: parity_res (OpResult) : Parity check result. gate_name (str) : The name of the gate to be corrected. @@ -368,6 +375,7 @@ def _hadamard_corrections( qubit: QubitType, ): """Add correction ops of a Hadamard gate to the subroutine. + Args: mres (list[builtin.IntegerType]): A list of the mid-measurement results. qubit (QubitType) : An auxiliary result qubit. @@ -393,6 +401,7 @@ def _s_corrections( qubit: QubitType, ): """Add correction ops of a S gate to the subroutine. + Args: mres (list[builtin.IntegerType]): A list of the mid-measurement results. qubit (QubitType) : An auxiliary result qubit. @@ -417,6 +426,7 @@ def _rot_corrections( qubit: QubitType, ): """Add correction ops of a RotXZX or RZ gate to the subroutine. + Args: mres (list[builtin.IntegerType]): A list of the mid-measurement results. qubit (QubitType) : An auxiliary result qubit. @@ -440,9 +450,11 @@ def _cnot_corrections( qubits: list[QubitType], ): """Add correction ops of a CNOT gate to the subroutine. + Args: mres (list[builtin.IntegerType]): A list of the mid-measurement results. qubits (list[QubitType]) : A list of auxiliary result qubits. + Returns: The result auxiliary qubits. """ @@ -465,6 +477,7 @@ def _queue_measurements( self, gate_name: str, graph_qubit_dict, params: None | list[builtin.Float64Type] = None ): """Add measurement ops to the subroutine. + Args: gate_name (str): Gate name. graph_qubit_dict (list[builtin.IntegerType]): A list of the mid-measurement results. @@ -487,7 +500,8 @@ def _queue_measurements( return self._cnot_measurements(graph_qubit_dict) case _: raise ValueError( - f"{gate_name} is not supported in the MBQC formalism. Please decompose it into the MBQC gate set." + f"{gate_name} is not supported in the MBQC formalism. Please decompose it " + "into the MBQC gate set." ) def _insert_byprod_corrections( @@ -518,7 +532,8 @@ def _insert_byprod_corrections( return self._cnot_corrections(mres, qubits) case _: raise ValueError( - f"{gate_name} is not supported in the MBQC formalism. Please decompose it into the MBQC gate set." + f"{gate_name} is not supported in the MBQC formalism. Please decompose it " + "into the MBQC gate set." ) def _create_single_qubit_gate_subroutine(self, gate_name: str): @@ -576,8 +591,10 @@ def _create_single_qubit_gate_subroutine(self, gate_name: str): func.ReturnOp(graph_qubit_dict[5]) region = Region([block]) + # pylint: disable=line-too-long # Note that visibility is set as private to ensure the subroutines that are - # not called (dead code) can be eliminated as the ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # not called (dead code) can be eliminated as the + # ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ # pass was added to the pipeline. funcOp = func.FuncOp( gate_name.lower() + "_in_mbqc", @@ -634,8 +651,10 @@ def _create_cnot_gate_subroutine(self): ) region = Region([block]) + # pylint: disable=line-too-long # Note that visibility is set as private to ensure the subroutines that are - # not called (dead code) can be eliminated as the ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # not called (dead code) can be eliminated as the + # ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ # pass was added to the pipeline. funcOp = func.FuncOp( gate_name.lower() + "_in_mbqc", @@ -647,32 +666,30 @@ def _create_cnot_gate_subroutine(self): funcOp.attributes["mbqc_transform"] = builtin.NoneAttr() return funcOp - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the convert-to-mbqc-formalism pass.""" + # pylint: disable=line-too-long # Insert subroutines for all gates in the MBQC gate set to the module. - # Note that the visibility of those subroutines are set as private, which ensure - # the ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ + # Note that the visibility of those subroutines are set as private, which ensure the + # ["symbol-dce"](https://github.com/PennyLaneAI/catalyst/blob/372c376eb821e830da778fdc8af423eeb487eab6/frontend/catalyst/pipelines.py#L248)_ # pass could eliminate the unreferenced subroutines. subroutine_dict = {} for gate_name in _MBQC_ONE_QUBIT_GATES: funcOp = self._create_single_qubit_gate_subroutine(gate_name) - module.regions[0].blocks.first.add_op(funcOp) + op.regions[0].blocks.first.add_op(funcOp) subroutine_dict[gate_name] = funcOp cnot_funcOp = self._create_cnot_gate_subroutine() - module.regions[0].blocks.first.add_op(cnot_funcOp) + op.regions[0].blocks.first.add_op(cnot_funcOp) subroutine_dict["CNOT"] = cnot_funcOp pattern_rewriter.PatternRewriteWalker( pattern_rewriter.GreedyRewritePatternApplier( - [ - ConvertToMBQCFormalismPattern(subroutine_dict), - ] + [ConvertToMBQCFormalismPattern(subroutine_dict)] ), apply_recursively=False, - ).rewrite_module(module) + ).rewrite_module(op) convert_to_mbqc_formalism_pass = compiler_transform(ConvertToMBQCFormalismPass) @@ -680,18 +697,18 @@ def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: class ConvertToMBQCFormalismPattern( pattern_rewriter.RewritePattern -): # pylint: disable=too-few-public-methods,no-self-use +): # pylint: disable=too-few-public-methods """RewritePattern for converting to the MBQC formalism.""" def __init__(self, subroutines_dict): self.subroutine_dict = subroutines_dict - # pylint: disable=no-self-use @pattern_rewriter.op_type_rewrite_pattern def match_and_rewrite( self, root: func.FuncOp | IfOp | WhileOp | ForOp | IndexSwitchOp, rewriter: pattern_rewriter.PatternRewriter, + /, ): """Match and rewrite for converting to the MBQC formalism.""" diff --git a/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py b/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py index 294c0dd467..0d89b171ba 100644 --- a/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py +++ b/frontend/catalyst/python_interface/transforms/mbqc/decompose_graph_state.py @@ -46,12 +46,11 @@ class DecomposeGraphStatePass(passes.ModulePass): name = "decompose-graph-state" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the decompose-graph-state pass.""" walker = pattern_rewriter.PatternRewriteWalker(DecomposeGraphStatePattern()) - walker.rewrite_module(module) + walker.rewrite_module(op) decompose_graph_state_pass = compiler_transform(DecomposeGraphStatePass) @@ -145,12 +144,11 @@ class NullDecomposeGraphStatePass(passes.ModulePass): name = "null-decompose-graph-state" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the null-decompose-graph-state pass.""" walker = pattern_rewriter.PatternRewriteWalker(NullDecomposeGraphStatePattern()) - walker.rewrite_module(module) + walker.rewrite_module(op) null_decompose_graph_state_pass = compiler_transform(NullDecomposeGraphStatePass) diff --git a/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py b/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py index bf13b09971..f5b7369d3a 100644 --- a/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py +++ b/frontend/catalyst/python_interface/transforms/mbqc/graph_state_utils.py @@ -51,7 +51,8 @@ def get_graph_state_edges(gate_name: str) -> list[tuple[int, int]]: """ Return a list of edges information in the graph state of a gate. - - The connectivity of the target qubits in the register and auxiliary qubits for a single-qubit gate is: + - The connectivity of the target qubits in the register and auxiliary qubits for a + single-qubit gate is: tgt -- 0 -- 1 -- 2 -- 3 @@ -63,10 +64,12 @@ def get_graph_state_edges(gate_name: str) -> list[tuple[int, int]]: (2, 3), ] - Wire 1 in the above isn't the target wire described in the Fig.2 of [`arXiv:quant-ph/0301052 `_], + Wire 1 in the above isn't the target wire described in the Fig.2 of + `arXiv:quant-ph/0301052 `_], 1 in the above maps to 3 in the figure. - - The connectivity of the ctrl/target qubits in the register and auxiliary qubits for a CNOT gate is: + - The connectivity of the ctrl/target qubits in the register and auxiliary qubits for a + CNOT gate is: ctl -- 0 -- 1 -- 2 -- 3 -- 4 -- 5 | @@ -91,9 +94,10 @@ def get_graph_state_edges(gate_name: str) -> list[tuple[int, int]]: (11, 12), ] - This graph is labelled based on the rows and columns of the adjacent matrix, but maps on to the graph described in - the Fig.2 of [`arXiv:quant-ph/0301052 `_], where wire 1 is the control and - wire 9 is the target. + This graph is labelled based on the rows and columns of the adjacent matrix, but maps on + to the graph described in the Fig.2 of + [`arXiv:quant-ph/0301052 `_], where wire 1 is the + control and wire 9 is the target. Args: gate_name (str): The name of a gate. @@ -214,10 +218,10 @@ def edge_iter(adj_matrix: DenselyPackedAdjMatrix) -> Generator[tuple[int, int], def _adj_matrix_generation_helper( num_vertices: int, edges_in_adj_matrix: list[tuple[int, int]] ) -> list: - """Helper function to generate an adjacency matrix to represent the connectivity of auxiliary qubits in - a graph state for a gate operation with the number of vertices and edges information. - Note that the adjacency matrix here means the lower triangular part of the full adjacency matrix. - It can be represented as below and `x` marks here denotes the matrix diagonal. + """Helper function to generate an adjacency matrix to represent the connectivity of auxiliary + qubits in a graph state for a gate operation with the number of vertices and edges information. + Note that the adjacency matrix here means the lower triangular part of the full adjacency + matrix. It can be represented as below and `x` marks here denotes the matrix diagonal. x + x + + x diff --git a/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py b/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py index dc9fbcfc98..6035040c18 100644 --- a/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py +++ b/frontend/catalyst/python_interface/transforms/mbqc/outline_state_evolution.py @@ -17,11 +17,14 @@ Known limitations ----------------- - * If the current pass is applied multiple times, the transform will fail as it would redefined the `state_evolution` func. This is - caused by the way we define the terminal_boundary_op. Each time the pass is applied to the IR, it would insert a new - terminal_boundary_op into the IR. TODOs: Instead of inserting a new `terminal_boundary_op` op to the IR when applying the pass, it - would be better to: 1. define a quantum.terminator op before this pass and use it as a delineation of quantum gate operation; - 2. move the `simplify_io` to a separate pass. + * If the current pass is applied multiple times, the transform will fail as it would redefined + the `state_evolution` func. This is caused by the way we define the terminal_boundary_op. + Each time the pass is applied to the IR, it would insert a new terminal_boundary_op into the + IR. TODOs: Instead of inserting a new `terminal_boundary_op` op to the IR when applying the + pass, it would be better to: + 1. define a quantum.terminator op before this pass and use it as a delineation of + quantum gate operation; + 2. move the `simplify_io` to a separate pass. """ from dataclasses import dataclass @@ -42,13 +45,12 @@ class OutlineStateEvolutionPass(passes.ModulePass): name = "outline-state-evolution" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the outline-state-evolution pass.""" - for op in module.ops: - if isinstance(op, func.FuncOp) and "qnode" in op.attributes: - rewriter = pattern_rewriter.PatternRewriter(op) - OutlineStateEvolutionPattern().match_and_rewrite(op, rewriter) + for op_ in op.ops: + if isinstance(op_, func.FuncOp) and "qnode" in op_.attributes: + rewriter = pattern_rewriter.PatternRewriter(op_) + OutlineStateEvolutionPattern().match_and_rewrite(op_, rewriter) outline_state_evolution_pass = compiler_transform(OutlineStateEvolutionPass) @@ -64,7 +66,8 @@ def _get_parent_module(self, op: func.FuncOp) -> builtin.ModuleOp: pass if op is None: raise RuntimeError( - "The given qnode func is not nested within a builtin.module. Please ensure the qnode func is defined in a builtin.module." + "The given qnode func is not nested within a builtin.module. Please ensure the " + "qnode func is defined in a builtin.module." ) return op @@ -83,8 +86,9 @@ def __init__(self): # State evolution function region self.state_evolution_func: func.FuncOp = None - # pylint: disable=too-many-arguments - def match_and_rewrite(self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + def match_and_rewrite( + self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, / + ): """Transform a quantum function (qnode) to outline state evolution regions. This implementation assumes that there is only one `quantum.alloc` operation in the func operations with a "qnode" attribute and all quantum operations are between @@ -109,13 +113,13 @@ def match_and_rewrite(self, func_op: func.FuncOp, rewriter: pattern_rewriter.Pat # in the qnode func. self._finalize_transformation() - # pylint: disable=no-else-return def _get_qubit_idx(self, op: Operation) -> int | None: """Get the index of qubit that an ExtractOp op extracts.""" if getattr(op, "idx", None): return op.idx return getattr(op, "idx_attr", None) + # pylint: disable=too-many-arguments,too-many-positional-arguments def _set_up_terminal_boundary_op( self, current_reg: quantum.QuregType, @@ -159,7 +163,7 @@ def _set_up_terminal_boundary_op( del qubit_to_reg_idx[qb] return current_reg, prev_qreg, terminal_boundary_op - # pylint: disable=cell-var-from-loop, too-many-branches + # pylint: disable=too-many-branches def _simplify_quantum_io( self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter ) -> func.FuncOp: @@ -285,12 +289,15 @@ def _create_state_evolution_function(self, rewriter: pattern_rewriter.PatternRew ) rewriter.insert_op(state_evolution_func, InsertPoint.at_end(self.module.body.block)) - # TODOs: how to define the `value_mapper` arg is not stated in the xdl.core module [here](https://github.com/xdslproject/xdsl/blob/e1301e0204bcf6ea5ed433e7da00bee57d07e695/xdsl/ir/core.py#L1429)_. - # It looks like storing ssa value to be cloned would maintain the dependency relationship required to build the new DAG for the new ops. + # pylint: disable=line-too-long + # TODOs: how to define the `value_mapper` arg is not stated in the xdl.core module + # [here](https://github.com/xdslproject/xdsl/blob/e1301e0204bcf6ea5ed433e7da00bee57d07e695/xdsl/ir/core.py#L1429)_. + # It looks like storing ssa value to be cloned would maintain the dependency + # relationship required to build the new DAG for the new ops. block = state_evolution_func.regions[0].block value_mapper = {} # only args ssavlue is required - for input, block_arg in zip(ordered_inputs, block.args): - value_mapper[input] = block_arg + for input_, block_arg in zip(ordered_inputs, block.args): + value_mapper[input_] = block_arg self._clone_operations_to_block(ops_to_clone, block, value_mapper) self._add_return_statement(block, ordered_outputs, value_mapper) diff --git a/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py b/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py index 9f5053cc67..51d668b2ba 100644 --- a/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py +++ b/frontend/catalyst/python_interface/transforms/quantum/cancel_inverses.py @@ -57,9 +57,8 @@ class IterativeCancelInversesPattern( ): # pylint: disable=too-few-public-methods """RewritePattern for iteratively cancelling consecutive self-inverse gates.""" - # pylint: disable=no-self-use @pattern_rewriter.op_type_rewrite_pattern - def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, /): """Implementation of rewriting FuncOps that may contain operations corresponding to self-inverse gates.""" for op in funcOp.body.walk(): @@ -91,12 +90,11 @@ class IterativeCancelInversesPass(passes.ModulePass): name = "xdsl-cancel-inverses" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the iterative cancel inverses pass.""" pattern_rewriter.PatternRewriteWalker( pattern_rewriter.GreedyRewritePatternApplier([IterativeCancelInversesPattern()]) - ).rewrite_module(module) + ).rewrite_module(op) iterative_cancel_inverses_pass = compiler_transform(IterativeCancelInversesPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py b/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py index 4b9d027f84..cea795a3e3 100644 --- a/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py +++ b/frontend/catalyst/python_interface/transforms/quantum/combine_global_phases.py @@ -32,10 +32,12 @@ class CombineGlobalPhasesPattern( """RewritePattern for combining all :class:`~pennylane.GlobalPhase` gates within the same region at the last global phase gate.""" - # pylint: disable=no-self-use @pattern_rewriter.op_type_rewrite_pattern def match_and_rewrite( - self, root: func.FuncOp | IfOp | ForOp | WhileOp, rewriter: pattern_rewriter.PatternRewriter + self, + root: func.FuncOp | IfOp | ForOp | WhileOp, + rewriter: pattern_rewriter.PatternRewriter, + /, ): # pylint: disable=cell-var-from-loop """Match and rewrite for the combine-global-phases pattern acting on functions or control-flow blocks containing GlobalPhase operations. @@ -74,13 +76,12 @@ class CombineGlobalPhasesPass(passes.ModulePass): name = "combine-global-phases" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the combine-global-phases pass.""" pattern_rewriter.PatternRewriteWalker( CombineGlobalPhasesPattern(), apply_recursively=False, - ).rewrite_module(module) + ).rewrite_module(op) combine_global_phases_pass = compiler_transform(CombineGlobalPhasesPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py b/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py index 763bff79e9..0e07f013fa 100644 --- a/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py +++ b/frontend/catalyst/python_interface/transforms/quantum/diagonalize_measurements.py @@ -17,15 +17,17 @@ Known Limitations ----------------- - * Only observables PauliX, PauliY, PauliZ, Hadamard and Identity are currently supported when using this transform (but - these are also the only observables currently supported in the Quantum dialect as NamedObservable). - * Unlike the current tape-based implementation of the transform, it doesn't allow for diagonalization of a subset - of observables. - * Unlike the current tape-based implementation of the transform, conversion to measurements based on eigvals - and wires (rather than the PauliZ observable) is not currently supported. - * Unlike the tape-based implementation, this pass will NOT raise an error if given a circuit that is invalid because - it contains non-commuting measurements. It should be assumed that this transform results in incorrect outputs unless - split_non_commuting is applied to break non-commuting measurements into separate tapes. + * Only observables PauliX, PauliY, PauliZ, Hadamard and Identity are currently supported when + using this transform (but these are also the only observables currently supported in the + Quantum dialect as NamedObservable). + * Unlike the current tape-based implementation of the transform, it doesn't allow for + diagonalization of a subset of observables. + * Unlike the current tape-based implementation of the transform, conversion to measurements + based on eigvals and wires (rather than the PauliZ observable) is not currently supported. + * Unlike the tape-based implementation, this pass will NOT raise an error if given a circuit + that is invalid because it contains non-commuting measurements. It should be assumed that + this transform results in incorrect outputs unless split_non_commuting is applied to break + non-commuting measurements into separate tapes. """ from dataclasses import dataclass @@ -78,9 +80,10 @@ class DiagonalizeFinalMeasurementsPattern( ): # pylint: disable=too-few-public-methods """RewritePattern for diagonalizing final measurements.""" - # pylint: disable=no-self-use @pattern_rewriter.op_type_rewrite_pattern - def match_and_rewrite(self, observable: NamedObsOp, rewriter: pattern_rewriter.PatternRewriter): + def match_and_rewrite( + self, observable: NamedObsOp, rewriter: pattern_rewriter.PatternRewriter, / + ): """Replace non-diagonalized observables with their diagonalizing gates and PauliZ.""" if _diagonalize(observable): @@ -126,7 +129,8 @@ def match_and_rewrite(self, observable: NamedObsOp, rewriter: pattern_rewriter.P if num_observables > 1: raise RuntimeError( - "Each wire can only have one set of diagonalizing gates applied, but the circuit contains multiple observables with the same wire." + "Each wire can only have one set of diagonalizing gates applied, but the " + "circuit contains multiple observables with the same wire." ) observable.qubit.replace_by_if(qubit, lambda use: use in uses_to_change) @@ -147,11 +151,10 @@ class DiagonalizeFinalMeasurementsPass(passes.ModulePass): name = "diagonalize-final-measurements" - # pylint: disable= no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the diagonalize final measurements pass.""" pattern_rewriter.PatternRewriteWalker(DiagonalizeFinalMeasurementsPattern()).rewrite_module( - module + op ) diff --git a/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py b/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py index 527c433ff2..0d6ec6b3a4 100644 --- a/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py +++ b/frontend/catalyst/python_interface/transforms/quantum/measurements_from_samples.py @@ -25,8 +25,9 @@ * Usage patterns that are not yet supported with program capture are also not supported in the compilation pass. For example, operator arithmetic is not currently supported, such as qml.expval(qml.Y(0) @ qml.X(1)). - * qml.counts() is not supported since the return type/shape is different in PennyLane and Catalyst. - See https://docs.pennylane.ai/projects/catalyst/en/stable/dev/quick_start.html#measurements + * qml.counts() is not supported since the return type/shape is different in PennyLane and + Catalyst. See + https://docs.pennylane.ai/projects/catalyst/en/stable/dev/quick_start.html#measurements for more information. """ @@ -56,10 +57,9 @@ class MeasurementsFromSamplesPass(passes.ModulePass): name = "measurements-from-samples" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the measurements-from-samples pass.""" - shots = _get_static_shots_value_from_first_device_op(module) + shots = _get_static_shots_value_from_first_device_op(op) greedy_applier = pattern_rewriter.GreedyRewritePatternApplier( [ @@ -70,7 +70,7 @@ def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: ] ) walker = pattern_rewriter.PatternRewriteWalker(greedy_applier, apply_recursively=False) - walker.rewrite_module(module) + walker.rewrite_module(op) measurements_from_samples_pass = compiler_transform(MeasurementsFromSamplesPass) @@ -296,7 +296,8 @@ def insert_constant_int_op( value (int): The integer value. insert_point (InsertPoint): The insertion point for the constant op. rewriter (PatternRewriter): The xDSL pattern rewriter. - value_type (int, optional): The integer value type (i.e. number of bits). Defaults to 64. + value_type (int, optional): The integer value type (i.e. number of bits). + Defaults to 64. Returns: arith.ConstantOp: The created constant op. @@ -424,7 +425,8 @@ def match_and_rewrite( return_types=[builtin.TensorType(builtin.Float64Type(), shape=())], ) - # The op to replace is not the expval/var op itself, but the tensor.from_elements op that follows + # The op to replace is not the expval/var op itself, but the tensor.from_elements + # op that follows op_to_replace = list(matched_op.results[0].uses)[0].operation assert isinstance( op_to_replace, tensor.FromElementsOp diff --git a/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py b/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py index f1477f7c7b..1f4a553bc8 100644 --- a/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py +++ b/frontend/catalyst/python_interface/transforms/quantum/merge_rotations.py @@ -68,9 +68,8 @@ class MergeRotationsPattern( ): # pylint: disable=too-few-public-methods """RewritePattern for merging consecutive composable rotations.""" - # pylint: disable=no-self-use @pattern_rewriter.op_type_rewrite_pattern - def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, /): """Implementation of rewriting FuncOps that may contain operations corresponding to consecutive composable rotations.""" for op in funcOp.body.walk(): @@ -130,12 +129,11 @@ class MergeRotationsPass(passes.ModulePass): name = "xdsl-merge-rotations" - # pylint: disable=no-self-use - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the merge rotations pass.""" pattern_rewriter.PatternRewriteWalker( pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()]) - ).rewrite_module(module) + ).rewrite_module(op) merge_rotations_pass = compiler_transform(MergeRotationsPass) diff --git a/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py b/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py index 8dd0378ed9..3cd8a47039 100644 --- a/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py +++ b/frontend/catalyst/python_interface/transforms/quantum/split_non_commuting.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=line-too-long """This file contains a limited prototype of the split_non_commuting pass. Known Limitations ----------------- - * Only single-term observables with no coefficients are supported - there is no support for CompositeOp or SymbolicOp observables + * Only single-term observables with no coefficients are supported - there is no support + for CompositeOp or SymbolicOp observables * Only the Expval measurement process is supported - * There is no option to specify a grouping strategy (this will be more relevant once CompositeOp support is added) - * Hence, only the "wires" grouping strategy is implemented, not taking into account observable-commutation logic yet. - * There is no efficient handling of duplicate observables - a circuit that returns multiple measurements on the same observable will split into multiple executions (this will be more relevant once CompositeOp support is added) + * There is no option to specify a grouping strategy (this will be more relevant once + CompositeOp support is added) + * Hence, only the "wires" grouping strategy is implemented, not taking into account + observable-commutation logic yet. + * There is no efficient handling of duplicate observables - a circuit that returns + multiple measurements on the same observable will split into multiple executions + (this will be more relevant once CompositeOp support is added) Example: ------------------ @@ -103,12 +109,12 @@ class SplitNonCommutingPass(passes.ModulePass): name = "split-non-commuting" - def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + def apply(self, _ctx: context.Context, op: builtin.ModuleOp) -> None: """Apply the split non-commuting pass to all QNode functions in the module.""" - for op in module.ops: - if isinstance(op, func.FuncOp) and "qnode" in op.attributes: - rewriter = pattern_rewriter.PatternRewriter(op) - SplitNonCommutingPattern().match_and_rewrite(op, rewriter) + for op_ in op.ops: + if isinstance(op_, func.FuncOp) and "qnode" in op_.attributes: + rewriter = pattern_rewriter.PatternRewriter(op_) + SplitNonCommutingPattern().match_and_rewrite(op_, rewriter) split_non_commuting_pass = compiler_transform(SplitNonCommutingPass) @@ -467,7 +473,9 @@ def replace_original_with_calls( return_op = func.ReturnOp(*final_return_values) original_block.add_op(return_op) - def match_and_rewrite(self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + def match_and_rewrite( + self, func_op: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter, / + ): """Split a quantum function into multiple functions using wire-based grouping. Creates one duplicate function per group, where each duplicate function contains diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/visualization/draw.py index 87088b73a1..ac26977404 100644 --- a/frontend/catalyst/python_interface/visualization/draw.py +++ b/frontend/catalyst/python_interface/visualization/draw.py @@ -52,16 +52,18 @@ def draw(qnode: QNode, *, level: None | int = None) -> Callable: """ Draw the QNode at the specified level. - This function can be used to visualize the QNode at different stages of the transformation pipeline - when xDSL or Catalyst compilation passes are applied. + This function can be used to visualize the QNode at different stages of the transformation + pipeline when xDSL or Catalyst compilation passes are applied. If the specified level is not available, the highest level will be used as a fallback. The provided QNode is assumed to be decorated with compilation passes. If no passes are applied, the original QNode is visualized. Args: - qnode (.QNode): the input QNode that is to be visualized. The QNode is assumed to be compiled with ``qjit``. - level (None | int): the level of transformation to visualize. If `None`, the final level is visualized. + qnode (.QNode): the input QNode that is to be visualized. The QNode is assumed to be + compiled with ``qjit``. + level (None | int): the level of transformation to visualize. If `None`, the final + level is visualized. Returns: diff --git a/frontend/catalyst/python_interface/visualization/mlir_graph.py b/frontend/catalyst/python_interface/visualization/mlir_graph.py index 7683feaecf..79196185ba 100644 --- a/frontend/catalyst/python_interface/visualization/mlir_graph.py +++ b/frontend/catalyst/python_interface/visualization/mlir_graph.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This file contains the implementation of the MLIR graph generation for the Unified Compiler framework. +This file contains the implementation of the MLIR graph generation for the Unified +Compiler framework. """ from __future__ import annotations @@ -87,8 +88,9 @@ def generate_mlir_graph(qnode: QNode) -> Callable: Generate an MLIR graph for the given QNode and saves it to a file. This function uses the callback mechanism of the unified compiler framework to generate - the MLIR graph in between compilation passes. The provided QNode is assumed to be decorated with xDSL compilation passes. - The ``qjit`` decorator is used to recompile the QNode with the passes and the provided arguments. + the MLIR graph in between compilation passes. The provided QNode is assumed to be decorated + with xDSL compilation passes. The ``qjit`` decorator is used to recompile the QNode with the + passes and the provided arguments. If no passes are applied, the original QNode is visualized. diff --git a/frontend/catalyst/python_interface/visualization/xdsl_conversion.py b/frontend/catalyst/python_interface/visualization/xdsl_conversion.py index 1fa496892c..917a292e94 100644 --- a/frontend/catalyst/python_interface/visualization/xdsl_conversion.py +++ b/frontend/catalyst/python_interface/visualization/xdsl_conversion.py @@ -158,8 +158,8 @@ def resolve_constant_params(ssa: SSAValue) -> float | int: case "stablehlo.reshape": res_type = op.result_types[0] shape = res_type.get_shape() - type = res_type.get_element_type() - return jax.numpy.array(shape, dtype=int if isinstance(type, IntegerType) else float) + type_ = res_type.get_element_type() + return jax.numpy.array(shape, dtype=int if isinstance(type_, IntegerType) else float) case _: raise NotImplementedError(f"Cannot resolve parameters for operation: {op}") diff --git a/frontend/catalyst/python_interface/xdsl_extras/constraints.py b/frontend/catalyst/python_interface/xdsl_extras/constraints.py index 8c812e35f2..2e813e7823 100644 --- a/frontend/catalyst/python_interface/xdsl_extras/constraints.py +++ b/frontend/catalyst/python_interface/xdsl_extras/constraints.py @@ -153,7 +153,6 @@ def get_bases(self) -> set[type[Attribute]]: return {self.expected_type} def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: - # pylint: disable=missing-function-docstring if not isinstance(attr, self.expected_type): raise VerifyException(f"{attr} should be of type {self.expected_type.__name__}.") constr = self.expected_type.constr(element_type=self.element_type, shape=self.shape) @@ -162,7 +161,6 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None def mapping_type_vars( self, type_var_mapping: dict[TypeVar, AttrConstraint] ) -> "ContainerConstraint": - # pylint: disable=unused-argument,missing-function-docstring return self @@ -230,5 +228,4 @@ def mapping_type_vars( type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint], ) -> AttrConstraint: """Map type variables to constraints.""" - # pylint: disable=unused-argument return self diff --git a/frontend/catalyst/python_interface/xdsl_extras/traits.py b/frontend/catalyst/python_interface/xdsl_extras/traits.py index 092241dd82..937d960a28 100644 --- a/frontend/catalyst/python_interface/xdsl_extras/traits.py +++ b/frontend/catalyst/python_interface/xdsl_extras/traits.py @@ -236,5 +236,6 @@ def verify(self, op: Operation) -> None: if any(res != first for res in results[1:]): results_str = ", ".join(str(r) for r in results) raise VerifyException( - f"all of {{{names_str}}} must have the same {self.summary}: got {self.summary}s {results_str}" + f"all of {{{names_str}}} must have the same {self.summary}: got " + f"{self.summary}s {results_str}" ) From f7aa2b65a8ace67fee49963cb1cc6e1c48692947 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 15:51:13 -0500 Subject: [PATCH 06/38] [skip ci] Skip CI From 012f58b9d299edb2dc0e8a8d9d6f96ac22d6d42a Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 15:51:34 -0500 Subject: [PATCH 07/38] [skip ci] Skip CI From c98755edbe394f514aa9561668ca58057f319ad2 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 16:34:05 -0500 Subject: [PATCH 08/38] Try installing graphviz with apt --- .github/workflows/check-catalyst.yaml | 4 ++-- .../catalyst/python_interface/xdsl_extras/constraints.py | 9 +++++++-- .../python_interface/dialects/test_transform_dialect.py | 1 + .../test/pytest/python_interface/test_python_compiler.py | 3 +++ 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 6ed3d7edfd..7ffc60a8a1 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -466,6 +466,8 @@ jobs: run: | sudo apt-get update sudo apt-get install -y libasan6 make + # Install graphviz for testing the mlir-op-graph integration + sudo apt-get install -y graphviz python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} python3 -m pip install -r requirements.txt # cuda-quantum is added manually here. @@ -473,8 +475,6 @@ jobs: # macOS requirements.txt python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client - # Install graphviz for testing the mlir-op-graph integration - python3 -m pip install graphviz make frontend - name: Get Cached LLVM Build diff --git a/frontend/catalyst/python_interface/xdsl_extras/constraints.py b/frontend/catalyst/python_interface/xdsl_extras/constraints.py index 2e813e7823..23e974fd88 100644 --- a/frontend/catalyst/python_interface/xdsl_extras/constraints.py +++ b/frontend/catalyst/python_interface/xdsl_extras/constraints.py @@ -153,6 +153,7 @@ def get_bases(self) -> set[type[Attribute]]: return {self.expected_type} def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + """Verify that the attribute meets the constraint.""" if not isinstance(attr, self.expected_type): raise VerifyException(f"{attr} should be of type {self.expected_type.__name__}.") constr = self.expected_type.constr(element_type=self.element_type, shape=self.shape) @@ -160,7 +161,11 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None def mapping_type_vars( self, type_var_mapping: dict[TypeVar, AttrConstraint] - ) -> "ContainerConstraint": + ) -> "ContainerConstraint": # pylint: disable=unused-argument + """ + A helper function to make type vars used in attribute definitions concrete when + creating constraints for new attributes or operations. + """ return self @@ -226,6 +231,6 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None def mapping_type_vars( self, type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint], - ) -> AttrConstraint: + ) -> AttrConstraint: # pylint: disable=unused-argument """Map type variables to constraints.""" return self diff --git a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py index ddfc4ad156..64aba9a85b 100644 --- a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py @@ -117,6 +117,7 @@ class _HelloWorld(passes.ModulePass): custom_print: str | None = None def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + """Apply the pass.""" if self.custom_print: print(self.custom_print) else: diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index f8b32f2aaf..3e4b29cc9e 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -82,6 +82,7 @@ def test_compiler(): @mlir_module @jax.jit def identity(x): + """Identity function""" return x input_module = identity(1) @@ -387,6 +388,7 @@ class _(passes.ModulePass): def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: ... def print_between_passes(*_, pass_level=0): + """Print between passes callback.""" if pass_level == 0: return print("hello world") @@ -481,6 +483,7 @@ def test_callback_run_integration(self, capsys): """Test that the callback is integrated into the pass pipeline with the Compiler.run() method""" def print_between_passes(_, module, __, pass_level=0): + """Callback to print something between passes.""" if pass_level == 0: return print("=== Between Pass ===") From bb5bd007557352a67166dcd38ebdaa62872f649b Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 16:36:29 -0500 Subject: [PATCH 09/38] Fix codefactor complaints --- .../catalyst/python_interface/xdsl_extras/constraints.py | 6 ++++-- .../test/pytest/python_interface/test_python_compiler.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/xdsl_extras/constraints.py b/frontend/catalyst/python_interface/xdsl_extras/constraints.py index 23e974fd88..c5e597f20b 100644 --- a/frontend/catalyst/python_interface/xdsl_extras/constraints.py +++ b/frontend/catalyst/python_interface/xdsl_extras/constraints.py @@ -159,9 +159,10 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None constr = self.expected_type.constr(element_type=self.element_type, shape=self.shape) constr.verify(attr, constraint_context) + # pylint: disable=unused-argument def mapping_type_vars( self, type_var_mapping: dict[TypeVar, AttrConstraint] - ) -> "ContainerConstraint": # pylint: disable=unused-argument + ) -> "ContainerConstraint": """ A helper function to make type vars used in attribute definitions concrete when creating constraints for new attributes or operations. @@ -228,9 +229,10 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None if not matched: raise VerifyException(f"tuple leaf {i} failed all allowed constraints: {leaf}") + # pylint: disable=unused-argument def mapping_type_vars( self, type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint], - ) -> AttrConstraint: # pylint: disable=unused-argument + ) -> AttrConstraint: """Map type variables to constraints.""" return self diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index 3e4b29cc9e..0bc84fe4b3 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -62,6 +62,7 @@ class HelloWorldPass(passes.ModulePass): name = "hello-world" def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + """Apply the pass.""" print("hello world") @@ -387,6 +388,8 @@ class _(passes.ModulePass): def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: ... + """Dummy apply pass.""" + def print_between_passes(*_, pass_level=0): """Print between passes callback.""" if pass_level == 0: From 5e745f5975d2d5371c2a0d2c498707db96698ac7 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 16:38:10 -0500 Subject: [PATCH 10/38] Try appeasing codefactor again --- frontend/test/pytest/python_interface/test_python_compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index 0bc84fe4b3..e4b6bddaf1 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -384,12 +384,12 @@ def test_callback_integration(self, capsys): @compiler_transform @dataclass(frozen=True) class _(passes.ModulePass): + """Dummy pass for testing.""" + name = "none-pass" def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: ... - """Dummy apply pass.""" - def print_between_passes(*_, pass_level=0): """Print between passes callback.""" if pass_level == 0: From 040993396b25b68753f004e6f8f9f956eb003b78 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 16:40:00 -0500 Subject: [PATCH 11/38] Try appeasing codefactor once again --- .../test/pytest/python_interface/test_python_compiler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index e4b6bddaf1..0a1cdcb9ff 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -383,12 +383,14 @@ def test_callback_integration(self, capsys): @compiler_transform @dataclass(frozen=True) - class _(passes.ModulePass): + class NonePass(passes.ModulePass): """Dummy pass for testing.""" name = "none-pass" - def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: ... + def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: + """Apply the pass. Do nothing; the test if for callbacks.""" + return def print_between_passes(*_, pass_level=0): """Print between passes callback.""" From e786a1242852c72910b4611dc068cc9958a159a7 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 16:40:47 -0500 Subject: [PATCH 12/38] Pylint suppression --- frontend/test/pytest/python_interface/test_python_compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index 0a1cdcb9ff..12c329f829 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -381,6 +381,7 @@ class TestCallbackIntegration: def test_callback_integration(self, capsys): """Test that the callback mechanism works with the transform interpreter""" + # pylint: disable=unused-variable @compiler_transform @dataclass(frozen=True) class NonePass(passes.ModulePass): From 4ae10ff8585e9c8f2c6ad18c08cf5f9f5ad0324c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 17 Nov 2025 17:04:54 -0500 Subject: [PATCH 13/38] Reduce complexity of stablehlo.reduce and stablehlo.dynamic_broadcast_in_dim verify methods --- .../dialects/stablehlo/dynamism.py | 64 ++++++++++++------- .../dialects/stablehlo/reduction.py | 15 ++++- .../test_draw_unified_compiler.py | 2 +- 3 files changed, 53 insertions(+), 28 deletions(-) diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py index 6da7b376d8..fa1a8d0e17 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py @@ -90,14 +90,28 @@ class DynamicBroadcastInDimOp(IRDLOperation): # pylint: disable=too-many-branches def verify_(self): """Verify the operation.""" - # Operand and result must be tensors operand_ty = self.operand_types[0] result_ty = self.result_types[0] + bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member + + # Operand and result must be tensors assert isinstance(operand_ty, TensorType) and isinstance(result_ty, TensorType) - # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank - bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member + # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(bcast_dims)) != len(bcast_dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + self._verify_rank_constraints(bcast_dims, operand_ty, result_ty) + self._verify_per_dimension_bounds(bcast_dims, operand_ty, result_ty) + self._verify_expansion_hints(operand_ty) + + def _verify_rank_constraints(self, bcast_dims, operand_ty, result_ty): + """Verify then operand and result tensors against the rank constraints.""" + operand_rank = operand_ty.get_num_dims() + result_rank = result_ty.get_num_dims() + + # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank if len(bcast_dims) != operand_rank: raise VerifyException( "broadcast_dimensions size (" @@ -108,7 +122,6 @@ def verify_(self): ) # dynamic_broadcast_in_dim_c3: result rank >= operand rank - result_rank = result_ty.get_num_dims() if result_rank < operand_rank: raise VerifyException( "result rank (" @@ -118,13 +131,29 @@ def verify_(self): ")" ) - # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates - if len(set(bcast_dims)) != len(bcast_dims): - raise VerifyException("broadcast_dimensions should not have duplicates") + # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank + out_dims_ty = self.output_dimensions.type # pylint: disable=no-member + assert isinstance(out_dims_ty, TensorType) + # Must be rank-1 tensor (enforced by type constraint), and length must match result + # rank when statically known + out_shape = out_dims_ty.get_shape() + if len(out_shape) != 1: + raise VerifyException("output_dimensions must be a 1D tensor") + if out_shape[0] != -1 and out_shape[0] != result_rank: + raise VerifyException( + "length of output_dimensions (" + f"{out_shape[0]}" + ") is not compatible with result rank (" + f"{result_rank}" + ")" + ) + def _verify_per_dimension_bounds(self, bcast_dims, operand_ty, result_ty): # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility operand_shape = operand_ty.get_shape() result_shape = result_ty.get_shape() + result_rank = result_ty.get_num_dims() + for i, dim_index in enumerate(bcast_dims): if dim_index < 0 or dim_index >= result_rank: raise VerifyException( @@ -142,24 +171,11 @@ def verify_(self): f"{dim_index} ({res_dim})" ) - # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank - out_dims_ty = self.output_dimensions.type # pylint: disable=no-member - assert isinstance(out_dims_ty, TensorType) - # Must be rank-1 tensor (enforced by type constraint), and length must match result - # rank when statically known - out_shape = out_dims_ty.get_shape() - if len(out_shape) != 1: - raise VerifyException("output_dimensions must be a 1D tensor") - if out_shape[0] != -1 and out_shape[0] != result_rank: - raise VerifyException( - "length of output_dimensions (" - f"{out_shape[0]}" - ") is not compatible with result rank (" - f"{result_rank}" - ")" - ) - + def _verify_expansion_hints(self, operand_ty): + """Verify the operation's expansion hints.""" # dynamic_broadcast_in_dim_c8: no duplicate expansion hints across both lists + operand_rank = operand_ty.get_num_dims() + hints = [] if self.known_expanding_dimensions is not None: hints.extend(self.known_expanding_dimensions.get_values()) # pylint: disable=no-member diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py index 28b332d22f..8f84b7b7b9 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py @@ -84,9 +84,11 @@ def verify_(self): input_types = [op.type for op in self.inputs] init_types = [op.type for op in self.init_values] - # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility - dims_attr = self.dimensions - dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple() + self._verify_input_and_init_types(input_types, init_types) + self._verify_reducer_region(input_types) + + def _verify_input_and_init_types(self, input_types, init_types): + """Verify the types of the inputs and init values.""" # Basic structural checks mirroring verifyReduceOpInputsAndInferShape if len(input_types) == 0: @@ -94,6 +96,10 @@ def verify_(self): if len(input_types) != len(init_types): raise VerifyException("number of inputs must match number of init_values") + # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility + dims_attr = self.dimensions + dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple() + # All inputs must have equal rank; dimensions must be within rank and unique # and not empty. ranks = [] @@ -119,6 +125,9 @@ def verify_(self): if it_elem != iv_elem: raise VerifyException("input and init_value must have the same element type") + def _verify_reducer_region(self, input_types): + """Verify the operation's reducer region.""" + # reduce_c2/c6: verify reducer region shape # Expect block with arity 2 * number of inputs, with matching tensor element types # and 0D tensors diff --git a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py index 80cfe437ad..26fa8d47ea 100644 --- a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py +++ b/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py @@ -78,7 +78,7 @@ def circ(): ), (2, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), (None, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), - (50, "0: ──RX──RZ─┤ State1: ──RY─────┤ State2: ──RZ─────┤ State"), + (50, "0: ──RX──RZ─┤ State\n1: ──RY─────┤ State\n2: ──RZ─────┤ State"), ], ) def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected): From dc41a7ba05ead8966b88ac4836ed0cd75c9fe497 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 10:20:48 -0500 Subject: [PATCH 14/38] Fix some failures --- .github/workflows/check-catalyst.yaml | 4 ++-- frontend/catalyst/jax_primitives_utils.py | 4 +--- .../python_interface/dialects/stablehlo/dynamism.py | 5 ++++- .../python_interface/dialects/stablehlo/reduction.py | 1 + frontend/test/lit/test_xdsl_passes.py | 3 ++- frontend/test/pytest/python_interface/conftest.py | 7 ++++--- frontend/test/pytest/test_jit_behaviour.py | 2 +- 7 files changed, 15 insertions(+), 11 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 7ffc60a8a1..43915978d9 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -466,8 +466,6 @@ jobs: run: | sudo apt-get update sudo apt-get install -y libasan6 make - # Install graphviz for testing the mlir-op-graph integration - sudo apt-get install -y graphviz python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} python3 -m pip install -r requirements.txt # cuda-quantum is added manually here. @@ -475,6 +473,8 @@ jobs: # macOS requirements.txt python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client + # Install graphviz for testing the mlir-op-graph integration + sudo apt-get install -y graphviz make frontend - name: Get Cached LLVM Build diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 8d7e27acfe..749a8b3125 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -378,9 +378,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin try: # pylint: disable=import-outside-toplevel - from pennylane.compiler.python_compiler.pass_api import ( - is_xdsl_pass, - ) + from catalyst.python_interface.pass_api import is_xdsl_pass if is_xdsl_pass(_pass.name): uses_xdsl_passes = True diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py index fa1a8d0e17..e38b4986f7 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py @@ -97,12 +97,14 @@ def verify_(self): # Operand and result must be tensors assert isinstance(operand_ty, TensorType) and isinstance(result_ty, TensorType) + self._verify_rank_constraints(bcast_dims, operand_ty, result_ty) + # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates if len(set(bcast_dims)) != len(bcast_dims): raise VerifyException("broadcast_dimensions should not have duplicates") - self._verify_rank_constraints(bcast_dims, operand_ty, result_ty) self._verify_per_dimension_bounds(bcast_dims, operand_ty, result_ty) + self._verify_expansion_hints(operand_ty) def _verify_rank_constraints(self, bcast_dims, operand_ty, result_ty): @@ -149,6 +151,7 @@ def _verify_rank_constraints(self, bcast_dims, operand_ty, result_ty): ) def _verify_per_dimension_bounds(self, bcast_dims, operand_ty, result_ty): + """Verify compatibility of operand and result dimensions.""" # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility operand_shape = operand_ty.get_shape() result_shape = result_ty.get_shape() diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py index 8f84b7b7b9..7e9bdcfdf8 100644 --- a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py +++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py @@ -85,6 +85,7 @@ def verify_(self): init_types = [op.type for op in self.init_values] self._verify_input_and_init_types(input_types, init_types) + self._verify_reducer_region(input_types) def _verify_input_and_init_types(self, input_types, init_types): diff --git a/frontend/test/lit/test_xdsl_passes.py b/frontend/test/lit/test_xdsl_passes.py index 88c113c6f6..57cd305560 100644 --- a/frontend/test/lit/test_xdsl_passes.py +++ b/frontend/test/lit/test_xdsl_passes.py @@ -19,7 +19,8 @@ """ import pennylane as qml -from pennylane.compiler.python_compiler.transforms import merge_rotations_pass + +from catalyst.python_interface.transforms import merge_rotations_pass def test_mlir_pass_no_attribute(): diff --git a/frontend/test/pytest/python_interface/conftest.py b/frontend/test/pytest/python_interface/conftest.py index d0d7cf38ce..ca18891cd6 100644 --- a/frontend/test/pytest/python_interface/conftest.py +++ b/frontend/test/pytest/python_interface/conftest.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pytest configuration for tests for the pennylane.compiler.python_compiler submodule.""" +"""Pytest configuration for tests for the catalyst.python_interface submodule.""" from inspect import getsource from io import StringIO import pytest +from catalyst.python_interface import Compiler, QuantumParser +from catalyst.python_interface.conversion import parse_generic_to_xdsl_module + deps_available = True try: @@ -30,8 +33,6 @@ from xdsl.passes import PassPipeline from xdsl.printer import Printer - from catalyst.python_interface import Compiler, QuantumParser - from catalyst.python_interface.conversion import parse_generic_to_xdsl_module except (ImportError, ModuleNotFoundError): deps_available = False diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index 6d33ec22eb..2c9737daee 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -905,7 +905,7 @@ def g(x: float): def test_mlir_opt_using_xdsl_passes(self, backend): """Test mlir opt using xDSL passes.""" # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler.transforms import iterative_cancel_inverses_pass + from catalyst.python_interface.transforms import iterative_cancel_inverses_pass @qjit @iterative_cancel_inverses_pass From 58b13e4e0c8b62240afd8b281379b8467fd0ef30 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 10:55:21 -0500 Subject: [PATCH 15/38] Try change to graphviz installation --- .github/workflows/check-catalyst.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 43915978d9..bdc4515a2a 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -474,7 +474,7 @@ jobs: python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client # Install graphviz for testing the mlir-op-graph integration - sudo apt-get install -y graphviz + sudo apt-get install -y graphviz && sudo apt-get update make frontend - name: Get Cached LLVM Build From 637d2cf8bb2a98dbab27d514b488410fd6a46dd4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 11:10:26 -0500 Subject: [PATCH 16/38] Try installing graphviz with pip --- .github/workflows/check-catalyst.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index bdc4515a2a..a1046d087e 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -474,7 +474,8 @@ jobs: python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client # Install graphviz for testing the mlir-op-graph integration - sudo apt-get install -y graphviz && sudo apt-get update + # sudo apt-get install -y graphviz && sudo apt-get update + python3 -m pip install graphviz make frontend - name: Get Cached LLVM Build From 64e0fd4ed2c90fb10821e582f83f3f96d2f7bc87 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 11:29:23 -0500 Subject: [PATCH 17/38] Try installing graphviz with both apt and pip --- .github/workflows/check-catalyst.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index a1046d087e..8f3ee26079 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -474,7 +474,7 @@ jobs: python3 -m pip install cuda-quantum==0.6.0 python3 -m pip install oqc-qcaas-client # Install graphviz for testing the mlir-op-graph integration - # sudo apt-get install -y graphviz && sudo apt-get update + sudo apt-get install -y graphviz python3 -m pip install graphviz make frontend From dc15d3a3eb171f5d70f50a8482745d1b279d41d1 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 13:19:37 -0500 Subject: [PATCH 18/38] Add utils file to remove conftest imports --- frontend/test/pytest/conftest.py | 4 ---- .../test/pytest/device/test_decomposition.py | 5 +++-- .../visualization/test_mlir_graph.py | 2 +- frontend/test/pytest/test_custom_devices.py | 2 +- .../pytest/test_measurement_transforms.py | 2 +- .../test/pytest/test_measurements_results.py | 2 +- frontend/test/pytest/test_preprocess.py | 2 +- frontend/test/pytest/utils.py | 22 +++++++++++++++++++ 8 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 frontend/test/pytest/utils.py diff --git a/frontend/test/pytest/conftest.py b/frontend/test/pytest/conftest.py index 685bb20f8e..a37fbd5959 100644 --- a/frontend/test/pytest/conftest.py +++ b/frontend/test/pytest/conftest.py @@ -16,16 +16,12 @@ """ import os -import pathlib from tempfile import TemporaryDirectory from textwrap import dedent import pennylane as qml import pytest -TEST_PATH = os.path.dirname(__file__) -CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml") - @pytest.fixture(scope="function") def create_temporary_toml_file(request) -> str: diff --git a/frontend/test/pytest/device/test_decomposition.py b/frontend/test/pytest/device/test_decomposition.py index 2889545f40..9a022e55e0 100644 --- a/frontend/test/pytest/device/test_decomposition.py +++ b/frontend/test/pytest/device/test_decomposition.py @@ -22,13 +22,14 @@ import pennylane as qml import pytest from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties +from utils import CONFIG_CUSTOM_DEVICE from catalyst import CompileError, ctrl, qjit from catalyst.compiler import get_lib_path from catalyst.device.decomposition import catalyst_decomposer -TEST_PATH = os.path.dirname(__file__) -CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../../custom_device/custom_device.toml") +# TEST_PATH = os.path.dirname(__file__) +# CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../../custom_device/custom_device.toml") class TestGateAliases: diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py index 6e564637b8..9753bcf88a 100644 --- a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py +++ b/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py @@ -143,7 +143,7 @@ def _(x, y, w1, w2): tmp_path, { "QNode_level_0_no_transforms.svg", - "QNode_level_1_after_remove-chained-self-inverse.svg", + "QNode_level_1_after_cancel-inverses.svg", "QNode_level_2_after_merge-rotations.svg", }, ) diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index b757ff63e7..afeb101a8a 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -16,7 +16,7 @@ import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE +from utils import CONFIG_CUSTOM_DEVICE from catalyst import measure, qjit from catalyst.compiler import get_lib_path diff --git a/frontend/test/pytest/test_measurement_transforms.py b/frontend/test/pytest/test_measurement_transforms.py index 2ccfee0c7c..e2aae968e6 100644 --- a/frontend/test/pytest/test_measurement_transforms.py +++ b/frontend/test/pytest/test_measurement_transforms.py @@ -24,10 +24,10 @@ import numpy as np import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE from pennylane.devices import Device from pennylane.devices.capabilities import OperatorProperties from pennylane.transforms import split_non_commuting, split_to_single_terms +from utils import CONFIG_CUSTOM_DEVICE from catalyst import qjit from catalyst.compiler import get_lib_path diff --git a/frontend/test/pytest/test_measurements_results.py b/frontend/test/pytest/test_measurements_results.py index f4938afd22..1f5df00e1e 100644 --- a/frontend/test/pytest/test_measurements_results.py +++ b/frontend/test/pytest/test_measurements_results.py @@ -18,8 +18,8 @@ import numpy as np import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE from jax import numpy as jnp +from utils import CONFIG_CUSTOM_DEVICE from catalyst import CompileError, qjit from catalyst.device import get_device_capabilities diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index a918ec09d4..5e89db1ca9 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -19,10 +19,10 @@ import numpy as np import pennylane as qml import pytest -from conftest import CONFIG_CUSTOM_DEVICE from pennylane.devices import Device, NullQubit from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties from pennylane.tape import QuantumScript +from utils import CONFIG_CUSTOM_DEVICE from catalyst import CompileError, ctrl, qjit from catalyst.api_extensions.control_flow import ( diff --git a/frontend/test/pytest/utils.py b/frontend/test/pytest/utils.py new file mode 100644 index 0000000000..68435015f6 --- /dev/null +++ b/frontend/test/pytest/utils.py @@ -0,0 +1,22 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pytest utilities for the Catalyst test suite. +""" + +import os +import pathlib + +TEST_PATH = os.path.dirname(__file__) +CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml") From f919e8d5c4b8156f4ec7750ba81881eb6d10576f Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 13:20:58 -0500 Subject: [PATCH 19/38] Remove unused imports --- frontend/test/pytest/device/test_decomposition.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/frontend/test/pytest/device/test_decomposition.py b/frontend/test/pytest/device/test_decomposition.py index 9a022e55e0..5f3594de1f 100644 --- a/frontend/test/pytest/device/test_decomposition.py +++ b/frontend/test/pytest/device/test_decomposition.py @@ -14,8 +14,6 @@ """Unit test module for catalyst/device/decomposition.py""" -import os -import pathlib import platform import numpy as np From 005ba9fb9bc800da4baf9629f1f112137986d22d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 13:37:17 -0500 Subject: [PATCH 20/38] Add graphviz dependencies to lightning.kokkos testing workflow --- .github/workflows/check-catalyst.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 8f3ee26079..d622074d73 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -559,6 +559,9 @@ jobs: sudo apt-get install -y libasan6 make python3 --version | grep ${{ needs.constants.outputs.primary_python_version }} python3 -m pip install -r requirements.txt + # Install graphviz for testing the mlir-op-graph integration + sudo apt-get install -y graphviz + python3 -m pip install graphviz make frontend - name: Get Cached LLVM Build From 61171897f64a3ff75f5739a3da6339c4a8e3d84c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 13:39:14 -0500 Subject: [PATCH 21/38] Update cookbook --- .../doc/unified_compiler_cookbook.rst | 12 ++++++------ .../doc/xdsl_dummy_quantum_subroutines.rst | 4 ++-- .../python_interface/doc/xdsl_post_processing.rst | 2 +- .../python_interface/doc/xdsl_utils_tutorial.rst | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst index b2f88b6060..1525c030ee 100644 --- a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst +++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst @@ -351,7 +351,7 @@ dialect), and check in the body of the method if the op is a .. code-block:: python from xdsl import pattern_rewriter - from pennylane.compiler.python_compiler.quantum_dialect import CustomOp + from catalyst.python_interface.dialects.quantum import CustomOp class MyPattern(pattern_rewriter.RewritePattern): """Dummy class for example.""" @@ -397,7 +397,7 @@ updates all ``Hadamard``\ s with ``PauliX``\ s: from xdsl import pattern_rewriter from xdsl.dialects import builtin - from pennylane.compiler.python_compiler.dialects.quantum import CustomOp + from catalyst.python_interface.dialects.quantum import CustomOp class HToXPattern(pattern_rewriter.RewritePattern): """Dummy class for example.""" @@ -441,8 +441,8 @@ replaces all ``Hadamard``\ s with ``PauliX``\ s from xdsl import passes, pattern_rewriter from xdsl.dialects import builtin - from pennylane.compiler.python_compiler.dialects.quantum import CustomOp - from pennylane.compiler.python_compiler import compiler_transform + from catalyst.python_interface.dialects.quantum import CustomOp + from catalyst.python_interface import compiler_transform class HToXPattern(pattern_rewriter.RewritePattern): """Dummy class for example.""" @@ -503,7 +503,7 @@ the “PennyLane integration” section below. .. code-block:: python import pennylane as qml - from pennylane.compiler.python_compiler.conversion import xdsl_from_qjit + from catalyst.python_interface.conversion import xdsl_from_qjit dev = qml.device("lightning.qubit", wires=3) @@ -985,7 +985,7 @@ currently accessible as .. code-block:: python - from pennylane.compiler.python_compiler import compiler_transform + from catalyst.python_interface import compiler_transform class MyPass(xdsl.passes.ModulePass): """MyPass that does something""" diff --git a/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst index 1a951c5cc7..a1e3a53324 100644 --- a/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst +++ b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst @@ -3,8 +3,8 @@ from dataclasses import dataclass import pennylane as qml - from pennylane.compiler.python_compiler.conversion import xdsl_from_qjit - from pennylane.compiler.python_compiler.dialects.quantum import CustomOp, QubitType + from catalyst.python_interface.conversion import xdsl_from_qjit + from catalyst.python_interface.dialects.quantum import CustomOp, QubitType from xdsl import context, passes, pattern_rewriter from xdsl.builder import ImplicitBuilder diff --git a/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst index 3a0924fd23..b9ca0ac3ae 100644 --- a/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst +++ b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst @@ -7,7 +7,7 @@ Simple tutorial for injecting functions into xDSL modules import jax import pennylane as qml - from pennylane.compiler.python_compiler.conversion import inline_module, xdsl_from_qjit, xdsl_module + from catalyst.python_interface.conversion import inline_module, xdsl_from_qjit, xdsl_module from xdsl import context, passes, pattern_rewriter from xdsl.dialects import builtin, func diff --git a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst index 532a384e1a..8c0163a3de 100644 --- a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst +++ b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst @@ -2,13 +2,13 @@ Python compiler utilities ========================= All utilities we care about are in the -``pennylane.compiler.python_compiler.conversion`` submodule. +``catalyst.python_interface.conversion`` submodule. .. code-block:: python import pennylane as qml - from pennylane.compiler.python_compiler.conversion import ( + from catalyst.python_interface.conversion import ( inline_jit_to_module, inline_module, xdsl_from_qjit, From 6961324bd23491793520421ab1da2d3f3ab351d6 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 14:24:42 -0500 Subject: [PATCH 22/38] Migrate all changelog entries from PennyLane --- doc/releases/changelog-dev.md | 70 ++++++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8e09946363..793111384f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -7,6 +7,48 @@

Improvements 🛠

+* Add an experimental `outline_state_evolution_pass` xDSL pass to `catalyst.python_interface.transforms`, + which moves all quantum gate operations to a private callable. + [(#8367)](https://github.com/PennyLaneAI/pennylane/pull/8367) + +* A new experimental `split_non_commuting_pass` compiler pass has been added to + `catalyst.python_interface.transforms`. This pass splits quantum functions that + measure observables on the same wires into multiple function executions, where + each execution measures observables on different wires (using the "wires" grouping + strategy). The original function is replaced with calls to these generated functions, + and the results are combined appropriately. + [(#8531)](https://github.com/PennyLaneAI/pennylane/pull/8531) + +* Add the `PCPhaseOp` operation to the xDSL Quantum dialect. + [(#8621)](https://github.com/PennyLaneAI/pennylane/pull/8621) + +* Users can now apply xDSL passes without the need to pass the `pass_plugins` argument to + the `qjit` decorator. + [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572) + [(#8573)](https://github.com/PennyLaneAI/pennylane/pull/8573) + [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169) + [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183) + +* The :meth:`catalyst.python_interface.transforms.convert_to_mbqc_formalism_pass` now + supports :class:`~xdsl.dialects.scf.IndexSwitchOp` in IR and ignores regions that have no body. + [(#8632)](https://github.com/PennyLaneAI/pennylane/pull/8632) + +* The `convert_to_mbqc_formalism` compilation pass now outlines the operations to represent a gate + in the MBQC formalism into subroutines in order to reduce the IR size for large programs. + [(#8619)](https://github.com/PennyLaneAI/pennylane/pull/8619) + +* The :meth:`catalyst.python_interface.Compiler.run` method now accepts a string as input, + which is parsed and transformed with xDSL. + [(#8587)](https://github.com/PennyLaneAI/pennylane/pull/8587) + +* An `is_xdsl_pass` function has been added to the `catalyst.python_interface.pass_api` module. + This function checks if a pass name corresponds to an xDSL implemented pass. + [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572) + +* A new `catalyst.python_interface.utils` submodule has been added, containing general-purpose utilities for + working with xDSL. This includes a function that extracts the concrete value of scalar, constant SSA values. + [(#8514)](https://github.com/PennyLaneAI/pennylane/pull/8514) + * Pass instrumentation can be applied to each pass within the `NamedSequenceOp` transform sequence for a qnode. [(#1978)](https://github.com/PennyLaneAI/catalyst/pull/1978) @@ -35,11 +77,6 @@ * `qml.grad` and `qml.jacobian` can now be used with `qjit` when program capture is enabled. [(#2078)](https://github.com/PennyLaneAI/catalyst/pull/2078) -* xDSL passes are now automatically detected when using the `qjit` decorator. - This removes the need to pass the `pass_plugins` argument to the `qjit` decorator. - [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169) - [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183) - * The ``mlir_opt`` property now correctly handles xDSL passes by automatically detecting when the Python compiler is being used and routing through it appropriately. [(#2190)](https://github.com/PennyLaneAI/catalyst/pull/2190) @@ -72,6 +109,20 @@

Bug fixes 🐛

+* The experimental xDSL :func:`~catalyst.python_interface.transforms.measurements_from_samples_pass` + pass has been updated to support `shots` defined by an `arith.constant` operation. + [(#8460)](https://github.com/PennyLaneAI/pennylane/pull/8460) + +* The experimental xDSL :func:`~catalyst.python_interface.transforms.diagonalize_measurements` + pass has been updated to fix a bug that included the wrong SSA value for final qubit insertion + and deallocation at the end of the circuit. A clear error is now also raised when there are + observables with overlapping wires. + [(#8383)](https://github.com/PennyLaneAI/pennylane/pull/8383) + +* Fixes a bug in the constructor of the xDSL Quantum dialect's `QubitUnitaryOp` that + prevented an instance from being constructed. + [(#8456)](https://github.com/PennyLaneAI/pennylane/pull/8456) + * Fixes an issue where a heap-to-stack allocation conversion pass was causing SIGSEGV issues during program execution at runtime. [(#2172)](https://github.com/PennyLaneAI/catalyst/pull/2172) @@ -122,6 +173,10 @@

Internal changes ⚙️

+* Migrated the `pennylane.compiler.python_compiler` submodule from PennyLane to Catalyst. + It is now available as `catalyst.python_interface`. + [(#2199)](https://github.com/PennyLaneAI/catalyst/pull/2199) + * Replaces the deprecated `shape_dtype_to_ir_type` function with the `RankedTensorType.get` method. [(#2159)](https://github.com/PennyLaneAI/catalyst/pull/2159) @@ -173,6 +228,11 @@

Documentation 📝

+* Added a "Unified Compiler Cookbook" RST file, along with tutorials, to `catalyst.python_interface.doc`, + which provides a quickstart guide for getting started with xDSL and its integration with PennyLane and + Catalyst. + [(#8571)](https://github.com/PennyLaneAI/pennylane/pull/8571) + * A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected. [(#2136)](https://github.com/PennyLaneAI/catalyst/pull/2136) From a868bc386d380301bcc5f3a2ad47cb72e9ab6b1d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 14:25:31 -0500 Subject: [PATCH 23/38] Add EOF new line to .codecov.yml --- .codecov.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.codecov.yml b/.codecov.yml index dc9a2b56c6..057bd2b172 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -9,4 +9,4 @@ coverage: patch: true ignore: - - "frontend/catalyst/python_interface" \ No newline at end of file + - "frontend/catalyst/python_interface" From 8f00b2f67fa7f53444724fdcc9c0de22b951d8fb Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 18 Nov 2025 14:32:26 -0500 Subject: [PATCH 24/38] change changelog entry slightly --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 793111384f..b135975266 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -174,7 +174,7 @@

Internal changes ⚙️

* Migrated the `pennylane.compiler.python_compiler` submodule from PennyLane to Catalyst. - It is now available as `catalyst.python_interface`. + It is now accessible as `catalyst.python_interface`. [(#2199)](https://github.com/PennyLaneAI/catalyst/pull/2199) * Replaces the deprecated `shape_dtype_to_ir_type` function with the `RankedTensorType.get` method. From 2e9c1fc813ef198fb07edf19a742a9b6d9dcbd71 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 20 Nov 2025 17:16:43 -0500 Subject: [PATCH 25/38] Create pattern match function --- .../python_interface/transforms/__init__.py | 3 + .../transforms/pattern_match.py | 218 ++++++++++++++++++ 2 files changed, 221 insertions(+) create mode 100644 frontend/catalyst/python_interface/transforms/pattern_match.py diff --git a/frontend/catalyst/python_interface/transforms/__init__.py b/frontend/catalyst/python_interface/transforms/__init__.py index 905893b2bf..006fedb511 100644 --- a/frontend/catalyst/python_interface/transforms/__init__.py +++ b/frontend/catalyst/python_interface/transforms/__init__.py @@ -23,6 +23,7 @@ null_decompose_graph_state_pass, outline_state_evolution_pass, ) +from .pattern_match import pattern_match from .quantum import ( CombineGlobalPhasesPass, DiagonalizeFinalMeasurementsPass, @@ -61,4 +62,6 @@ "outline_state_evolution_pass", "null_decompose_graph_state_pass", "NullDecomposeGraphStatePass", + # Pattern matching + "pattern_match", ] diff --git a/frontend/catalyst/python_interface/transforms/pattern_match.py b/frontend/catalyst/python_interface/transforms/pattern_match.py new file mode 100644 index 0000000000..b6f2fc0b9a --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/pattern_match.py @@ -0,0 +1,218 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Function for applying high-level pattern-matching to xDSL modules.""" + +from functools import partial, wraps +from inspect import signature +from typing import Callable + +import pennylane as qml +from xdsl.builder import ImplicitBuilder +from xdsl.context import Context +from xdsl.dialects import builtin, func, pdl +from xdsl.ir import Block, Operation, Region +from xdsl.passes import PassPipeline +from xdsl.traits import SymbolTable +from xdsl.transforms.apply_pdl import ApplyPDLPass + +from catalyst.jax_primitives import decomposition_rule +from catalyst.jit import QJIT +from catalyst.python_interface import QuantumParser +from catalyst.python_interface.conversion import xdsl_from_qjit +from catalyst.python_interface.dialects import quantum + +qml.capture.enable() + + +def _rewrite_mod( + mod: builtin.ModuleOp, pattern: Callable, rewrite: Callable, num_wires: int +) -> None: + + args = list(range(num_wires)) + + @xdsl_from_qjit + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=num_wires)) + def mod_fn(): + @decomposition_rule(is_qreg=False, num_params=0) + def pattern__(*args__): + return pattern(*args__) + + @decomposition_rule(is_qreg=False, num_params=0) + def rewrite__(*args__): + return rewrite(*args__) + + pattern__(*args) + rewrite__(*args) + + return qml.state() + + dummy_mod = mod_fn() + + qnode_mod: builtin.ModuleOp = [ + o for o in dummy_mod.body.ops if isinstance(o, builtin.ModuleOp) + ][0] + pattern_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "pattern__") + rewrite_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "rewrite__") + + pattern_fn.body.block.erase_op(pattern_fn.body.block.last_op) + rewrite_fn.body.block.erase_op(rewrite_fn.body.block.last_op) + + def create_pdl_op( + op: Operation, + cur_values: list[pdl.ValueType], + value_map: dict[quantum.QubitSSAValue, int], + pdl_value_map: dict[pdl.ValueType, int], + type_map: dict[quantum.QubitSSAValue, pdl.TypeType], + ): + attr_vals = [] + attr_names = [] + for attr_name, attr in (op.properties | op.attributes).items(): + attr_names.append(builtin.StringAttr(attr_name)) + attr_vals.append(pdl.AttributeOp(attr)) + + operand_vals = [] + res_type_vals = [] + for val in op.operands: + pdl_val = cur_values[value_map[val]] + operand_vals.append(pdl_val) + res_type_vals.append(type_map[val]) + + pdl_op = pdl.OperationOp( + op_name=op.name, + attribute_value_names=attr_names, + attribute_values=attr_vals, + operand_values=operand_vals, + type_values=res_type_vals, + ) + + pdl_results = [ + pdl.ResultOp(index=i, parent=pdl_op).results[0] for i in range(len(op.results)) + ] + for i, (result, pdl_result) in enumerate(zip(op.results, pdl_results, strict=True)): + value_idx = value_map[op.operands[i]] + value_map[result] = value_idx + pdl_value_map[pdl_result] = value_idx + type_map[result] = res_type_vals[i] + cur_values[value_idx] = pdl_result + + return pdl_op + + pdl_ops: list[pdl.OperationOp] = [] + starting_values: list[pdl.ValueType] = [] + cur_values: list[pdl.ValueType] = [] + value_map: dict[quantum.QubitSSAValue, int] = {} + pdl_value_map: dict[pdl.ValueType, int] = {} + type_map: dict[quantum.QubitSSAValue, pdl.TypeType] = {} + + pattern_block: Block = Block() + with ImplicitBuilder(pattern_block): + + # Initialization + for i, (arg, arg_type) in enumerate( + zip(pattern_fn.body.block.args, pattern_fn.body.block.arg_types) + ): + # Create pdl types + pdl_type = pdl.TypeOp(arg_type).results[0] + type_map[arg] = pdl_type + + # Create pdl values + pdl_val = pdl.OperandOp(value_type=pdl_type).results[0] + starting_values.append(pdl_val) + cur_values.append(pdl_val) + value_map[arg] = i + pdl_value_map[pdl_val] = i + + # Create pdl ops + for op in pattern_fn.body.ops: + pdl_op = create_pdl_op( + op=op, + cur_values=cur_values, + value_map=value_map, + pdl_value_map=pdl_value_map, + type_map=type_map, + ) + pdl_ops.append(pdl_op) + + rewrite_block: Block = Block() + with ImplicitBuilder(rewrite_block): + + rw_value_map: dict[quantum.QubitSSAValue, int] = {} + rw_pdl_ops: list[pdl.OperationOp] = [] + rw_terminal_values: list[pdl.ValueType] = [] + + # Rewrite block initialization + for i, (p_arg, r_arg) in enumerate( + zip(pattern_fn.body.block.args, rewrite_fn.body.block.args, strict=True) + ): + rw_terminal_values.append(starting_values[i]) + rw_value_map[r_arg] = i + type_map[r_arg] = type_map[p_arg] + + # Create pdl ops for operations in the rewrite pattern + for op in rewrite_fn.body.ops: + pdl_op = create_pdl_op( + op=op, + cur_values=rw_terminal_values, + value_map=rw_value_map, + pdl_value_map=pdl_value_map, + type_map=type_map, + ) + rw_pdl_ops.append(pdl_op) + + # Replace all operations in the original pattern with values generated by + # the rewrite pattern + for op, pdl_op in tuple(zip(pattern_fn.body.ops, pdl_ops, strict=True))[::-1]: + repl_vals = [] + for pdl_val in pdl_op.operand_values: + if isinstance(pdl_val.op, pdl.ResultOp): + repl_vals.append(pdl_val) + else: + idx = pdl_value_map[pdl_val] + repl_vals.append(rw_terminal_values[idx]) + + _ = pdl.ReplaceOp(pdl_op, repl_values=repl_vals) + + rewrite_op = pdl.RewriteOp(pdl_ops[-1].results[0], body=Region(rewrite_block)) + pattern_block.add_op(rewrite_op) + pattern_op = pdl.PatternOp(benefit=1, sym_name="temp", body=Region(pattern_block)) + mod.body.block.add_op(pattern_op) + + ctx = Context() + _ = QuantumParser(ctx, "") + + PassPipeline((ApplyPDLPass(),)).apply(ctx, mod) + mod.body.block.erase_op(pattern_op) + + +def pattern_match(func: QJIT = None, patterns: dict[Callable, Callable] = {}): + """Apply pattern matching to q QJIT-ed workflow.""" + if func is None: + return partial(pattern_match, patterns=patterns) + + @wraps(func) + def wrapper(*args, **kwargs): + mod = xdsl_from_qjit(func)(*args, **kwargs) + + for pattern, rewrite in patterns.items(): + p_nargs = len(signature(pattern).parameters) + r_nargs = len(signature(rewrite).parameters) + if p_nargs != r_nargs: + raise ValueError("Pattern and match must have the same number of qubits as inputs") + + _rewrite_mod(mod, pattern, rewrite, p_nargs) + + return mod + + return wrapper From 7bbecc38aaed4a2a253c80b9a8e5dc8255315b2a Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 20 Nov 2025 17:18:13 -0500 Subject: [PATCH 26/38] Add example --- .../transforms/pattern_match_example.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 frontend/catalyst/python_interface/transforms/pattern_match_example.py diff --git a/frontend/catalyst/python_interface/transforms/pattern_match_example.py b/frontend/catalyst/python_interface/transforms/pattern_match_example.py new file mode 100644 index 0000000000..cf30ead289 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/pattern_match_example.py @@ -0,0 +1,111 @@ +import pennylane as qml + +import catalyst + +from .pattern_match import pattern_match + +qml.capture.enable() + + +def pattern1(w1, w2): + qml.CZ([w1, w2]) + qml.PauliZ(w2) + + +def rewrite1(w1, w2): + qml.PauliX(w1) + qml.S(w1) + qml.T(w2) + + +def pattern2(w1): + qml.T(w1) + + +def rewrite2(w1): + qml.PauliX(w1) + qml.S(w1) + qml.H(w1) + + +@pattern_match(patterns={pattern1: rewrite1, pattern2: rewrite2}) +@catalyst.qjit +@qml.qnode(qml.device("lightning.qubit", wires=4)) +def workflow(): + qml.CZ([0, 1]) + qml.Z(1) + + qml.T(0) + + qml.CNOT([1, 2]) + qml.PauliY(2) + qml.PauliX(3) + return qml.state() + + +xmod = workflow() +print(xmod) +""" +Output: + +builtin.module @workflow { + func.func public @jit_workflow() -> (tensor<16xcomplex>) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_workflow::@workflow() : () -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> + } + builtin.module @module_workflow { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @workflow() -> (tensor<16xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<4> : tensor}> : () -> tensor + %3 = quantum.alloc(4) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %7 = tensor.extract %6[] : tensor + %8 = quantum.extract %3[%7] : !quantum.reg -> !quantum.bit + %9, %10 = quantum.custom "CZ"() %5, %8 : !quantum.bit, !quantum.bit + %11 = quantum.custom "PauliZ"() %10 : !quantum.bit + %12 = quantum.custom "PauliX"() %9 : !quantum.bit + %13 = quantum.custom "S"() %12 : !quantum.bit + %14 = quantum.custom "Hadamard"() %13 : !quantum.bit + %15 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor + %16 = tensor.extract %15[] : tensor + %17 = quantum.extract %3[%16] : !quantum.reg -> !quantum.bit + %18, %19 = quantum.custom "CNOT"() %11, %17 : !quantum.bit, !quantum.bit + %20 = quantum.custom "PauliY"() %19 : !quantum.bit + %21 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor + %22 = tensor.extract %21[] : tensor + %23 = quantum.extract %3[%22] : !quantum.reg -> !quantum.bit + %24 = quantum.custom "PauliX"() %23 : !quantum.bit + %25 = tensor.extract %0[] : tensor + %26 = quantum.insert %3[%25], %14 : !quantum.reg, !quantum.bit + %27 = tensor.extract %6[] : tensor + %28 = quantum.insert %26[%27], %18 : !quantum.reg, !quantum.bit + %29 = tensor.extract %15[] : tensor + %30 = quantum.insert %28[%29], %20 : !quantum.reg, !quantum.bit + %31 = tensor.extract %21[] : tensor + %32 = quantum.insert %30[%31], %24 : !quantum.reg, !quantum.bit + %33 = quantum.compbasis qreg %32 : !quantum.obs + %34 = quantum.state %33 : tensor<16xcomplex> + quantum.dealloc %32 : !quantum.reg + quantum.device_release + func.return %34 : tensor<16xcomplex> + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} +""" From dc1846ad860b9d619a0588b1583a34de29e4e10d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 20 Nov 2025 17:20:02 -0500 Subject: [PATCH 27/38] Move example output to mlir file --- .../transforms/pattern_match_example.mlir | 60 +++++++++++++++++ .../transforms/pattern_match_example.py | 65 +------------------ 2 files changed, 61 insertions(+), 64 deletions(-) create mode 100644 frontend/catalyst/python_interface/transforms/pattern_match_example.mlir diff --git a/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir b/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir new file mode 100644 index 0000000000..9880cbfd66 --- /dev/null +++ b/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir @@ -0,0 +1,60 @@ +builtin.module @workflow { + func.func public @jit_workflow() -> (tensor<16xcomplex>) attributes {llvm.emit_c_interface} { + %0 = catalyst.launch_kernel @module_workflow::@workflow() : () -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> + } + builtin.module @module_workflow { + builtin.module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { + transform.yield + } + } + func.func public @workflow() -> (tensor<16xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { + %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor + %1 = tensor.extract %0[] : tensor + quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] + %2 = "stablehlo.constant"() <{value = dense<4> : tensor}> : () -> tensor + %3 = quantum.alloc(4) : !quantum.reg + %4 = tensor.extract %0[] : tensor + %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit + %6 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor + %7 = tensor.extract %6[] : tensor + %8 = quantum.extract %3[%7] : !quantum.reg -> !quantum.bit + %9, %10 = quantum.custom "CZ"() %5, %8 : !quantum.bit, !quantum.bit + %11 = quantum.custom "PauliZ"() %10 : !quantum.bit + %12 = quantum.custom "PauliX"() %9 : !quantum.bit + %13 = quantum.custom "S"() %12 : !quantum.bit + %14 = quantum.custom "Hadamard"() %13 : !quantum.bit + %15 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor + %16 = tensor.extract %15[] : tensor + %17 = quantum.extract %3[%16] : !quantum.reg -> !quantum.bit + %18, %19 = quantum.custom "CNOT"() %11, %17 : !quantum.bit, !quantum.bit + %20 = quantum.custom "PauliY"() %19 : !quantum.bit + %21 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor + %22 = tensor.extract %21[] : tensor + %23 = quantum.extract %3[%22] : !quantum.reg -> !quantum.bit + %24 = quantum.custom "PauliX"() %23 : !quantum.bit + %25 = tensor.extract %0[] : tensor + %26 = quantum.insert %3[%25], %14 : !quantum.reg, !quantum.bit + %27 = tensor.extract %6[] : tensor + %28 = quantum.insert %26[%27], %18 : !quantum.reg, !quantum.bit + %29 = tensor.extract %15[] : tensor + %30 = quantum.insert %28[%29], %20 : !quantum.reg, !quantum.bit + %31 = tensor.extract %21[] : tensor + %32 = quantum.insert %30[%31], %24 : !quantum.reg, !quantum.bit + %33 = quantum.compbasis qreg %32 : !quantum.obs + %34 = quantum.state %33 : tensor<16xcomplex> + quantum.dealloc %32 : !quantum.reg + quantum.device_release + func.return %34 : tensor<16xcomplex> + } + } + func.func @setup() { + quantum.init + func.return + } + func.func @teardown() { + quantum.finalize + func.return + } +} \ No newline at end of file diff --git a/frontend/catalyst/python_interface/transforms/pattern_match_example.py b/frontend/catalyst/python_interface/transforms/pattern_match_example.py index cf30ead289..e7b5de2cb6 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match_example.py +++ b/frontend/catalyst/python_interface/transforms/pattern_match_example.py @@ -45,67 +45,4 @@ def workflow(): xmod = workflow() print(xmod) -""" -Output: - -builtin.module @workflow { - func.func public @jit_workflow() -> (tensor<16xcomplex>) attributes {llvm.emit_c_interface} { - %0 = catalyst.launch_kernel @module_workflow::@workflow() : () -> tensor<16xcomplex> - func.return %0 : tensor<16xcomplex> - } - builtin.module @module_workflow { - builtin.module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { - transform.yield - } - } - func.func public @workflow() -> (tensor<16xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { - %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor - %1 = tensor.extract %0[] : tensor - quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] - %2 = "stablehlo.constant"() <{value = dense<4> : tensor}> : () -> tensor - %3 = quantum.alloc(4) : !quantum.reg - %4 = tensor.extract %0[] : tensor - %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit - %6 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor - %7 = tensor.extract %6[] : tensor - %8 = quantum.extract %3[%7] : !quantum.reg -> !quantum.bit - %9, %10 = quantum.custom "CZ"() %5, %8 : !quantum.bit, !quantum.bit - %11 = quantum.custom "PauliZ"() %10 : !quantum.bit - %12 = quantum.custom "PauliX"() %9 : !quantum.bit - %13 = quantum.custom "S"() %12 : !quantum.bit - %14 = quantum.custom "Hadamard"() %13 : !quantum.bit - %15 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor - %16 = tensor.extract %15[] : tensor - %17 = quantum.extract %3[%16] : !quantum.reg -> !quantum.bit - %18, %19 = quantum.custom "CNOT"() %11, %17 : !quantum.bit, !quantum.bit - %20 = quantum.custom "PauliY"() %19 : !quantum.bit - %21 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor - %22 = tensor.extract %21[] : tensor - %23 = quantum.extract %3[%22] : !quantum.reg -> !quantum.bit - %24 = quantum.custom "PauliX"() %23 : !quantum.bit - %25 = tensor.extract %0[] : tensor - %26 = quantum.insert %3[%25], %14 : !quantum.reg, !quantum.bit - %27 = tensor.extract %6[] : tensor - %28 = quantum.insert %26[%27], %18 : !quantum.reg, !quantum.bit - %29 = tensor.extract %15[] : tensor - %30 = quantum.insert %28[%29], %20 : !quantum.reg, !quantum.bit - %31 = tensor.extract %21[] : tensor - %32 = quantum.insert %30[%31], %24 : !quantum.reg, !quantum.bit - %33 = quantum.compbasis qreg %32 : !quantum.obs - %34 = quantum.state %33 : tensor<16xcomplex> - quantum.dealloc %32 : !quantum.reg - quantum.device_release - func.return %34 : tensor<16xcomplex> - } - } - func.func @setup() { - quantum.init - func.return - } - func.func @teardown() { - quantum.finalize - func.return - } -} -""" +# To view output, check out pattern_match_example.mlir From d008136eff8b4be68264370c98816b249f7e34ad Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 20 Nov 2025 17:21:10 -0500 Subject: [PATCH 28/38] EOF new line --- .../python_interface/transforms/pattern_match_example.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir b/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir index 9880cbfd66..e381f1aefc 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir +++ b/frontend/catalyst/python_interface/transforms/pattern_match_example.mlir @@ -57,4 +57,4 @@ builtin.module @workflow { quantum.finalize func.return } -} \ No newline at end of file +} From 694ecc94406e8fc77df441c6f76b8827e2481243 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 21 Nov 2025 11:40:25 -0500 Subject: [PATCH 29/38] Put impl inside a pass --- .../transforms/pattern_match.py | 294 ++++++++++-------- 1 file changed, 161 insertions(+), 133 deletions(-) diff --git a/frontend/catalyst/python_interface/transforms/pattern_match.py b/frontend/catalyst/python_interface/transforms/pattern_match.py index b6f2fc0b9a..60478673cb 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match.py +++ b/frontend/catalyst/python_interface/transforms/pattern_match.py @@ -22,7 +22,7 @@ from xdsl.context import Context from xdsl.dialects import builtin, func, pdl from xdsl.ir import Block, Operation, Region -from xdsl.passes import PassPipeline +from xdsl.passes import ModulePass, PassPipeline from xdsl.traits import SymbolTable from xdsl.transforms.apply_pdl import ApplyPDLPass @@ -32,56 +32,176 @@ from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.dialects import quantum -qml.capture.enable() - -def _rewrite_mod( - mod: builtin.ModuleOp, pattern: Callable, rewrite: Callable, num_wires: int -) -> None: - - args = list(range(num_wires)) - - @xdsl_from_qjit - @qml.qjit - @qml.qnode(qml.device("lightning.qubit", wires=num_wires)) - def mod_fn(): - @decomposition_rule(is_qreg=False, num_params=0) - def pattern__(*args__): - return pattern(*args__) - - @decomposition_rule(is_qreg=False, num_params=0) - def rewrite__(*args__): - return rewrite(*args__) - - pattern__(*args) - rewrite__(*args) - - return qml.state() - - dummy_mod = mod_fn() - - qnode_mod: builtin.ModuleOp = [ - o for o in dummy_mod.body.ops if isinstance(o, builtin.ModuleOp) - ][0] - pattern_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "pattern__") - rewrite_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "rewrite__") - - pattern_fn.body.block.erase_op(pattern_fn.body.block.last_op) - rewrite_fn.body.block.erase_op(rewrite_fn.body.block.last_op) - - def create_pdl_op( +class PatternMatchPass(ModulePass): + """Module pass to apply Python-level pattern-matching to xDSL modules.""" + + name: str = "pattern-match" + + def __init__(self, patterns: dict[Callable, Callable] = {}): + self.patterns = patterns + + def apply(self, ctx: Context, op: builtin.ModuleOp): + """Apply the provided patterns to the input module.""" + + # Load default dialects + for dialect in QuantumParser.default_dialects: + if ctx.get_optional_dialect(dialect.name) is None: + ctx.load_dialect(dialect) + + pdl_pass = ApplyPDLPass() + for pat, rw in self.patterns.items(): + pat_fn, rw_fn = self._pattern_to_xdsl(pat, rw) + pattern_op = self._create_pdl_pattern(pat_fn, rw_fn) + + op.body.block.add_op(pattern_op) + pdl_pass.apply(ctx, op) + op.body.block.erase_op(pattern_op) + + def _pattern_to_xdsl(self, py_pattern: Callable, py_rewrite: Callable) -> func.FuncOp: + """Create xDSL ``func.FuncOp``\ s from Python pattern and rewrite functions.""" + + n_args = len(signature(py_pattern).parameters) + if len(signature(py_rewrite).parameters) != n_args: + raise ValueError("Pattern and match must have the same number of qubits as inputs") + + @xdsl_from_qjit + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=n_args)) + def mod_fn(): + @decomposition_rule(is_qreg=False, num_params=0) + def pattern__(*args__): + return py_pattern(*args__) + + @decomposition_rule(is_qreg=False, num_params=0) + def rewrite__(*args__): + return py_rewrite(*args__) + + pattern__(*range(n_args)) + rewrite__(*range(n_args)) + + return qml.state() + + qnode_mod: builtin.ModuleOp = [ + o for o in mod_fn().body.ops if isinstance(o, builtin.ModuleOp) + ][0] + pattern_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "pattern__") + rewrite_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "rewrite__") + + # Erase return op from the func bodies. These do not need to be matched. + # This will make the funcs invalid, but that is fine since they will be + # discarded after the rewriting is done. + pattern_fn.body.block.erase_op(pattern_fn.body.block.last_op) + rewrite_fn.body.block.erase_op(rewrite_fn.body.block.last_op) + + return pattern_fn, rewrite_fn + + def _create_pdl_pattern(self, pattern_fn: func.FuncOp, rewrite_fn: func.FuncOp): + """Use xDSL functions containing patterns to match and rewrite to create a PDL PatternOp.""" + pdl_ops: list[pdl.OperationOp] = [] + starting_values: list[pdl.ValueType] = [] + cur_values: list[pdl.ValueType] = [] + value_map: dict[quantum.QubitSSAValue, int] = {} + pdl_value_map: dict[pdl.ValueType, int] = {} + type_map: dict[quantum.QubitSSAValue, pdl.TypeType] = {} + + # Create block containing the pattern we want to match. This will be the main + # body of the pdl.PatternOp we're building. + pattern_block: Block = Block() + with ImplicitBuilder(pattern_block): + + # Initialization + for i, (arg, arg_type) in enumerate( + zip(pattern_fn.body.block.args, pattern_fn.body.block.arg_types) + ): + # Create pdl types + pdl_type = pdl.TypeOp(arg_type).results[0] + type_map[arg] = pdl_type + + # Create pdl values + pdl_val = pdl.OperandOp(value_type=pdl_type).results[0] + starting_values.append(pdl_val) + cur_values.append(pdl_val) + value_map[arg] = i + pdl_value_map[pdl_val] = i + + # Create pdl ops + for op in pattern_fn.body.ops: + pdl_op = self._create_pdl_operation( + op=op, + cur_values=cur_values, + value_map=value_map, + pdl_value_map=pdl_value_map, + type_map=type_map, + ) + pdl_ops.append(pdl_op) + + # Create block containing the rewrite pattern. This will be the block inside + # a pdl.RewriteOp, which will be the terminator op of the pdl.PatternOp that + # we're building. + rewrite_block: Block = Block() + with ImplicitBuilder(rewrite_block): + + rw_value_map: dict[quantum.QubitSSAValue, int] = {} + rw_pdl_ops: list[pdl.OperationOp] = [] + rw_terminal_values: list[pdl.ValueType] = [] + + # Rewrite block initialization + for i, (p_arg, r_arg) in enumerate( + zip(pattern_fn.body.block.args, rewrite_fn.body.block.args, strict=True) + ): + rw_terminal_values.append(starting_values[i]) + rw_value_map[r_arg] = i + type_map[r_arg] = type_map[p_arg] + + # Create pdl ops for operations in the rewrite pattern + for op in rewrite_fn.body.ops: + pdl_op = self._create_pdl_operation( + op=op, + cur_values=rw_terminal_values, + value_map=rw_value_map, + pdl_value_map=pdl_value_map, + type_map=type_map, + ) + rw_pdl_ops.append(pdl_op) + + # Replace all operations in the original pattern with values generated by + # the rewrite pattern + for op, pdl_op in tuple(zip(pattern_fn.body.ops, pdl_ops, strict=True))[::-1]: + repl_vals = [] + for pdl_val in pdl_op.operand_values: + if isinstance(pdl_val.op, pdl.ResultOp): + repl_vals.append(pdl_val) + else: + idx = pdl_value_map[pdl_val] + repl_vals.append(rw_terminal_values[idx]) + + _ = pdl.ReplaceOp(pdl_op, repl_values=repl_vals) + + rewrite_op = pdl.RewriteOp(pdl_ops[-1].results[0], body=Region(rewrite_block)) + pattern_block.add_op(rewrite_op) + pattern_op = pdl.PatternOp(benefit=1, sym_name="temp", body=Region(pattern_block)) + return pattern_op + + def _create_pdl_operation( + self, op: Operation, cur_values: list[pdl.ValueType], value_map: dict[quantum.QubitSSAValue, int], pdl_value_map: dict[pdl.ValueType, int], type_map: dict[quantum.QubitSSAValue, pdl.TypeType], ): + """Create a pdl.OperationOp corresponding to the input xDSL operation. This method must be + called from within an ``ImplicitBuilder`` context.""" + # Create operations corresponding to the operation attributes and properties. attr_vals = [] attr_names = [] for attr_name, attr in (op.properties | op.attributes).items(): attr_names.append(builtin.StringAttr(attr_name)) attr_vals.append(pdl.AttributeOp(attr)) + # Find values corresponding to the operands. For now, assume all operands and results + # are qubits. operand_vals = [] res_type_vals = [] for val in op.operands: @@ -97,6 +217,8 @@ def create_pdl_op( type_values=res_type_vals, ) + # A pdl.OperationOp returns a pdl.OperationType. To use the values corresponding to + # the results of the original operation, we must use pdl.ResultOp to extract them. pdl_results = [ pdl.ResultOp(index=i, parent=pdl_op).results[0] for i in range(len(op.results)) ] @@ -109,92 +231,6 @@ def create_pdl_op( return pdl_op - pdl_ops: list[pdl.OperationOp] = [] - starting_values: list[pdl.ValueType] = [] - cur_values: list[pdl.ValueType] = [] - value_map: dict[quantum.QubitSSAValue, int] = {} - pdl_value_map: dict[pdl.ValueType, int] = {} - type_map: dict[quantum.QubitSSAValue, pdl.TypeType] = {} - - pattern_block: Block = Block() - with ImplicitBuilder(pattern_block): - - # Initialization - for i, (arg, arg_type) in enumerate( - zip(pattern_fn.body.block.args, pattern_fn.body.block.arg_types) - ): - # Create pdl types - pdl_type = pdl.TypeOp(arg_type).results[0] - type_map[arg] = pdl_type - - # Create pdl values - pdl_val = pdl.OperandOp(value_type=pdl_type).results[0] - starting_values.append(pdl_val) - cur_values.append(pdl_val) - value_map[arg] = i - pdl_value_map[pdl_val] = i - - # Create pdl ops - for op in pattern_fn.body.ops: - pdl_op = create_pdl_op( - op=op, - cur_values=cur_values, - value_map=value_map, - pdl_value_map=pdl_value_map, - type_map=type_map, - ) - pdl_ops.append(pdl_op) - - rewrite_block: Block = Block() - with ImplicitBuilder(rewrite_block): - - rw_value_map: dict[quantum.QubitSSAValue, int] = {} - rw_pdl_ops: list[pdl.OperationOp] = [] - rw_terminal_values: list[pdl.ValueType] = [] - - # Rewrite block initialization - for i, (p_arg, r_arg) in enumerate( - zip(pattern_fn.body.block.args, rewrite_fn.body.block.args, strict=True) - ): - rw_terminal_values.append(starting_values[i]) - rw_value_map[r_arg] = i - type_map[r_arg] = type_map[p_arg] - - # Create pdl ops for operations in the rewrite pattern - for op in rewrite_fn.body.ops: - pdl_op = create_pdl_op( - op=op, - cur_values=rw_terminal_values, - value_map=rw_value_map, - pdl_value_map=pdl_value_map, - type_map=type_map, - ) - rw_pdl_ops.append(pdl_op) - - # Replace all operations in the original pattern with values generated by - # the rewrite pattern - for op, pdl_op in tuple(zip(pattern_fn.body.ops, pdl_ops, strict=True))[::-1]: - repl_vals = [] - for pdl_val in pdl_op.operand_values: - if isinstance(pdl_val.op, pdl.ResultOp): - repl_vals.append(pdl_val) - else: - idx = pdl_value_map[pdl_val] - repl_vals.append(rw_terminal_values[idx]) - - _ = pdl.ReplaceOp(pdl_op, repl_values=repl_vals) - - rewrite_op = pdl.RewriteOp(pdl_ops[-1].results[0], body=Region(rewrite_block)) - pattern_block.add_op(rewrite_op) - pattern_op = pdl.PatternOp(benefit=1, sym_name="temp", body=Region(pattern_block)) - mod.body.block.add_op(pattern_op) - - ctx = Context() - _ = QuantumParser(ctx, "") - - PassPipeline((ApplyPDLPass(),)).apply(ctx, mod) - mod.body.block.erase_op(pattern_op) - def pattern_match(func: QJIT = None, patterns: dict[Callable, Callable] = {}): """Apply pattern matching to q QJIT-ed workflow.""" @@ -204,15 +240,7 @@ def pattern_match(func: QJIT = None, patterns: dict[Callable, Callable] = {}): @wraps(func) def wrapper(*args, **kwargs): mod = xdsl_from_qjit(func)(*args, **kwargs) - - for pattern, rewrite in patterns.items(): - p_nargs = len(signature(pattern).parameters) - r_nargs = len(signature(rewrite).parameters) - if p_nargs != r_nargs: - raise ValueError("Pattern and match must have the same number of qubits as inputs") - - _rewrite_mod(mod, pattern, rewrite, p_nargs) - + PatternMatchPass(patterns=patterns).apply(Context(), mod) return mod return wrapper From ecd4d7f7aec41963c7b1ebea50ca860e213721cf Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 24 Nov 2025 11:13:38 -0500 Subject: [PATCH 30/38] Create scaffolding for serializing and deserializing patterns --- .../transforms/pattern_match.py | 103 +++++++++++------- 1 file changed, 66 insertions(+), 37 deletions(-) diff --git a/frontend/catalyst/python_interface/transforms/pattern_match.py b/frontend/catalyst/python_interface/transforms/pattern_match.py index 60478673cb..206d4c29f2 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match.py +++ b/frontend/catalyst/python_interface/transforms/pattern_match.py @@ -13,16 +13,17 @@ # limitations under the License. """Function for applying high-level pattern-matching to xDSL modules.""" +import tempfile +from collections.abc import Callable, Sequence from functools import partial, wraps -from inspect import signature -from typing import Callable +from inspect import getsource, signature import pennylane as qml from xdsl.builder import ImplicitBuilder from xdsl.context import Context from xdsl.dialects import builtin, func, pdl from xdsl.ir import Block, Operation, Region -from xdsl.passes import ModulePass, PassPipeline +from xdsl.passes import ModulePass from xdsl.traits import SymbolTable from xdsl.transforms.apply_pdl import ApplyPDLPass @@ -31,15 +32,19 @@ from catalyst.python_interface import QuantumParser from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.pass_api import compiler_transform class PatternMatchPass(ModulePass): """Module pass to apply Python-level pattern-matching to xDSL modules.""" name: str = "pattern-match" + _patterns: dict[Callable, Callable] + _pdl_patterns: tuple[pdl.PatternOp, ...] - def __init__(self, patterns: dict[Callable, Callable] = {}): - self.patterns = patterns + def __init__(self, patterns: dict[Callable, Callable] | None = None): + self._patterns = patterns or {} + self._pdl_patterns = () def apply(self, ctx: Context, op: builtin.ModuleOp): """Apply the provided patterns to the input module.""" @@ -49,11 +54,14 @@ def apply(self, ctx: Context, op: builtin.ModuleOp): if ctx.get_optional_dialect(dialect.name) is None: ctx.load_dialect(dialect) - pdl_pass = ApplyPDLPass() - for pat, rw in self.patterns.items(): - pat_fn, rw_fn = self._pattern_to_xdsl(pat, rw) - pattern_op = self._create_pdl_pattern(pat_fn, rw_fn) + if self._patterns and not self._pdl_patterns: + for pat, rw in self._patterns.items(): + pat_fn, rw_fn = self._pattern_to_xdsl(pat, rw) + pattern_op = self._create_pdl_pattern(pat_fn, rw_fn) + self._pdl_patterns += (pattern_op,) + pdl_pass = ApplyPDLPass() + for pattern_op in self._pdl_patterns: op.body.block.add_op(pattern_op) pdl_pass.apply(ctx, op) op.body.block.erase_op(pattern_op) @@ -63,34 +71,32 @@ def _pattern_to_xdsl(self, py_pattern: Callable, py_rewrite: Callable) -> func.F n_args = len(signature(py_pattern).parameters) if len(signature(py_rewrite).parameters) != n_args: - raise ValueError("Pattern and match must have the same number of qubits as inputs") + raise ValueError("Search and rewrite patterns must have the same number of arguments.") + # Rename functions so that their names in the xDSL module are known + pattern_name = "__pattern" + rewrite_name = "__rewrite" + py_pattern.__name__ = pattern_name + py_rewrite.__name__ = rewrite_name + args = range(n_args) + + # Lower the functions and extract them from the IR @xdsl_from_qjit @qml.qjit - @qml.qnode(qml.device("lightning.qubit", wires=n_args)) + @qml.qnode(qml.device("null.qubit", wires=n_args)) def mod_fn(): - @decomposition_rule(is_qreg=False, num_params=0) - def pattern__(*args__): - return py_pattern(*args__) - - @decomposition_rule(is_qreg=False, num_params=0) - def rewrite__(*args__): - return py_rewrite(*args__) - - pattern__(*range(n_args)) - rewrite__(*range(n_args)) + decomposition_rule(py_pattern, is_qreg=False, num_params=0)(*args) + decomposition_rule(py_rewrite, is_qreg=False, num_params=0)(*args) return qml.state() - qnode_mod: builtin.ModuleOp = [ - o for o in mod_fn().body.ops if isinstance(o, builtin.ModuleOp) - ][0] - pattern_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "pattern__") - rewrite_fn: func.FuncOp = SymbolTable.lookup_symbol(qnode_mod, "rewrite__") + mod = mod_fn() + pattern_fn: func.FuncOp = SymbolTable.lookup_symbol(mod, pattern_name) + rewrite_fn: func.FuncOp = SymbolTable.lookup_symbol(mod, rewrite_name) # Erase return op from the func bodies. These do not need to be matched. # This will make the funcs invalid, but that is fine since they will be - # discarded after the rewriting is done. + # discarded after the PDL patterns are created. pattern_fn.body.block.erase_op(pattern_fn.body.block.last_op) rewrite_fn.body.block.erase_op(rewrite_fn.body.block.last_op) @@ -231,16 +237,39 @@ def _create_pdl_operation( return pdl_op + @classmethod + def create_from_serialized_options(cls, **options): + """Create a pass instance using serialized patterns.""" + paths = options["pattern_paths"] + py_patterns = _patterns_from_paths(paths) + pass_instance = cls(patterns=py_patterns) + return pass_instance + + +pattern_match = compiler_transform(PatternMatchPass) + + +@pattern_match.custom_serialize_options +def _(*, patterns: dict[Callable, Callable] = {}): + paths = [] + + for pat, rw in patterns.items(): + # Create source for pattern and rewrite functions + cur_program = _create_pattern_source(pat, rw) + + # Create tempfile with search pattern + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as cur_file: + cur_file.write(cur_program) + paths.append(cur_file.name) + + valued_options = {"pattern_paths": paths} + return (), valued_options + -def pattern_match(func: QJIT = None, patterns: dict[Callable, Callable] = {}): - """Apply pattern matching to q QJIT-ed workflow.""" - if func is None: - return partial(pattern_match, patterns=patterns) +def _create_pattern_source(pattern: Callable, rewrite: Callable) -> str: + """Create a program represented as a string that encodes the ``pattern`` and ``rewrite`` + functions along with necessary locals and globals.""" - @wraps(func) - def wrapper(*args, **kwargs): - mod = xdsl_from_qjit(func)(*args, **kwargs) - PatternMatchPass(patterns=patterns).apply(Context(), mod) - return mod - return wrapper +def _patterns_from_paths(paths: Sequence[str]) -> dict[Callable, Callable]: + """Create pattern and rewrite functions using source files specified by ``paths``.""" From 6fd7b6fa739b114312527a53eeff0a7c1c5717dd Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 25 Nov 2025 11:12:21 -0500 Subject: [PATCH 31/38] Add dev comments --- .../python_interface/transforms/pattern_match.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/transforms/pattern_match.py b/frontend/catalyst/python_interface/transforms/pattern_match.py index 206d4c29f2..eb6fbbade4 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match.py +++ b/frontend/catalyst/python_interface/transforms/pattern_match.py @@ -238,7 +238,7 @@ def _create_pdl_operation( return pdl_op @classmethod - def create_from_serialized_options(cls, **options): + def create_from_serialized_options(cls, *_, **options): """Create a pass instance using serialized patterns.""" paths = options["pattern_paths"] py_patterns = _patterns_from_paths(paths) @@ -269,7 +269,16 @@ def _(*, patterns: dict[Callable, Callable] = {}): def _create_pattern_source(pattern: Callable, rewrite: Callable) -> str: """Create a program represented as a string that encodes the ``pattern`` and ``rewrite`` functions along with necessary locals and globals.""" + # Collect imports + # Collect closure vars + # Collect called functions + # Collect function source + # Add variables pointing to pattern and rewrite functions def _patterns_from_paths(paths: Sequence[str]) -> dict[Callable, Callable]: """Create pattern and rewrite functions using source files specified by ``paths``.""" + # Read files at paths + # Exec + # Collect locals + # Find references to pattern and rewrite functions From 4976b7d33dd94cde423ae494a09c8dd027c96e74 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 25 Nov 2025 11:13:06 -0500 Subject: [PATCH 32/38] Remove bloat --- .../python_interface/transforms/pattern_match.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/frontend/catalyst/python_interface/transforms/pattern_match.py b/frontend/catalyst/python_interface/transforms/pattern_match.py index eb6fbbade4..fe8b11a24a 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match.py +++ b/frontend/catalyst/python_interface/transforms/pattern_match.py @@ -268,17 +268,8 @@ def _(*, patterns: dict[Callable, Callable] = {}): def _create_pattern_source(pattern: Callable, rewrite: Callable) -> str: """Create a program represented as a string that encodes the ``pattern`` and ``rewrite`` - functions along with necessary locals and globals.""" - # Collect imports - # Collect closure vars - # Collect called functions - # Collect function source - # Add variables pointing to pattern and rewrite functions + functions.""" def _patterns_from_paths(paths: Sequence[str]) -> dict[Callable, Callable]: """Create pattern and rewrite functions using source files specified by ``paths``.""" - # Read files at paths - # Exec - # Collect locals - # Find references to pattern and rewrite functions From 4c5db3d02e0346325a1ea7ae81f309c451bf8a72 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 25 Nov 2025 11:18:04 -0500 Subject: [PATCH 33/38] Add empty returns --- frontend/catalyst/python_interface/transforms/pattern_match.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/catalyst/python_interface/transforms/pattern_match.py b/frontend/catalyst/python_interface/transforms/pattern_match.py index fe8b11a24a..c012c696b6 100644 --- a/frontend/catalyst/python_interface/transforms/pattern_match.py +++ b/frontend/catalyst/python_interface/transforms/pattern_match.py @@ -269,7 +269,9 @@ def _(*, patterns: dict[Callable, Callable] = {}): def _create_pattern_source(pattern: Callable, rewrite: Callable) -> str: """Create a program represented as a string that encodes the ``pattern`` and ``rewrite`` functions.""" + return "" def _patterns_from_paths(paths: Sequence[str]) -> dict[Callable, Callable]: """Create pattern and rewrite functions using source files specified by ``paths``.""" + return {} From 4cbe02a206a27e956ec09a4e708899644336b5c7 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 26 Nov 2025 16:32:03 -0500 Subject: [PATCH 34/38] Remove reference to 'remove-chained-self-inverses' --- .../catalyst/python_interface/doc/unified_compiler_cookbook.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst index 1525c030ee..b59cb83e48 100644 --- a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst +++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst @@ -709,7 +709,7 @@ module @circuit { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { %0 = transform.apply_registered_pass "merge-rotations" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> - %1 = transform.apply_registered_pass "remove-chained-self-inverse" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + %1 = transform.apply_registered_pass "cancel-inverses" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> transform.yield } } From c5f08c392531506ebd85ec375a963a550fd48fc0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 14:03:46 -0500 Subject: [PATCH 35/38] Remove leftover references to pennylane.compiler.python_compiler --- .../python_interface/doc/unified_compiler_cookbook.rst | 10 +++++----- .../catalyst/python_interface/visualization/draw.py | 2 +- .../python_interface/dialects/test_catalyst_dialect.py | 2 +- .../python_interface/dialects/test_mbqc_dialect.py | 2 +- .../python_interface/dialects/test_qec_dialect.py | 2 +- .../python_interface/dialects/test_quantum_dialect.py | 2 +- .../dialects/test_transform_dialect.py | 2 +- .../pytest/python_interface/test_python_compiler.py | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst index b59cb83e48..01f3295fa5 100644 --- a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst +++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst @@ -640,7 +640,7 @@ the “PennyLane integration” section below. PennyLane integration ===================== -This section will cover the API in the ``qml.compiler.python_compiler`` +This section will cover the API in the ``catalyst.python_interface`` submodule. Lowering to MLIR @@ -981,7 +981,7 @@ Implications/notes ``compiler_transform`` is the function used to register xDSL ``ModulePass``\ es to be used with ``qjit``-ed workflows. It is currently accessible as -``qml.compiler.python_compiler.compiler_transform``. +``catalyst.python_interface.compiler_transform``. .. code-block:: python @@ -1031,7 +1031,7 @@ is compiled! Conversion utilities -------------------- -The ``python_compiler.conversion`` submodule provides several utilities +The ``catalyst.python_interface.conversion`` submodule provides several utilities for creating xDSL modules from Python functions. There are many more utilities in the submodule, but I will focus on the most important ones, and provide examples of how to use them. @@ -1230,7 +1230,7 @@ To use FileCheck with ``pytest``, we use the ```filecheck`` Python package `__, which allows us to use assertions for testing in a way that ``pytest`` can understand. All of the ``filecheck`` API has been captured inside two fixtures available -within the ``tests/python_compiler`` folder: +within the ``tests/python_interface`` folder: - ``run_filecheck``: This fixture is for unit testing. One can specify a program along with filecheck directives as a multi-line string. @@ -1331,7 +1331,7 @@ Key blockers ============ There are several blockers that are currently disabling developers from -taking full advantage of the ``python_compiler`` submodule. These +taking full advantage of the ``catalyst.python_interface`` submodule. These include: * Lack of support for quantum subroutines. This impacts pattern diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/visualization/draw.py index ac26977404..fde4771723 100644 --- a/frontend/catalyst/python_interface/visualization/draw.py +++ b/frontend/catalyst/python_interface/visualization/draw.py @@ -91,7 +91,7 @@ def wrapper(*args, **kwargs): warnings.warn( "The `draw` function does not yet support dynamic arguments.\n" "To visualize the circuit with dynamic parameters or wires, please use the\n" - "`compiler.python_compiler.visualization.generate_mlir_graph` function instead.", + "`catalyst.python_interface.visualization.generate_mlir_graph` function instead.", UserWarning, ) mlir_module = _get_mlir_module(qnode, args, kwargs) diff --git a/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py index 4f5c149ac0..dc0e40fe99 100644 --- a/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_catalyst_dialect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/dialects/catalyst.py.""" +"""Unit test module for the xDSL Catalyst dialect.""" import pytest diff --git a/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py index 6e776e5175..9f3af60e3e 100644 --- a/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_mbqc_dialect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/dialects/mbqc.py.""" +"""Unit tests for the xDSL MBQC dialect.""" import pytest diff --git a/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py index 02ed4a6f97..a8b486aeb0 100644 --- a/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_qec_dialect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/dialects/qec.py.""" +"""Unit tests for the xDSL QEC dialect.""" import pytest diff --git a/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py index d83561a193..31629387b3 100644 --- a/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_quantum_dialect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/dialects/quantum.py.""" +"""Unit tests for the xDSL Quantum dialect.""" import pytest diff --git a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py index 64aba9a85b..4f41831c3f 100644 --- a/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_transform_dialect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/transform.py.""" +"""Unit tests for the xDSL Transform dialect.""" from dataclasses import dataclass diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index 12c329f829..86b8b06f84 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/impl.py""" +"""Unit tests for the unified compiler's entry point.""" from dataclasses import dataclass From fe763571078ada5d0f8e634ffa1f3066f0ea042a Mon Sep 17 00:00:00 2001 From: Jake Zaia <23638795+jzaia18@users.noreply.github.com> Date: Mon, 1 Dec 2025 10:12:35 -0500 Subject: [PATCH 36/38] Streamline circuit inspection utilities (#2237) **Context:** Migrates the `chore/xdsl-utils` branch from PennyLane. **Description of the Change:** Renames the visualization module to inspection and improves several xDSL parsing utilities within the python interface. **Benefits:** Simplifies and augments existing utility functions within the ~visualization~ inspection module **Possible Drawbacks:** **Related GitHub Issues:** [sc-104851] --- doc/releases/changelog-dev.md | 9 +++- .../catalyst/python_interface/__init__.py | 2 +- .../{visualization => inspection}/__init__.py | 2 +- .../collector.py | 0 .../{visualization => inspection}/draw.py | 19 ++----- .../mlir_graph.py | 4 +- .../xdsl_conversion.py | 50 +++++++++++++++++-- .../test_draw_unified_compiler.py | 6 +-- .../test_mlir_graph.py | 2 +- 9 files changed, 65 insertions(+), 29 deletions(-) rename frontend/catalyst/python_interface/{visualization => inspection}/__init__.py (90%) rename frontend/catalyst/python_interface/{visualization => inspection}/collector.py (100%) rename frontend/catalyst/python_interface/{visualization => inspection}/draw.py (81%) rename frontend/catalyst/python_interface/{visualization => inspection}/mlir_graph.py (97%) rename frontend/catalyst/python_interface/{visualization => inspection}/xdsl_conversion.py (86%) rename frontend/test/pytest/python_interface/{visualization => inspection}/test_draw_unified_compiler.py (99%) rename frontend/test/pytest/python_interface/{visualization => inspection}/test_mlir_graph.py (99%) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 694390e266..7f26d74e78 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -11,7 +11,7 @@ which moves all quantum gate operations to a private callable. [(#8367)](https://github.com/PennyLaneAI/pennylane/pull/8367) -* A new experimental `split_non_commuting_pass` compiler pass has been added to +* A new experimental `split_non_commuting_pass` compiler pass has been added to `catalyst.python_interface.transforms`. This pass splits quantum functions that measure observables on the same wires into multiple function executions, where each execution measures observables on different wires (using the "wires" grouping @@ -29,7 +29,7 @@ [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169) [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183) -* The :meth:`catalyst.python_interface.transforms.convert_to_mbqc_formalism_pass` now +* The :meth:`catalyst.python_interface.transforms.convert_to_mbqc_formalism_pass` now supports :class:`~xdsl.dialects.scf.IndexSwitchOp` in IR and ignores regions that have no body. [(#8632)](https://github.com/PennyLaneAI/pennylane/pull/8632) @@ -195,6 +195,11 @@

Internal changes ⚙️

+* The `catalyst.python_interface.visualization` module has been renamed to + `catalyst.python_interface.inspection`, and various utility functions within this module + have been streamlined. + [(#2237)](https://github.com/PennyLaneAI/catalyst/pull/2237) + * Migrated the `pennylane.compiler.python_compiler` submodule from PennyLane to Catalyst. It is now accessible as `catalyst.python_interface`. [(#2199)](https://github.com/PennyLaneAI/catalyst/pull/2199) diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py index 85fb8ceef5..d2dc368784 100644 --- a/frontend/catalyst/python_interface/__init__.py +++ b/frontend/catalyst/python_interface/__init__.py @@ -14,9 +14,9 @@ """Python Compiler API for integration of Catalyst with xDSL.""" from .compiler import Compiler +from .inspection import QMLCollector from .parser import QuantumParser from .pass_api import compiler_transform -from .visualization import QMLCollector __all__ = [ "Compiler", diff --git a/frontend/catalyst/python_interface/visualization/__init__.py b/frontend/catalyst/python_interface/inspection/__init__.py similarity index 90% rename from frontend/catalyst/python_interface/visualization/__init__.py rename to frontend/catalyst/python_interface/inspection/__init__.py index 0a550976e5..9b5209ce2c 100644 --- a/frontend/catalyst/python_interface/visualization/__init__.py +++ b/frontend/catalyst/python_interface/inspection/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Visualization functions for Catalyst and xDSL transformations. +Circuit inspection & visualization functions for Catalyst and xDSL transformations. """ diff --git a/frontend/catalyst/python_interface/visualization/collector.py b/frontend/catalyst/python_interface/inspection/collector.py similarity index 100% rename from frontend/catalyst/python_interface/visualization/collector.py rename to frontend/catalyst/python_interface/inspection/collector.py diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/inspection/draw.py similarity index 81% rename from frontend/catalyst/python_interface/visualization/draw.py rename to frontend/catalyst/python_interface/inspection/draw.py index fde4771723..5bee87b522 100644 --- a/frontend/catalyst/python_interface/visualization/draw.py +++ b/frontend/catalyst/python_interface/inspection/draw.py @@ -21,33 +21,20 @@ from pennylane.tape import QuantumScript -from catalyst import qjit -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from catalyst.python_interface.compiler import Compiler from .collector import QMLCollector +from .xdsl_conversion import get_mlir_module if TYPE_CHECKING: from pennylane.typing import Callable from pennylane.workflow.qnode import QNode - from xdsl.dialects.builtin import ModuleOp # TODO: This caching mechanism should be improved, # because now it relies on a mutable global state _cache_store: dict[Callable, dict[int, tuple[str, str]]] = {} -def _get_mlir_module(qnode: QNode, args, kwargs) -> ModuleOp: - """Ensure the QNode is compiled and return its MLIR module.""" - if hasattr(qnode, "mlir_module") and qnode.mlir_module is not None: - return qnode.mlir_module - - func = getattr(qnode, "user_function", qnode) - jitted_qnode = qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(func) - jitted_qnode.jit_compile(args, **kwargs) - return jitted_qnode.mlir_module - - def draw(qnode: QNode, *, level: None | int = None) -> Callable: """ Draw the QNode at the specified level. @@ -91,10 +78,10 @@ def wrapper(*args, **kwargs): warnings.warn( "The `draw` function does not yet support dynamic arguments.\n" "To visualize the circuit with dynamic parameters or wires, please use the\n" - "`catalyst.python_interface.visualization.generate_mlir_graph` function instead.", + "`catalyst.python_interface.inspection.generate_mlir_graph` function instead.", UserWarning, ) - mlir_module = _get_mlir_module(qnode, args, kwargs) + mlir_module = get_mlir_module(qnode, args, kwargs) Compiler.run(mlir_module, callback=_draw_callback) if not cache: diff --git a/frontend/catalyst/python_interface/visualization/mlir_graph.py b/frontend/catalyst/python_interface/inspection/mlir_graph.py similarity index 97% rename from frontend/catalyst/python_interface/visualization/mlir_graph.py rename to frontend/catalyst/python_interface/inspection/mlir_graph.py index 79196185ba..1e537085bf 100644 --- a/frontend/catalyst/python_interface/visualization/mlir_graph.py +++ b/frontend/catalyst/python_interface/inspection/mlir_graph.py @@ -28,7 +28,7 @@ from catalyst.compiler import CompileError, _get_catalyst_cli_cmd from catalyst.python_interface.compiler import Compiler -from .draw import _get_mlir_module +from .xdsl_conversion import get_mlir_module if TYPE_CHECKING: from pennylane import QNode @@ -115,7 +115,7 @@ def wrapper(*args, **kwargs): # with the args and kwargs provided by the user. # TODO: we could integrate the callback mechanism within `qjit`, # so that we wouldn't need to recompile the qnode twice. - mlir_module = _get_mlir_module(qnode, args, kwargs) + mlir_module = get_mlir_module(qnode, args, kwargs) Compiler.run(mlir_module, callback=_mlir_graph_callback) return wrapper diff --git a/frontend/catalyst/python_interface/visualization/xdsl_conversion.py b/frontend/catalyst/python_interface/inspection/xdsl_conversion.py similarity index 86% rename from frontend/catalyst/python_interface/visualization/xdsl_conversion.py rename to frontend/catalyst/python_interface/inspection/xdsl_conversion.py index 917a292e94..7a83570016 100644 --- a/frontend/catalyst/python_interface/visualization/xdsl_conversion.py +++ b/frontend/catalyst/python_interface/inspection/xdsl_conversion.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This file contains the implementation of the QMLCollector class, -which collects and maps PennyLane operations and measurements from xDSL.""" +"""This file contains utility functions for parsing PennyLane objects from xDSL.""" from __future__ import annotations @@ -27,10 +26,14 @@ from pennylane.ops import __all__ as ops_all from pennylane.ops import measure from xdsl.dialects.builtin import DenseIntOrFPElementsAttr, IntegerAttr, IntegerType +from xdsl.dialects.scf import ForOp from xdsl.dialects.tensor import ExtractOp as TensorExtractOp from xdsl.ir import SSAValue -from catalyst.python_interface.dialects.quantum import ( +from catalyst.jit import QJIT, qjit +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath + +from ..dialects.quantum import ( CustomOp, ExtractOp, GlobalPhaseOp, @@ -44,6 +47,8 @@ if TYPE_CHECKING: from pennylane.measurements import MeasurementProcess + from pennylane.workflow.qnode import QNode + from xdsl.dialects.builtin import ModuleOp has_jax = True try: @@ -52,6 +57,24 @@ has_jax = False +def get_mlir_module(qnode: QNode | QJIT, args, kwargs) -> ModuleOp: + """Ensure the QNode is compiled and return its MLIR module.""" + if hasattr(qnode, "mlir_module") and qnode.mlir_module is not None: + return qnode.mlir_module + + if isinstance(qnode, QJIT): + compile_options = qnode.compile_options + compile_options.autograph = False # Autograph has already been applied for `user_function` + compile_options.pass_plugins.add(getXDSLPluginAbsolutePath()) + + jitted_qnode = QJIT(qnode.user_function, compile_options) + else: + jitted_qnode = qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(qnode) + + jitted_qnode.jit_compile(args, **kwargs) + return jitted_qnode.mlir_module + + from_str_to_PL_gate = { name: getattr(ops, name) for name in ops_all @@ -146,6 +169,9 @@ def resolve_constant_params(ssa: SSAValue) -> float | int: case "arith.constant": return op.value.value.data # Catalyst + case "arith.index_cast": + return resolve_constant_params(op.input) + case "stablehlo.constant": return _extract_dense_constant_value(op) @@ -165,6 +191,24 @@ def resolve_constant_params(ssa: SSAValue) -> float | int: raise NotImplementedError(f"Cannot resolve parameters for operation: {op}") +def count_static_loop_iterations(for_op: ForOp) -> int: + """ + Calculates static loop iterations for a given ForOp. + + Requires that the loop bounds and step are constant values. + """ + + lower_bound = resolve_constant_params(for_op.lb) + upper_bound = resolve_constant_params(for_op.ub) + step = resolve_constant_params(for_op.step) + + if upper_bound <= lower_bound: + return 0 + + num_elements = upper_bound - lower_bound + return (num_elements + step - 1) // step + + def dispatch_wires_extract(op: ExtractOp): """Dispatch the wire resolution for the given extract operation.""" if op.idx_attr is not None: # used by Catalyst diff --git a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py similarity index 99% rename from frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py rename to frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py index 26fa8d47ea..9ebaf92bb1 100644 --- a/frontend/test/pytest/python_interface/visualization/test_draw_unified_compiler.py +++ b/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for the draw function in the Python Compiler visualization module.""" +"""Unit test module for the draw function in the Python Compiler inspection module.""" import pytest @@ -23,16 +23,16 @@ import pennylane as qml from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.inspection import draw from catalyst.python_interface.transforms import ( iterative_cancel_inverses_pass, merge_rotations_pass, ) -from catalyst.python_interface.visualization import draw @pytest.mark.usefixtures("use_capture") class Testdraw: - """Unit tests for the draw function in the Python Compiler visualization module.""" + """Unit tests for the draw function in the Python Compiler inspection module.""" @pytest.fixture def transforms_circuit(self): diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py b/frontend/test/pytest/python_interface/inspection/test_mlir_graph.py similarity index 99% rename from frontend/test/pytest/python_interface/visualization/test_mlir_graph.py rename to frontend/test/pytest/python_interface/inspection/test_mlir_graph.py index 9753bcf88a..9b5a485ec6 100644 --- a/frontend/test/pytest/python_interface/visualization/test_mlir_graph.py +++ b/frontend/test/pytest/python_interface/inspection/test_mlir_graph.py @@ -23,11 +23,11 @@ import pennylane as qml from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from catalyst.python_interface.inspection import generate_mlir_graph from catalyst.python_interface.transforms import ( iterative_cancel_inverses_pass, merge_rotations_pass, ) -from catalyst.python_interface.visualization import generate_mlir_graph @pytest.fixture(autouse=True) From a69c9ba435822125f2cf280c4131a9705732656d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 1 Dec 2025 16:23:54 -0500 Subject: [PATCH 37/38] Remove more references to python_compiler --- frontend/catalyst/compiler.py | 14 +++++++------- frontend/catalyst/python_interface/__init__.py | 2 +- .../catalyst/python_interface/dialects/catalyst.py | 2 +- .../catalyst/python_interface/dialects/mbqc.py | 2 +- frontend/catalyst/python_interface/dialects/qec.py | 2 +- .../doc/unified_compiler_cookbook.rst | 9 ++++----- .../python_interface/doc/xdsl_utils_tutorial.rst | 4 ++-- .../dialects/test_stablehlo_dialect.py | 2 +- .../inspection/test_draw_unified_compiler.py | 4 ++-- .../python_interface/test_python_compiler.py | 6 +++--- .../mbqc/test_xdsl_decompose_graph_state.py | 2 +- .../quantum/test_xdsl_measurements_from_samples.py | 2 +- frontend/test/pytest/test_xdsl_passes.py | 2 +- 13 files changed, 26 insertions(+), 27 deletions(-) diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 6790b5af31..d1796902c3 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -381,12 +381,12 @@ def to_llvmir(*args, stdin=None, options: Optional[CompileOptions] = None): def to_mlir_opt( - *args, stdin=None, options: Optional[CompileOptions] = None, using_unified_compiler=False + *args, stdin=None, options: Optional[CompileOptions] = None, using_python_compiler=False ): """echo ${input} | catalyst --tool=opt *args *opts -""" - # Check if we need to use Python compiler for xDSL passes - if using_unified_compiler: - # Use Python compiler path for xDSL passes + # Check if we need to use the Python interface for xDSL passes + if using_python_compiler: + # Use the Python interface path for xDSL passes # pylint: disable-next=import-outside-toplevel from catalyst.python_interface import Compiler as PythonCompiler @@ -547,8 +547,8 @@ def check_nested_operations(op): return False @debug_logger - def is_using_unified_compiler(self, mlir_module=None): - """Returns true if we need the Python compiler path. + def is_using_python_compiler(self, mlir_module=None): + """Returns true if we need the Python interface path. This happens when: 1. xDSL plugin is explicitly loaded (legacy), OR @@ -606,7 +606,7 @@ def _create_xdsl_pass_save_callback(self, workspace): os.makedirs(user_transform_dir, exist_ok=True) class SavePassIRCallback: - """Callback to save IR after each pass in python_compiler.""" + """Callback to save IR after each pass in python_interface.""" def __init__(self, transform_dir): self.transform_dir = transform_dir diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py index d2dc368784..5eda3e4454 100644 --- a/frontend/catalyst/python_interface/__init__.py +++ b/frontend/catalyst/python_interface/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Python Compiler API for integration of Catalyst with xDSL.""" +"""Unified Compiler API for integration of Catalyst with xDSL.""" from .compiler import Compiler from .inspection import QMLCollector diff --git a/frontend/catalyst/python_interface/dialects/catalyst.py b/frontend/catalyst/python_interface/dialects/catalyst.py index b42b256c37..a1cf627272 100644 --- a/frontend/catalyst/python_interface/dialects/catalyst.py +++ b/frontend/catalyst/python_interface/dialects/catalyst.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This file contains the Catalyst dialect for the Python compiler. +This file contains the Catalyst dialect for the unified compiler. This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool) and modified manually to support the unified compiler. diff --git a/frontend/catalyst/python_interface/dialects/mbqc.py b/frontend/catalyst/python_interface/dialects/mbqc.py index b608e037fb..f08ae517d7 100644 --- a/frontend/catalyst/python_interface/dialects/mbqc.py +++ b/frontend/catalyst/python_interface/dialects/mbqc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This module contains the definition of the MBQC dialect for the Python compiler. +This module contains the definition of the MBQC dialect for the unified compiler. The MBQC dialect is a set of operations and types used to represent measurement-based quantum-computing instructions in the xDSL framework. diff --git a/frontend/catalyst/python_interface/dialects/qec.py b/frontend/catalyst/python_interface/dialects/qec.py index be5ba288f3..b88e05203f 100644 --- a/frontend/catalyst/python_interface/dialects/qec.py +++ b/frontend/catalyst/python_interface/dialects/qec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This module contains the QEC dialect for the Python compiler. +This module contains the QEC dialect for the unified compiler. The QEC dialect is a set of operations and types used to represent quantum error correction instructions in the xDSL framework. diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst index 01f3295fa5..25e2d23c57 100644 --- a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst +++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst @@ -1060,7 +1060,7 @@ Useful patterns =============== Now that we have gone over compilers, xDSL, and how it’s being used in -PennyLane, let’s take a look at some common patterns that might be +Catalyst, let’s take a look at some common patterns that might be useful. Post-processing functions @@ -1102,7 +1102,7 @@ I’ll use tapes to provide details below about some common cases: All of the above are very non-trivial. I will leave out code examples for now, as that may be unnecessarily time consuming. If we get to the stage where we need to write a transform that splits into multiple -tapes, we can revisit this section and the Python compiler/compilation +tapes, we can revisit this section and the unified compiler/compilation team can assist in developing such transforms. Note @@ -1119,9 +1119,8 @@ Writing tests ============= **Note to readers**: this section is written based on how the testing -infrastructure for the Python compiler exists in PennyLane. However, the -Python compiler may be getting moved to Catalyst, in which case, the -infrastructure would likely change. +infrastructure for the unified compiler exists in Catalyst currently. +However, this infrastructure may change in the future. FileCheck --------- diff --git a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst index 8c0163a3de..2e39c07404 100644 --- a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst +++ b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst @@ -1,4 +1,4 @@ -Python compiler utilities +Unified compiler utilities ========================= All utilities we care about are in the @@ -343,7 +343,7 @@ have been inlined into ``qjit_mod2``, just like we wanted. One might wonder why we need both ``inline_jit_to_module`` and ``inline_module``. At this stage, the intention is just to enable as -much functionality for users of the Python compiler as possible. There +much functionality for users of the unified compiler as possible. There may be use cases where users may want to create their own module manually instead of having one be created automatically by JAX or Catalyst. diff --git a/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py index c63ed80e44..f7a121b930 100644 --- a/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py +++ b/frontend/test/pytest/python_interface/dialects/test_stablehlo_dialect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for pennylane/compiler/python_compiler/dialects/stablehlo.py.""" +"""Unit test module for catalyst/python_interface/dialects/stablehlo.py.""" # pylint: disable=line-too-long import pytest diff --git a/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py index 9ebaf92bb1..f344671745 100644 --- a/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py +++ b/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test module for the draw function in the Python Compiler inspection module.""" +"""Unit test module for the draw function in the unified compiler inspection module.""" import pytest @@ -32,7 +32,7 @@ @pytest.mark.usefixtures("use_capture") class Testdraw: - """Unit tests for the draw function in the Python Compiler inspection module.""" + """Unit tests for the draw function in the unified compiler inspection module.""" @pytest.fixture def transforms_circuit(self): diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_python_compiler.py index 86b8b06f84..996e560064 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_python_compiler.py @@ -261,7 +261,7 @@ def program(): class TestCatalystIntegration: - """Tests for integration of the Python compiler with Catalyst""" + """Tests for integration of the unified compiler with Catalyst""" @pytest.mark.usefixtures("use_capture") def test_integration_catalyst_no_passes_with_capture(self): @@ -333,7 +333,7 @@ def f(x): @pytest.mark.usefixtures("use_capture") def test_integration_catalyst_mixed_passes_with_capture(self, capsys): - """Test that both Catalyst and Python compiler passes can be used with qjit + """Test that both Catalyst and unified compiler passes can be used with qjit when capture is enabled.""" assert capture_enabled() @@ -354,7 +354,7 @@ def f(x): assert captured.out.strip() == "hello world" def test_integration_catalyst_mixed_passes_no_capture(self, capsys): - """Test that both Catalyst and Python compiler passes can be used with qjit + """Test that both Catalyst and unified compiler passes can be used with qjit when capture is disabled.""" assert not capture_enabled() diff --git a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py index 9b48c93835..069b874b15 100644 --- a/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py +++ b/frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_decompose_graph_state.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit and integration tests for the Python compiler `decompose-graph-state` transform. +"""Unit and integration tests for the unified compiler `decompose-graph-state` transform. FileCheck notation hint: diff --git a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py index 8889bb08ac..1ca411d342 100644 --- a/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py +++ b/frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_measurements_from_samples.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit and integration tests for the Python compiler `measurements_from_samples` transform.""" +"""Unit and integration tests for the unified compiler `measurements_from_samples` transform.""" # pylint: disable=wrong-import-position,line-too-long diff --git a/frontend/test/pytest/test_xdsl_passes.py b/frontend/test/pytest/test_xdsl_passes.py index 2490b80f48..598960f191 100644 --- a/frontend/test/pytest/test_xdsl_passes.py +++ b/frontend/test/pytest/test_xdsl_passes.py @@ -169,7 +169,7 @@ class TestXDSLPassesIntegration: def test_xdsl_passes_integration(self): """Test the xDSL passes integration.""" # pylint: disable-next=import-outside-toplevel - from pennylane.compiler.python_compiler.transforms import merge_rotations_pass + from catalyst.python_interface.transforms import merge_rotations_pass @qjit(keep_intermediate="changed", verbose=True) def workflow(): From 9d208f92dcc3d173cf03c7920a07113cc0b3621b Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 1 Dec 2025 16:27:54 -0500 Subject: [PATCH 38/38] Address more code review comments --- frontend/catalyst/compiler.py | 8 ++++---- frontend/test/pytest/device/test_decomposition.py | 3 --- .../inspection/test_draw_unified_compiler.py | 1 - .../{test_python_compiler.py => test_unified_compiler.py} | 4 ++-- 4 files changed, 6 insertions(+), 10 deletions(-) rename frontend/test/pytest/python_interface/{test_python_compiler.py => test_unified_compiler.py} (99%) diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index d1796902c3..9b56ee8fff 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -388,9 +388,9 @@ def to_mlir_opt( if using_python_compiler: # Use the Python interface path for xDSL passes # pylint: disable-next=import-outside-toplevel - from catalyst.python_interface import Compiler as PythonCompiler + from catalyst.python_interface import Compiler as UnifiedCompiler - compiler = PythonCompiler() + compiler = UnifiedCompiler() stdin = compiler.run(stdin, callback=None) # These are the options that may affect compilation @@ -666,10 +666,10 @@ def run(self, mlir_module, *args, **kwargs): # We keep this module here to keep xDSL requirement optional # Only move this is it has been decided that xDSL is no longer optional. # pylint: disable-next=import-outside-toplevel - from catalyst.python_interface import Compiler as PythonCompiler + from catalyst.python_interface import Compiler as UnifiedCompiler callback = self._create_xdsl_pass_save_callback(workspace) - compiler = PythonCompiler() + compiler = UnifiedCompiler() ir = compiler.run(ir, callback=callback) return self.run_from_ir( diff --git a/frontend/test/pytest/device/test_decomposition.py b/frontend/test/pytest/device/test_decomposition.py index 5f3594de1f..fe315edad0 100644 --- a/frontend/test/pytest/device/test_decomposition.py +++ b/frontend/test/pytest/device/test_decomposition.py @@ -26,9 +26,6 @@ from catalyst.compiler import get_lib_path from catalyst.device.decomposition import catalyst_decomposer -# TEST_PATH = os.path.dirname(__file__) -# CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../../custom_device/custom_device.toml") - class TestGateAliases: """Test the decomposition of gates wich are in fact supported via aliased or equivalent diff --git a/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py b/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py index f344671745..81da302f34 100644 --- a/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py +++ b/frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py @@ -376,7 +376,6 @@ def circuit(): assert drawing == expected_drawing - @pytest.mark.jax @pytest.mark.parametrize( "ops, expected", [ diff --git a/frontend/test/pytest/python_interface/test_python_compiler.py b/frontend/test/pytest/python_interface/test_unified_compiler.py similarity index 99% rename from frontend/test/pytest/python_interface/test_python_compiler.py rename to frontend/test/pytest/python_interface/test_unified_compiler.py index 996e560064..13398fe3dd 100644 --- a/frontend/test/pytest/python_interface/test_python_compiler.py +++ b/frontend/test/pytest/python_interface/test_unified_compiler.py @@ -333,7 +333,7 @@ def f(x): @pytest.mark.usefixtures("use_capture") def test_integration_catalyst_mixed_passes_with_capture(self, capsys): - """Test that both Catalyst and unified compiler passes can be used with qjit + """Test that both MLIR and xDSL passes can be used with qjit when capture is enabled.""" assert capture_enabled() @@ -354,7 +354,7 @@ def f(x): assert captured.out.strip() == "hello world" def test_integration_catalyst_mixed_passes_no_capture(self, capsys): - """Test that both Catalyst and unified compiler passes can be used with qjit + """Test that both MLIR and xDSL passes can be used with qjit when capture is disabled.""" assert not capture_enabled()