diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml
index 999e5ad137..d622074d73 100644
--- a/.github/workflows/check-catalyst.yaml
+++ b/.github/workflows/check-catalyst.yaml
@@ -473,6 +473,9 @@ jobs:
# macOS requirements.txt
python3 -m pip install cuda-quantum==0.6.0
python3 -m pip install oqc-qcaas-client
+ # Install graphviz for testing the mlir-op-graph integration
+ sudo apt-get install -y graphviz
+ python3 -m pip install graphviz
make frontend
- name: Get Cached LLVM Build
@@ -556,6 +559,9 @@ jobs:
sudo apt-get install -y libasan6 make
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
python3 -m pip install -r requirements.txt
+ # Install graphviz for testing the mlir-op-graph integration
+ sudo apt-get install -y graphviz
+ python3 -m pip install graphviz
make frontend
- name: Get Cached LLVM Build
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 8e09946363..b135975266 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -7,6 +7,48 @@
Improvements π
+* Add an experimental `outline_state_evolution_pass` xDSL pass to `catalyst.python_interface.transforms`,
+ which moves all quantum gate operations to a private callable.
+ [(#8367)](https://github.com/PennyLaneAI/pennylane/pull/8367)
+
+* A new experimental `split_non_commuting_pass` compiler pass has been added to
+ `catalyst.python_interface.transforms`. This pass splits quantum functions that
+ measure observables on the same wires into multiple function executions, where
+ each execution measures observables on different wires (using the "wires" grouping
+ strategy). The original function is replaced with calls to these generated functions,
+ and the results are combined appropriately.
+ [(#8531)](https://github.com/PennyLaneAI/pennylane/pull/8531)
+
+* Add the `PCPhaseOp` operation to the xDSL Quantum dialect.
+ [(#8621)](https://github.com/PennyLaneAI/pennylane/pull/8621)
+
+* Users can now apply xDSL passes without the need to pass the `pass_plugins` argument to
+ the `qjit` decorator.
+ [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572)
+ [(#8573)](https://github.com/PennyLaneAI/pennylane/pull/8573)
+ [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169)
+ [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183)
+
+* The :meth:`catalyst.python_interface.transforms.convert_to_mbqc_formalism_pass` now
+ supports :class:`~xdsl.dialects.scf.IndexSwitchOp` in IR and ignores regions that have no body.
+ [(#8632)](https://github.com/PennyLaneAI/pennylane/pull/8632)
+
+* The `convert_to_mbqc_formalism` compilation pass now outlines the operations to represent a gate
+ in the MBQC formalism into subroutines in order to reduce the IR size for large programs.
+ [(#8619)](https://github.com/PennyLaneAI/pennylane/pull/8619)
+
+* The :meth:`catalyst.python_interface.Compiler.run` method now accepts a string as input,
+ which is parsed and transformed with xDSL.
+ [(#8587)](https://github.com/PennyLaneAI/pennylane/pull/8587)
+
+* An `is_xdsl_pass` function has been added to the `catalyst.python_interface.pass_api` module.
+ This function checks if a pass name corresponds to an xDSL implemented pass.
+ [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572)
+
+* A new `catalyst.python_interface.utils` submodule has been added, containing general-purpose utilities for
+ working with xDSL. This includes a function that extracts the concrete value of scalar, constant SSA values.
+ [(#8514)](https://github.com/PennyLaneAI/pennylane/pull/8514)
+
* Pass instrumentation can be applied to each pass within the `NamedSequenceOp` transform sequence for a qnode.
[(#1978)](https://github.com/PennyLaneAI/catalyst/pull/1978)
@@ -35,11 +77,6 @@
* `qml.grad` and `qml.jacobian` can now be used with `qjit` when program capture is enabled.
[(#2078)](https://github.com/PennyLaneAI/catalyst/pull/2078)
-* xDSL passes are now automatically detected when using the `qjit` decorator.
- This removes the need to pass the `pass_plugins` argument to the `qjit` decorator.
- [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169)
- [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183)
-
* The ``mlir_opt`` property now correctly handles xDSL passes by automatically
detecting when the Python compiler is being used and routing through it appropriately.
[(#2190)](https://github.com/PennyLaneAI/catalyst/pull/2190)
@@ -72,6 +109,20 @@
Bug fixes π
+* The experimental xDSL :func:`~catalyst.python_interface.transforms.measurements_from_samples_pass`
+ pass has been updated to support `shots` defined by an `arith.constant` operation.
+ [(#8460)](https://github.com/PennyLaneAI/pennylane/pull/8460)
+
+* The experimental xDSL :func:`~catalyst.python_interface.transforms.diagonalize_measurements`
+ pass has been updated to fix a bug that included the wrong SSA value for final qubit insertion
+ and deallocation at the end of the circuit. A clear error is now also raised when there are
+ observables with overlapping wires.
+ [(#8383)](https://github.com/PennyLaneAI/pennylane/pull/8383)
+
+* Fixes a bug in the constructor of the xDSL Quantum dialect's `QubitUnitaryOp` that
+ prevented an instance from being constructed.
+ [(#8456)](https://github.com/PennyLaneAI/pennylane/pull/8456)
+
* Fixes an issue where a heap-to-stack allocation conversion pass was causing SIGSEGV issues
during program execution at runtime.
[(#2172)](https://github.com/PennyLaneAI/catalyst/pull/2172)
@@ -122,6 +173,10 @@
Internal changes βοΈ
+* Migrated the `pennylane.compiler.python_compiler` submodule from PennyLane to Catalyst.
+ It is now accessible as `catalyst.python_interface`.
+ [(#2199)](https://github.com/PennyLaneAI/catalyst/pull/2199)
+
* Replaces the deprecated `shape_dtype_to_ir_type` function with the `RankedTensorType.get` method.
[(#2159)](https://github.com/PennyLaneAI/catalyst/pull/2159)
@@ -173,6 +228,11 @@
Documentation π
+* Added a "Unified Compiler Cookbook" RST file, along with tutorials, to `catalyst.python_interface.doc`,
+ which provides a quickstart guide for getting started with xDSL and its integration with PennyLane and
+ Catalyst.
+ [(#8571)](https://github.com/PennyLaneAI/pennylane/pull/8571)
+
* A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected.
[(#2136)](https://github.com/PennyLaneAI/catalyst/pull/2136)
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index a61f4ccfe3..71537e8e53 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -376,14 +376,14 @@ def to_llvmir(*args, stdin=None, options: Optional[CompileOptions] = None):
def to_mlir_opt(
- *args, stdin=None, options: Optional[CompileOptions] = None, using_python_compiler=False
+ *args, stdin=None, options: Optional[CompileOptions] = None, using_unified_compiler=False
):
"""echo ${input} | catalyst --tool=opt *args *opts -"""
# Check if we need to use Python compiler for xDSL passes
- if using_python_compiler:
+ if using_unified_compiler:
# Use Python compiler path for xDSL passes
# pylint: disable-next=import-outside-toplevel
- from pennylane.compiler.python_compiler import Compiler as PythonCompiler
+ from catalyst.python_interface import Compiler as PythonCompiler
compiler = PythonCompiler()
stdin = compiler.run(stdin, callback=None)
@@ -542,7 +542,7 @@ def check_nested_operations(op):
return False
@debug_logger
- def is_using_python_compiler(self, mlir_module=None):
+ def is_using_unified_compiler(self, mlir_module=None):
"""Returns true if we need the Python compiler path.
This happens when:
@@ -598,11 +598,11 @@ def run(self, mlir_module, *args, **kwargs):
(str): filename of shared object
"""
- if self.is_using_python_compiler(mlir_module):
+ if self.is_using_unified_compiler(mlir_module):
# We keep this module here to keep xDSL requirement optional
# Only move this is it has been decided that xDSL is no longer optional.
# pylint: disable-next=import-outside-toplevel
- from pennylane.compiler.python_compiler import Compiler as PythonCompiler
+ from catalyst.python_interface import Compiler as PythonCompiler
compiler = PythonCompiler()
mlir_module = compiler.run(mlir_module)
diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py
index 8d7e27acfe..749a8b3125 100644
--- a/frontend/catalyst/jax_primitives_utils.py
+++ b/frontend/catalyst/jax_primitives_utils.py
@@ -378,9 +378,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
try:
# pylint: disable=import-outside-toplevel
- from pennylane.compiler.python_compiler.pass_api import (
- is_xdsl_pass,
- )
+ from catalyst.python_interface.pass_api import is_xdsl_pass
if is_xdsl_pass(_pass.name):
uses_xdsl_passes = True
diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py
index 81f873917c..62b0acbde2 100644
--- a/frontend/catalyst/jit.py
+++ b/frontend/catalyst/jit.py
@@ -582,13 +582,13 @@ def mlir_opt(self):
"""Obtain the MLIR representation after optimization"""
if not self.mlir_module:
return None
- using_python_compiler = self.compiler.is_using_python_compiler(self.mlir_module)
+ using_unified_compiler = self.compiler.is_using_unified_compiler(self.mlir_module)
stdin = self.mlir_module.operation.get_asm(
- print_generic_op_form=using_python_compiler,
+ print_generic_op_form=using_unified_compiler,
enable_debug_info=self.compile_options.use_nameloc,
)
return to_mlir_opt(
- stdin=stdin, options=self.compile_options, using_python_compiler=using_python_compiler
+ stdin=stdin, options=self.compile_options, using_unified_compiler=using_unified_compiler
)
@debug_logger
diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py
new file mode 100644
index 0000000000..85fb8ceef5
--- /dev/null
+++ b/frontend/catalyst/python_interface/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Python Compiler API for integration of Catalyst with xDSL."""
+
+from .compiler import Compiler
+from .parser import QuantumParser
+from .pass_api import compiler_transform
+from .visualization import QMLCollector
+
+__all__ = [
+ "Compiler",
+ "compiler_transform",
+ "QuantumParser",
+ "QMLCollector",
+]
diff --git a/frontend/catalyst/python_interface/compiler.py b/frontend/catalyst/python_interface/compiler.py
new file mode 100644
index 0000000000..f98ff2b8d8
--- /dev/null
+++ b/frontend/catalyst/python_interface/compiler.py
@@ -0,0 +1,83 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file contains the implementation of the PennyLane-xDSL integration API."""
+
+
+import io
+
+from jax._src.interpreters import mlir
+from jaxlib.mlir.dialects import stablehlo
+from jaxlib.mlir.ir import Context as jaxContext
+from jaxlib.mlir.ir import Module as jaxModule
+from pennylane.typing import Callable
+from xdsl.context import Context as xContext
+from xdsl.dialects.builtin import ModuleOp
+from xdsl.passes import ModulePass, PassPipeline
+from xdsl.printer import Printer
+
+from catalyst.python_interface.parser import QuantumParser
+from catalyst.python_interface.pass_api import ApplyTransformSequence
+
+
+# pylint: disable=too-few-public-methods
+class Compiler:
+ """Compiler namespace"""
+
+ @staticmethod
+ def run(
+ module: jaxModule | str,
+ callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = None,
+ ) -> jaxModule | str:
+ """Runs the apply-transform-sequence pass.
+
+ The apply-transform-sequence pass is a "meta-pass". In other words,
+ it is a pass that runs other passes.
+
+ Args:
+ module: Either a Jax MLIR module or MLIR IR as a string
+ callback: Optional callback function called between passes
+
+ Returns:
+ jaxModule | str: jaxModule if the input was a jaxModule, else a string.
+ """
+ # Convert to generic text format
+ is_jax_module = isinstance(module, jaxModule)
+ if is_jax_module:
+ gentxtmod = module.operation.get_asm(
+ binary=False, print_generic_op_form=True, assume_verified=True
+ )
+ else:
+ gentxtmod = module
+
+ # Parse and transform with xDSL
+ ctx = xContext(allow_unregistered=True)
+ parser = QuantumParser(ctx, gentxtmod)
+ # xmod is modified in place
+ xmod = parser.parse_module()
+ pipeline = PassPipeline((ApplyTransformSequence(callback=callback),))
+ pipeline.apply(ctx, xmod)
+
+ # Convert back to string
+ buffer = io.StringIO()
+ Printer(stream=buffer, print_generic_format=True).print_op(xmod)
+
+ # Convert back to jaxModule if input was jaxModule
+ if is_jax_module:
+ with jaxContext() as jctx:
+ jctx.allow_unregistered_dialects = True
+ jctx.append_dialect_registry(mlir.upstream_dialects)
+ stablehlo.register_dialect(jctx) # pylint: disable=no-member
+ newmod: jaxModule = jaxModule.parse(buffer.getvalue())
+ return newmod
+ return buffer.getvalue()
diff --git a/frontend/catalyst/python_interface/conversion.py b/frontend/catalyst/python_interface/conversion.py
new file mode 100644
index 0000000000..975a46b3cb
--- /dev/null
+++ b/frontend/catalyst/python_interface/conversion.py
@@ -0,0 +1,171 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for converting to xDSL module."""
+
+from collections.abc import Callable, Sequence
+from functools import wraps
+from typing import TypeAlias
+
+from jax._src.lib import _jax
+from jaxlib.mlir.dialects import stablehlo as jstablehlo
+from jaxlib.mlir.ir import Context as jContext
+from jaxlib.mlir.ir import Module as jModule
+from xdsl.context import Context as xContext
+from xdsl.dialects import builtin as xbuiltin
+from xdsl.dialects import func as xfunc
+from xdsl.ir import Dialect as xDialect
+from xdsl.traits import SymbolTable as xSymbolTable
+
+from catalyst import QJIT
+from catalyst.python_interface.parser import QuantumParser
+
+JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member
+
+
+def _mlir_module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule:
+ """Get the MLIR module from a jax.jitted function"""
+ return func.lower(*args, **kwargs).compiler_ir()
+
+
+def mlir_module(func: JaxJittedFunction) -> Callable[..., jModule]:
+ """Returns a wrapper that creates an MLIR module from a jax.jitted function."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> jModule:
+ return _mlir_module_inline(func, *args, **kwargs)
+
+ return wrapper
+
+
+def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover
+ """Create the generic textual representation for a jax.jitted function"""
+ lowered = func.lower(*args, **kwargs)
+ mod = lowered.compiler_ir()
+ return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)
+
+
+def generic_str(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover
+ """Returns a wrapper that creates the generic textual representation for a
+ jax.jitted function."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> str:
+ return _generic_str_inline(func, *args, **kwargs)
+
+ return wrapper
+
+
+def parse_generic_to_xdsl_module(
+ program: str, extra_dialects: Sequence[xDialect] | None = None
+) -> xbuiltin.ModuleOp: # pragma: no cover
+ """Parses a generic MLIR program string to an xDSL module."""
+ ctx = xContext(allow_unregistered=True)
+ parser = QuantumParser(ctx, program, extra_dialects=extra_dialects)
+ moduleOp: xbuiltin.ModuleOp = parser.parse_module()
+ return moduleOp
+
+
+def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover
+ """Parses a generic MLIR program string to an MLIR module."""
+ with jContext() as ctx:
+ ctx.allow_unregistered_dialects = True
+ jstablehlo.register_dialect(ctx) # pylint: disable=no-member
+ return jModule.parse(program)
+
+
+def mlir_from_docstring(func: Callable) -> jModule: # pragma: no cover
+ """Returns a wrapper that parses an MLIR program string located in the docstring
+ into an MLIR module."""
+
+ @wraps(func)
+ def wrapper(*_, **__):
+ return parse_generic_to_mlir_module(func.__doc__)
+
+ return wrapper
+
+
+def _xdsl_module_inline(
+ func: JaxJittedFunction, *args, **kwargs
+) -> xbuiltin.ModuleOp: # pragma: no cover
+ """Get the xDSL module from a jax.jitted function"""
+ generic_repr = _generic_str_inline(func, *args, **kwargs)
+ return parse_generic_to_xdsl_module(generic_repr)
+
+
+def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover
+ """Returns a wrapper that parses an MLIR program string located in the docstring
+ into an xDSL module."""
+
+ @wraps(func)
+ def wrapper(*_, **__):
+ return parse_generic_to_xdsl_module(func.__doc__)
+
+ return wrapper
+
+
+def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover
+ """Returns a wrapper that creates an xDSL module from a jax.jitted function."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp:
+ return _xdsl_module_inline(func, *args, **kwargs)
+
+ return wrapper
+
+
+def inline_module(
+ from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None
+) -> None:
+ """Inline the contents of one xDSL module into another xDSL module. The inlined body is appended
+ to the end of ``to_mod``.
+
+ If ``from_mod`` has a ``main`` function, its name is changed to ``change_main_to`` if specified.
+ """
+ if change_main_to:
+ main = xSymbolTable.lookup_symbol(from_mod, "main")
+ if main is not None:
+ assert isinstance(main, xfunc.FuncOp)
+ main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to)
+
+ for op in from_mod.body.ops:
+ xSymbolTable.insert_or_update(to_mod, op.clone())
+
+
+def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp) -> Callable[..., None]:
+ """Inline a ``jax.jit``-ed Python function to an xDSL module. The inlined body is appended
+ to the end of ``mod`` in-place. The name of the entry point function of ``func`` is the same
+ as the name of ``func``."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func_mod = _xdsl_module_inline(func, *args, **kwargs)
+ inline_module(func_mod, mod, change_main_to=func.__name__)
+
+ return wrapper
+
+
+def xdsl_from_qjit(func: QJIT) -> Callable[..., xbuiltin.ModuleOp]:
+ """Decorator to convert QJIT-ed functions into xDSL modules."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func.jaxpr, *_ = func.capture(args, **kwargs)
+ _mlir_module = func.generate_ir()
+ _generic_str = _mlir_module.operation.get_asm(
+ binary=False, print_generic_op_form=True, assume_verified=True
+ )
+ return parse_generic_to_xdsl_module(_generic_str)
+
+ return wrapper
diff --git a/frontend/catalyst/python_interface/dialects/__init__.py b/frontend/catalyst/python_interface/dialects/__init__.py
new file mode 100644
index 0000000000..1b0ab7d1a9
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This submodule contains xDSL dialects for the unified compiler."""
+
+from .catalyst import Catalyst
+from .mbqc import MBQC
+from .qec import QEC
+from .quantum import Quantum
+from .stablehlo import StableHLO
+from .transform import Transform
+
+__all__ = ["Catalyst", "MBQC", "Quantum", "QEC", "StableHLO", "Transform"]
diff --git a/frontend/catalyst/python_interface/dialects/catalyst.py b/frontend/catalyst/python_interface/dialects/catalyst.py
new file mode 100644
index 0000000000..b42b256c37
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/catalyst.py
@@ -0,0 +1,268 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This file contains the Catalyst dialect for the Python compiler.
+
+This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool)
+and modified manually to support the unified compiler.
+
+The catalyst dialect serves as a standard library for the Catalyst compiler.
+It contains data structures that support core compiler functionality.
+"""
+
+from typing import ClassVar
+
+from xdsl.dialects.builtin import (
+ I64,
+ ArrayAttr,
+ DenseArrayBase,
+ DictionaryAttr,
+ FlatSymbolRefAttrConstr,
+ FunctionType,
+ IntegerAttr,
+ IntegerType,
+ MemRefType,
+ StringAttr,
+ SymbolRefAttr,
+ UnitAttr,
+ i32,
+)
+from xdsl.ir import AttributeCovT, Dialect, Generic, ParametrizedAttribute, TypeAttribute
+from xdsl.irdl import (
+ AnyAttr,
+ IRDLOperation,
+ ParsePropInAttrDict,
+ VarConstraint,
+ base,
+ irdl_attr_definition,
+ irdl_op_definition,
+ operand_def,
+ opt_operand_def,
+ opt_prop_def,
+ prop_def,
+ region_def,
+ result_def,
+ var_operand_def,
+ var_result_def,
+)
+
+
+@irdl_attr_definition
+class ArrayListType(Generic[AttributeCovT], ParametrizedAttribute, TypeAttribute):
+ """A dynamically resizable array"""
+
+ name = "catalyst.arraylist"
+
+ element_type: AttributeCovT
+
+
+@irdl_op_definition
+class AssertionOp(IRDLOperation):
+ """Asserts condition at runtime."""
+
+ name = "catalyst.assert"
+
+ assertion = operand_def(IntegerType(1))
+
+ error = prop_def(StringAttr)
+
+
+@irdl_op_definition
+class CallbackCallOp(IRDLOperation):
+ """CallbackCallOp operation."""
+
+ name = "catalyst.callback_call"
+
+ assembly_format = """
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ """
+
+ callee = prop_def(FlatSymbolRefAttrConstr)
+
+ inputs = var_operand_def()
+
+ arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ callback_results = var_result_def()
+
+ irdl_options = [ParsePropInAttrDict()]
+
+
+@irdl_op_definition
+class CallbackOp(IRDLOperation):
+ """Operation denoting a symbol to refer to user callbacks."""
+
+ name = "catalyst.callback"
+
+ sym_name = prop_def(StringAttr)
+
+ function_type = prop_def(FunctionType)
+
+ id = prop_def(IntegerAttr[I64])
+
+ argc = prop_def(IntegerAttr[I64])
+
+ resc = prop_def(IntegerAttr[I64])
+
+ arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ body = region_def()
+
+
+@irdl_op_definition
+class CustomCallOp(IRDLOperation):
+ """CustomCall operation"""
+
+ name = "catalyst.custom_call"
+
+ assembly_format = """
+ `fn` `(`$call_target_name`)` `(` $inputs `)`
+ attr-dict `:` functional-type(operands, results)
+ """
+
+ inputs = var_operand_def()
+
+ call_target_name = prop_def(StringAttr)
+
+ number_original_arg = opt_prop_def(DenseArrayBase.constr(i32))
+
+ custom_results = var_result_def()
+
+ irdl_options = [ParsePropInAttrDict()]
+
+
+@irdl_op_definition
+class LaunchKernelOp(IRDLOperation):
+ """LaunchKernelOp operation."""
+
+ name = "catalyst.launch_kernel"
+
+ assembly_format = """
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ """
+
+ callee = prop_def(SymbolRefAttr)
+
+ inputs = var_operand_def()
+
+ arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ kernel_results = var_result_def()
+
+ irdl_options = [ParsePropInAttrDict()]
+
+
+@irdl_op_definition
+class ListDeallocOp(IRDLOperation):
+ """Deallocate the underlying memory of an arraylist."""
+
+ name = "catalyst.list_dealloc"
+
+ assembly_format = """ $list attr-dict `:` type($list) """
+
+ list = operand_def(ArrayListType)
+
+
+@irdl_op_definition
+class ListInitOp(IRDLOperation):
+ """Initialize a dynamically resizable arraylist."""
+
+ name = "catalyst.list_init"
+
+ assembly_format = """ attr-dict `:` type($list) """
+
+ list = result_def(ArrayListType)
+
+
+@irdl_op_definition
+class ListLoadDataOp(IRDLOperation):
+ """Get the underlying memref storing the data of an array list."""
+
+ name = "catalyst.list_load_data"
+
+ assembly_format = """ $list attr-dict `:` type($list) `->` type($data) """
+
+ list = operand_def(ArrayListType)
+
+ data = result_def(MemRefType)
+
+
+@irdl_op_definition
+class ListPopOp(IRDLOperation):
+ """Remove an element from the end of an array list and return it."""
+
+ name = "catalyst.list_pop"
+
+ assembly_format = """ $list attr-dict `:` type($list) """
+
+ T: ClassVar = VarConstraint("T", AnyAttr())
+
+ list = operand_def(base(ArrayListType[T]))
+
+ result = result_def(T)
+
+
+@irdl_op_definition
+class ListPushOp(IRDLOperation):
+ """Append an element to the end of an array list."""
+
+ name = "catalyst.list_push"
+
+ assembly_format = """ $value `,` $list attr-dict `:` type($list) """
+
+ T: ClassVar = VarConstraint("T", AnyAttr())
+
+ value = operand_def(T)
+
+ list = operand_def(base(ArrayListType[T]))
+
+
+@irdl_op_definition
+class PrintOp(IRDLOperation):
+ """Prints numeric values or constant strings at runtime."""
+
+ name = "catalyst.print"
+
+ val = opt_operand_def()
+
+ const_val = opt_prop_def(StringAttr)
+
+ print_descriptor = prop_def(UnitAttr)
+
+
+Catalyst = Dialect(
+ "catalyst",
+ [
+ AssertionOp,
+ CallbackCallOp,
+ CallbackOp,
+ CustomCallOp,
+ LaunchKernelOp,
+ ListDeallocOp,
+ ListInitOp,
+ ListLoadDataOp,
+ ListPopOp,
+ ListPushOp,
+ PrintOp,
+ ],
+ [
+ ArrayListType,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/mbqc.py b/frontend/catalyst/python_interface/dialects/mbqc.py
new file mode 100644
index 0000000000..b608e037fb
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/mbqc.py
@@ -0,0 +1,174 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the definition of the MBQC dialect for the Python compiler.
+
+The MBQC dialect is a set of operations and types used to represent measurement-based
+quantum-computing instructions in the xDSL framework.
+
+It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the
+catalyst/mlir/include/MBQC/IR/MBQCDialect.td file in the catalyst repository.
+
+For detailed documentation on the operations contained in this dialect, please refer to the MBQC
+dialect documentation in Catalyst.
+"""
+
+from typing import TypeAlias
+
+from xdsl.dialects.builtin import (
+ I32,
+ AnyAttr,
+ Float64Type,
+ IntegerAttr,
+ IntegerType,
+ StringAttr,
+ i1,
+)
+from xdsl.ir import Dialect, EnumAttribute, Operation, SpacedOpaqueSyntaxAttribute, SSAValue
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_attr_definition,
+ irdl_op_definition,
+ operand_def,
+ opt_prop_def,
+ prop_def,
+ result_def,
+)
+from xdsl.utils.exceptions import VerifyException
+from xdsl.utils.str_enum import StrEnum # StrEnum is standard in Python>=3.11
+
+from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint
+
+from .quantum import QubitType, QuregType
+
+QubitSSAValue: TypeAlias = SSAValue[QubitType]
+
+
+class MeasurementPlaneEnum(StrEnum):
+ """Enum containing supported measurement-plane attributes"""
+
+ XY = "XY"
+ YZ = "YZ"
+ ZX = "ZX"
+
+
+@irdl_attr_definition
+class MeasurementPlaneAttr(EnumAttribute[MeasurementPlaneEnum], SpacedOpaqueSyntaxAttribute):
+ """Planes in the Bloch sphere representation with support for arbitrary-basis measurements"""
+
+ name = "mbqc.measurement_plane"
+
+
+@irdl_op_definition
+class MeasureInBasisOp(IRDLOperation):
+ """A parametric single-qubit projective measurement in an arbitrary basis."""
+
+ name = "mbqc.measure_in_basis"
+
+ assembly_format = """
+ `[` $plane `,` $angle `]` $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results)
+ """
+
+ in_qubit = operand_def(QubitType)
+
+ plane = prop_def(MeasurementPlaneAttr)
+
+ angle = operand_def(Float64Type())
+
+ postselect = opt_prop_def(IntegerAttr[I32])
+
+ mres = result_def(IntegerType(1))
+
+ out_qubit = result_def(QubitType)
+
+ def __init__(
+ self,
+ in_qubit: QubitSSAValue | Operation,
+ plane: MeasurementPlaneAttr,
+ angle: SSAValue[Float64Type],
+ postselect: int | IntegerAttr | None = None,
+ ):
+ properties = {"plane": plane}
+
+ if isinstance(postselect, int):
+ postselect = IntegerAttr.from_int_and_width(postselect, 32)
+
+ if postselect is not None:
+ properties["postselect"] = postselect
+
+ super().__init__(
+ operands=(in_qubit, angle),
+ properties=properties,
+ result_types=(IntegerType(1), QubitType()),
+ )
+
+ def verify_(self):
+ """Verify operation when rewriting."""
+ if self.postselect is None:
+ return
+
+ if self.postselect.value.data not in [0, 1]: # pylint: disable=no-member
+ raise VerifyException("'postselect' must be 0 or 1.")
+
+
+@irdl_op_definition
+class GraphStatePrepOp(IRDLOperation):
+ """Allocate resources for a new graph state."""
+
+ name = "mbqc.graph_state_prep"
+
+ assembly_format = """
+ `(` $adj_matrix `:` type($adj_matrix) `)` `[` `init` $init_op `,` `entangle` $entangle_op `]` attr-dict `:` type(results)
+ """
+
+ adj_matrix = operand_def(
+ TensorConstraint(element_type=i1, rank=1) | MemRefConstraint(element_type=i1, rank=1)
+ )
+
+ init_op = prop_def(StringAttr)
+
+ entangle_op = prop_def(StringAttr)
+
+ qreg = result_def(QuregType)
+
+ def __init__(
+ self, adj_matrix: AnyAttr, init_op: str | StringAttr, entangle_op: str | StringAttr
+ ):
+ if isinstance(init_op, str):
+ init_op = StringAttr(data=init_op)
+
+ if isinstance(entangle_op, str):
+ entangle_op = StringAttr(data=entangle_op)
+
+ properties = {"init_op": init_op, "entangle_op": entangle_op}
+
+ qreg = QuregType()
+
+ super().__init__(
+ operands=(adj_matrix,),
+ result_types=(qreg,),
+ properties=properties,
+ )
+
+
+MBQC = Dialect(
+ "mbqc",
+ [
+ MeasureInBasisOp,
+ GraphStatePrepOp,
+ ],
+ [
+ MeasurementPlaneAttr,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/qec.py b/frontend/catalyst/python_interface/dialects/qec.py
new file mode 100644
index 0000000000..be5ba288f3
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/qec.py
@@ -0,0 +1,216 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the QEC dialect for the Python compiler.
+
+The QEC dialect is a set of operations and types used to represent quantum error correction
+instructions in the xDSL framework.
+
+It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the
+catalyst/mlir/include/QEC/IR/QECDialect.td file in the catalyst repository.
+"""
+
+from xdsl.dialects.builtin import I16, ArrayAttr, IntegerAttr, IntegerType, StringAttr, i16
+from xdsl.dialects.utils import AbstractYieldOperation
+from xdsl.ir import Attribute, Dialect, EnumAttribute, SpacedOpaqueSyntaxAttribute
+from xdsl.irdl import (
+ AttrSizedOperandSegments,
+ IRDLOperation,
+ irdl_attr_definition,
+ irdl_op_definition,
+ lazy_traits_def,
+ operand_def,
+ opt_operand_def,
+ opt_prop_def,
+ prop_def,
+ region_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import HasParent, IsTerminator, Pure, SingleBlockImplicitTerminator
+from xdsl.utils.str_enum import StrEnum
+
+from .quantum import QubitType
+
+
+class LogicalInitKind(StrEnum):
+ """The initial state of a logical qubit such as |0β©, |1β©, |+β©, |ββ©, |Yβ©, |-Yβ©, |mβ©, or |mΜ
β©."""
+
+ Zero = "zero" # |0β© Non-magic state
+ One = "one" # |1β© Non-magic state
+ Plus = "plus" # |+β© = (|0β© + |1β©) / sqrt(2) Non-magic state
+ Minus = "minus" # |-β© = (|0β© - |1β©) / sqrt(2) Non-magic state
+ PlusI = "plus_i" # |Yβ© = (|0β© + i|1β©) / sqrt(2) Non-magic / Magic state
+ MinusI = "minus_i" # |-Yβ© = (|0β© - i|1β©) / sqrt(2) Non-magic / Magic state
+ Magic = "magic" # |mβ© = |0β© + e^{iΟ/4}|1β© Magic state
+ MagicConj = "magic_conj" # |mΜ
β© = |0β© + e^{-iΟ/4}|1β© Magic state
+
+
+@irdl_attr_definition
+class LogicalInit(EnumAttribute[LogicalInitKind], SpacedOpaqueSyntaxAttribute):
+ """The initial state of a logical qubit such as |0β©, |1β©, |+β©, |ββ©, |Yβ©, |-Yβ©, |mβ©, or |mΜ
β©."""
+
+ name = "qec.enum"
+
+
+# Type alias for a product of Pauli operators, aka a Pauli word.
+PauliWord = ArrayAttr[StringAttr]
+
+
+@irdl_op_definition
+class YieldOp(AbstractYieldOperation[Attribute]):
+ """Return results from a layer region"""
+
+ name = "qec.yield"
+
+ traits = lazy_traits_def(lambda: (IsTerminator(), HasParent(LayerOp), Pure()))
+
+
+@irdl_op_definition
+class FabricateOp(IRDLOperation):
+ """Fabricate axillary qubits from qubit factories."""
+
+ name = "qec.fabricate"
+
+ assembly_format = """
+ $init_state attr-dict `:` type($out_qubits)
+ """
+
+ init_state = prop_def(LogicalInit)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class LayerOp(IRDLOperation):
+ """A layer operation"""
+
+ name = "qec.layer"
+
+ initArgs = var_operand_def()
+
+ results = var_result_def()
+
+ region = region_def("single_block")
+
+ traits = traits_def(SingleBlockImplicitTerminator(YieldOp))
+
+ # TODO: add a custom parse and print for this operation
+
+
+@irdl_op_definition
+class PPMeasurementOp(IRDLOperation):
+ """Pauli Product Measurement on qubits."""
+
+ name = "qec.ppm"
+
+ assembly_format = """
+ $pauli_product (`(` $rotation_sign^ `)`)? $in_qubits (`cond` `(` $condition^ `)`)? attr-dict `:` type(results)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ pauli_product = prop_def(PauliWord)
+
+ rotation_sign = opt_prop_def(IntegerAttr[I16], default_value=IntegerAttr(1, i16))
+
+ in_qubits = var_operand_def(QubitType)
+
+ condition = opt_operand_def(IntegerType(1))
+
+ mres = result_def(IntegerType(1))
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class PPRotationOp(IRDLOperation):
+ """Pauli Product Rotation on qubits."""
+
+ name = "qec.ppr"
+
+ assembly_format = """
+ $pauli_product `(` $rotation_kind `)` $in_qubits attr-dict (`cond` `(` $condition^ `)`)? `:` type($out_qubits)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ pauli_product = prop_def(PauliWord)
+
+ rotation_kind = prop_def(IntegerAttr[IntegerType(16)])
+
+ in_qubits = var_operand_def(QubitType)
+
+ condition = opt_operand_def(IntegerType(1))
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class PrepareStateOp(IRDLOperation):
+ """Initialize existing qubits into a given state."""
+
+ name = "qec.prepare"
+
+ assembly_format = """
+ $init_state $in_qubits attr-dict `:` type($out_qubits)
+ """
+
+ init_state = prop_def(LogicalInit)
+
+ in_qubits = var_operand_def(QubitType)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class SelectPPMeasurementOp(IRDLOperation):
+ """Multiplexed Pauli product measurement."""
+
+ name = "qec.select.ppm"
+
+ assembly_format = """
+ `(` $select_switch `,` $pauli_product_0 `,` $pauli_product_1 `)` $in_qubits attr-dict `:` type(results)
+ """
+
+ select_switch = operand_def(IntegerType(1))
+
+ pauli_product_0 = prop_def(PauliWord)
+
+ pauli_product_1 = prop_def(PauliWord)
+
+ in_qubits = var_operand_def(QubitType)
+
+ mres = result_def(IntegerType(1))
+
+ out_qubits = var_result_def(QubitType)
+
+
+QEC = Dialect(
+ "qec",
+ [
+ FabricateOp,
+ LayerOp,
+ PPMeasurementOp,
+ PPRotationOp,
+ PrepareStateOp,
+ SelectPPMeasurementOp,
+ YieldOp,
+ ],
+ [
+ LogicalInit,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/quantum.py b/frontend/catalyst/python_interface/dialects/quantum.py
new file mode 100644
index 0000000000..8b549f5440
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/quantum.py
@@ -0,0 +1,1138 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This file contains the definition of the Quantum dialect for the unified compiler.
+
+The Quantum dialect is a set of operations and types used to represent quantum computations
+in the xDSL framework.
+
+It was initially generated by xDSL (using the ``xdsl-tblgen`` tool)
+starting from the catalyst/mlir/include/Quantum/IR/QuantumOps.td file in the catalyst repository.
+"""
+# pylint: disable=too-many-lines
+
+from collections.abc import Sequence
+from typing import TypeAlias
+
+from xdsl.dialects.builtin import (
+ I32,
+ I64,
+ ComplexType,
+ Float64Type,
+ FloatAttr,
+ IntegerAttr,
+ IntegerType,
+ MemRefType,
+ StringAttr,
+ TensorType,
+ UnitAttr,
+ i1,
+ i64,
+)
+from xdsl.ir import (
+ Block,
+ Dialect,
+ EnumAttribute,
+ Operation,
+ ParametrizedAttribute,
+ Region,
+ SpacedOpaqueSyntaxAttribute,
+ SSAValue,
+ StrEnum,
+ TypeAttribute,
+)
+from xdsl.irdl import (
+ AtLeast,
+ AttrSizedOperandSegments,
+ AttrSizedResultSegments,
+ IntSetConstraint,
+ IRDLOperation,
+ ParsePropInAttrDict,
+ SameVariadicResultSize,
+ irdl_attr_definition,
+ irdl_op_definition,
+ lazy_traits_def,
+ operand_def,
+ opt_operand_def,
+ opt_prop_def,
+ opt_result_def,
+ prop_def,
+ region_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import (
+ HasParent,
+ IsTerminator,
+ NoMemoryEffect,
+ Pure,
+ ReturnLike,
+ SingleBlockImplicitTerminator,
+)
+
+from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint
+
+################################################################
+######################## ATTRIBUTES ############################
+################################################################
+
+
+@irdl_attr_definition
+class ObservableType(ParametrizedAttribute, TypeAttribute):
+ """A quantum observable for use in measurements."""
+
+ name = "quantum.obs"
+
+
+@irdl_attr_definition
+class QubitType(ParametrizedAttribute, TypeAttribute):
+ """A value-semantic qubit (state)."""
+
+ name = "quantum.bit"
+
+
+@irdl_attr_definition
+class QuregType(ParametrizedAttribute, TypeAttribute):
+ """An array of value-semantic qubits (i.e. quantum register)."""
+
+ name = "quantum.reg"
+
+
+@irdl_attr_definition
+class ResultType(ParametrizedAttribute, TypeAttribute):
+ """A quantum measurement result."""
+
+ name = "quantum.res"
+
+
+class NamedObservable(StrEnum):
+ """Known named observables"""
+
+ Identity = "Identity"
+ PauliX = "PauliX"
+ PauliY = "PauliY"
+ PauliZ = "PauliZ"
+ Hadamard = "Hadamard"
+
+
+@irdl_attr_definition
+class NamedObservableAttr(EnumAttribute[NamedObservable], SpacedOpaqueSyntaxAttribute):
+ """Known named observables"""
+
+ name = "quantum.named_observable"
+
+
+################################################################
+######################## OPERATIONS ############################
+################################################################
+
+
+QubitSSAValue: TypeAlias = SSAValue[QubitType]
+QuregSSAValue: TypeAlias = SSAValue[QuregType]
+ObservableSSAValue: TypeAlias = SSAValue[ObservableType]
+
+
+@irdl_op_definition
+class AdjointOp(IRDLOperation):
+ """Calculate the adjoint of the enclosed operations"""
+
+ name = "quantum.adjoint"
+
+ assembly_format = """
+ `(` $qreg `)` attr-dict `:` type(operands) $region
+ """
+
+ qreg = operand_def(QuregType)
+
+ out_qreg = result_def(QuregType)
+
+ region = region_def("single_block")
+
+ traits = lazy_traits_def(lambda: (NoMemoryEffect(), SingleBlockImplicitTerminator(YieldOp)))
+
+ def __init__(
+ self,
+ qreg: QuregSSAValue | Operation,
+ region: Region | Sequence[Operation] | Sequence[Block],
+ ):
+ super().__init__(operands=(qreg,), result_types=(QuregType(),), regions=(region,))
+
+
+@irdl_op_definition
+class AllocOp(IRDLOperation):
+ """Allocate n qubits into a quantum register."""
+
+ name = "quantum.alloc"
+
+ assembly_format = """
+ `(` ($nqubits^):($nqubits_attr)? `)` attr-dict `:` type(results)
+ """
+
+ nqubits = opt_operand_def(i64)
+
+ nqubits_attr = opt_prop_def(IntegerAttr.constr(type=I64, value=AtLeast(0)))
+
+ qreg = result_def(QuregType)
+
+ def __init__(self, nqubits):
+ if isinstance(nqubits, int):
+ nqubits = IntegerAttr.from_int_and_width(nqubits, 64)
+
+ if isinstance(nqubits, IntegerAttr):
+ operands = (None,)
+ properties = {"nqubits_attr": nqubits}
+ else:
+ operands = (nqubits,)
+ properties = {}
+
+ super().__init__(operands=operands, properties=properties, result_types=(QuregType(),))
+
+
+@irdl_op_definition
+class AllocQubitOp(IRDLOperation):
+ """Allocate a single qubit."""
+
+ name = "quantum.alloc_qb"
+
+ assembly_format = """attr-dict `:` type(results)"""
+
+ qubit = result_def(QubitType)
+
+ def __init__(self):
+ super().__init__(
+ result_types=(QubitType(),),
+ )
+
+
+@irdl_op_definition
+class ComputationalBasisOp(IRDLOperation):
+ """Define a pseudo-obeservable of the computational basis for use in measurements"""
+
+ name = "quantum.compbasis"
+
+ assembly_format = """
+ (`qubits` $qubits^)? (`qreg` $qreg^)? attr-dict `:` type(results)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ qubits = var_operand_def(QubitType)
+
+ qreg = opt_operand_def(QuregType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class CountsOp(IRDLOperation):
+ """Compute sample counts for the given observable for the current state"""
+
+ name = "quantum.counts"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $in_eigvals^ `:` type($in_eigvals) `,` $in_counts `:` type($in_counts) `)` )?
+ attr-dict ( `:` type($eigvals)^ `,` type($counts) )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ SameVariadicResultSize(),
+ ]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = opt_operand_def(i64)
+
+ in_eigvals = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1))
+
+ in_counts = opt_operand_def(MemRefConstraint(element_type=i64, rank=1))
+
+ eigvals = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1))
+
+ counts = opt_result_def(TensorConstraint(element_type=i64, rank=1))
+
+
+@irdl_op_definition
+class CustomOp(IRDLOperation):
+ """A generic quantum gate on n qubits with m floating point parameters."""
+
+ name = "quantum.custom"
+
+ assembly_format = """
+ $gate_name `(` $params `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ params = var_operand_def(Float64Type())
+
+ in_qubits = var_operand_def(QubitType)
+
+ gate_name = prop_def(StringAttr)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ gate_name: str | StringAttr,
+ params: SSAValue[Float64Type] | Sequence[SSAValue[Float64Type]] | None = None,
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ params = () if params is None else params
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(params, Sequence):
+ params = (params,)
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ if isinstance(gate_name, str):
+ gate_name = StringAttr(data=gate_name)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {"gate_name": gate_name}
+ if adjoint:
+ properties["adjoint"] = UnitAttr()
+
+ super().__init__(
+ operands=(params, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class DeallocOp(IRDLOperation):
+ """Deallocate a quantum register."""
+
+ name = "quantum.dealloc"
+
+ assembly_format = """
+ $qreg attr-dict `:` type(operands)
+ """
+
+ qreg = operand_def(QuregType)
+
+ def __init__(self, qreg: QuregSSAValue | Operation):
+ super().__init__(operands=(qreg,))
+
+
+@irdl_op_definition
+class DeallocQubitOp(IRDLOperation):
+ """Deallocate a single qubit."""
+
+ name = "quantum.dealloc_qb"
+
+ assembly_format = """$qubit attr-dict `:` type(operands)"""
+
+ qubit = operand_def(QubitType)
+
+ def __init__(self, qubit: QubitSSAValue | Operation):
+ super().__init__(
+ operands=(qubit,),
+ )
+
+
+@irdl_op_definition
+class DeviceInitOp(IRDLOperation):
+ """Initialize a quantum device."""
+
+ name = "quantum.device"
+
+ assembly_format = """
+ (`shots` `(` $shots^ `)`)? `[` $lib `,` $device_name `,` $kwargs `]` attr-dict
+ """
+
+ irdl_options = [ParsePropInAttrDict()]
+
+ shots = opt_operand_def(i64)
+
+ auto_qubit_management = opt_prop_def(UnitAttr)
+
+ lib = prop_def(StringAttr)
+
+ device_name = prop_def(StringAttr)
+
+ kwargs = prop_def(StringAttr)
+
+
+@irdl_op_definition
+class DeviceReleaseOp(IRDLOperation):
+ """Release the active quantum device."""
+
+ name = "quantum.device_release"
+
+ assembly_format = "attr-dict"
+
+
+@irdl_op_definition
+class ExpvalOp(IRDLOperation):
+ """Compute the expectation value of the given observable for the current state"""
+
+ name = "quantum.expval"
+
+ assembly_format = "$obs attr-dict `:` type(results)"
+
+ obs = operand_def(ObservableType)
+
+ expval = result_def(Float64Type())
+
+ def __init__(self, obs: ObservableSSAValue | Operation):
+ super().__init__(operands=(obs,), result_types=(Float64Type(),))
+
+
+@irdl_op_definition
+class ExtractOp(IRDLOperation):
+ """Extract a qubit value from a register."""
+
+ name = "quantum.extract"
+
+ assembly_format = """
+ $qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results)
+ """
+
+ qreg = operand_def(QuregType)
+
+ idx = opt_operand_def(i64)
+
+ idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0)))
+
+ qubit = result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ def __init__(
+ self,
+ qreg: QuregSSAValue | Operation,
+ idx: int | SSAValue[IntegerType] | Operation | IntegerAttr,
+ ):
+ if isinstance(idx, int):
+ idx = IntegerAttr.from_int_and_width(idx, 64)
+
+ if isinstance(idx, IntegerAttr):
+ operands = (qreg, None)
+ properties = {"idx_attr": idx}
+ else:
+ operands = (qreg, idx)
+ properties = {}
+
+ super().__init__(
+ operands=operands,
+ result_types=(QubitType(),),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class FinalizeOp(IRDLOperation):
+ """Teardown the quantum runtime."""
+
+ name = "quantum.finalize"
+
+ assembly_format = "attr-dict"
+
+
+@irdl_op_definition
+class GlobalPhaseOp(IRDLOperation):
+ """Global Phase."""
+
+ name = "quantum.gphase"
+
+ assembly_format = """
+ `(` $params `)`
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type(results)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]
+
+ params = operand_def(Float64Type())
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ def __init__(
+ self,
+ *,
+ params: float | SSAValue[Float64Type],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ ):
+ if isinstance(params, float):
+ params = FloatAttr(data=params, type=Float64Type())
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+
+ super().__init__(
+ operands=(params, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_ctrl_qubits,),
+ )
+
+
+@irdl_op_definition
+class HamiltonianOp(IRDLOperation):
+ """Define a Hamiltonian observable for use in measurements"""
+
+ name = "quantum.hamiltonian"
+
+ assembly_format = """
+ `(` $coeffs `:` type($coeffs) `)` $terms attr-dict `:` type(results)
+ """
+
+ coeffs = operand_def(
+ TensorConstraint(element_type=Float64Type(), rank=1)
+ | (MemRefConstraint(element_type=Float64Type(), rank=1))
+ )
+
+ terms = var_operand_def(ObservableType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class HermitianOp(IRDLOperation):
+ """Define a Hermitian observable for use in measurements"""
+
+ name = "quantum.hermitian"
+
+ assembly_format = """
+ `(` $matrix `:` type($matrix) `)` $qubits attr-dict `:` type(results)
+ """
+
+ matrix = operand_def(
+ TensorConstraint(element_type=ComplexType(Float64Type()), rank=2)
+ | MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2)
+ )
+
+ qubits = var_operand_def(QubitType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class InitializeOp(IRDLOperation):
+ """Initialize the quantum runtime."""
+
+ name = "quantum.init"
+
+ assembly_format = "attr-dict"
+
+
+@irdl_op_definition
+class InsertOp(IRDLOperation):
+ """Update the qubit value of a register."""
+
+ name = "quantum.insert"
+
+ assembly_format = """
+ $in_qreg `[` ($idx^):($idx_attr)? `]` `,` $qubit attr-dict `:` type($in_qreg) `,` type($qubit)
+ """
+
+ in_qreg = operand_def(QuregType)
+
+ idx = opt_operand_def(i64)
+
+ idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0)))
+
+ qubit = operand_def(QubitType)
+
+ out_qreg = result_def(QuregType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ def __init__(
+ self,
+ in_qreg: QuregSSAValue | Operation,
+ idx: SSAValue[IntegerType] | Operation | int | IntegerAttr,
+ qubit: QubitSSAValue | Operation,
+ ):
+ if isinstance(idx, int):
+ idx = IntegerAttr.from_int_and_width(idx, 64)
+
+ if isinstance(idx, IntegerAttr):
+ operands = (in_qreg, None, qubit)
+ properties = {"idx_attr": idx}
+ else:
+ operands = (in_qreg, idx, qubit)
+ properties = {}
+
+ super().__init__(operands=operands, properties=properties, result_types=(QuregType(),))
+
+
+@irdl_op_definition
+class MeasureOp(IRDLOperation):
+ """A single-qubit projective measurement in the computational basis."""
+
+ name = "quantum.measure"
+
+ assembly_format = """
+ $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results)
+ """
+
+ in_qubit = operand_def(QubitType)
+
+ postselect = opt_prop_def(
+ IntegerAttr.constr(type=I32, value=IntSetConstraint(frozenset((0, 1))))
+ )
+
+ mres = result_def(i1)
+
+ out_qubit = result_def(QubitType)
+
+ def __init__(
+ self, in_qubit: QubitSSAValue | Operation, postselect: int | IntegerAttr | None = None
+ ):
+ if isinstance(postselect, int):
+ postselect = IntegerAttr.from_int_and_width(postselect, 32)
+
+ if postselect is None:
+ properties = {}
+ else:
+ properties = {"postselect": postselect}
+
+ super().__init__(
+ operands=(in_qubit,), properties=properties, result_types=(i1, QubitType())
+ )
+
+
+@irdl_op_definition
+class MultiRZOp(IRDLOperation):
+ """Apply an arbitrary multi Z rotation"""
+
+ name = "quantum.multirz"
+
+ assembly_format = """
+ `(` $theta `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ theta = operand_def(Float64Type())
+
+ in_qubits = var_operand_def(QubitType)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ theta: SSAValue[Float64Type],
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {"adjoint": UnitAttr()} if adjoint else {}
+
+ super().__init__(
+ operands=(theta, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class NamedObsOp(IRDLOperation):
+ """Define a Named observable for use in measurements"""
+
+ name = "quantum.namedobs"
+
+ assembly_format = """
+ $qubit `[` $type `]` attr-dict `:` type(results)
+ """
+
+ qubit = operand_def(QubitType)
+
+ type = prop_def(NamedObservableAttr)
+
+ obs = result_def(ObservableType)
+
+ def __init__(self, qubit: QubitSSAValue | Operation, obs_type: NamedObservableAttr):
+ super().__init__(
+ operands=(qubit,), properties={"type": obs_type}, result_types=(ObservableType(),)
+ )
+
+
+@irdl_op_definition
+class NumQubitsOp(IRDLOperation):
+ """Get the number of currently allocated qubits."""
+
+ name = "quantum.num_qubits"
+
+ assembly_format = """
+ attr-dict `:` type(results)
+ """
+
+ num_qubits = result_def(i64)
+
+
+@irdl_op_definition
+class PCPhaseOp(IRDLOperation):
+ """Apply a projector-controlled phase gate"""
+
+ name = "quantum.pcphase"
+
+ assembly_format = """
+ `(` $theta `,` $dim `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ theta = operand_def(Float64Type())
+
+ dim = operand_def(Float64Type())
+
+ in_qubits = var_operand_def(QubitType)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ theta: SSAValue[Float64Type],
+ dim: SSAValue[Float64Type],
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {"adjoint": UnitAttr()} if adjoint else {}
+
+ super().__init__(
+ operands=(theta, dim, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class ProbsOp(IRDLOperation):
+ """Compute computational basis probabilities for the current state"""
+
+ name = "quantum.probs"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $state_in^ `:` type($state_in) `)` )?
+ attr-dict ( `:` type($probabilities)^ )?
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = opt_operand_def(i64)
+
+ state_in = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1))
+
+ probabilities = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1))
+
+
+@irdl_op_definition
+class QubitUnitaryOp(IRDLOperation):
+ """Apply an arbitrary fixed unitary matrix"""
+
+ name = "quantum.unitary"
+
+ assembly_format = """
+ `(` $matrix `:` type($matrix) `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ matrix = operand_def(
+ (TensorConstraint(element_type=ComplexType(Float64Type()), rank=2))
+ | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2))
+ )
+
+ in_qubits = var_operand_def(QubitType)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ matrix: SSAValue[TensorType | MemRefType],
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {}
+ if adjoint:
+ properties["adjoint"] = UnitAttr()
+
+ super().__init__(
+ operands=(matrix, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class SampleOp(IRDLOperation):
+ """Sample eigenvalues from the given observable for the current state"""
+
+ name = "quantum.sample"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $in_data^ `:` type($in_data) `)` )?
+ attr-dict ( `:` type($samples)^ )?
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = var_operand_def(i64)
+
+ in_data = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=(1, 2)))
+
+ samples = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=(1, 2)))
+
+
+@irdl_op_definition
+class SetBasisStateOp(IRDLOperation):
+ """Set basis state."""
+
+ name = "quantum.set_basis_state"
+
+ assembly_format = """
+ `(` $basis_state`)` $in_qubits attr-dict `:` functional-type(operands, results)
+ """
+
+ basis_state = operand_def(
+ (TensorConstraint(element_type=i1, rank=1)) | (MemRefConstraint(element_type=i1, rank=1))
+ )
+
+ in_qubits = var_operand_def(QubitType)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class SetStateOp(IRDLOperation):
+ """Set state to a complex vector."""
+
+ name = "quantum.set_state"
+
+ assembly_format = """
+ `(` $in_state `)` $in_qubits attr-dict `:` functional-type(operands, results)
+ """
+
+ in_state = operand_def(
+ (TensorConstraint(element_type=ComplexType(Float64Type()), rank=1))
+ | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1))
+ )
+
+ in_qubits = var_operand_def(QubitType)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class StateOp(IRDLOperation):
+ """Return the current statevector"""
+
+ name = "quantum.state"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $state_in^ `:` type($state_in) `)` )?
+ attr-dict ( `:` type($state)^ )?
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = opt_operand_def(i64)
+
+ state_in = opt_operand_def(MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1))
+
+ state = opt_result_def(TensorConstraint(element_type=ComplexType(Float64Type()), rank=1))
+
+
+@irdl_op_definition
+class TensorOp(IRDLOperation):
+ """Define a tensor product of observables for use in measurements"""
+
+ name = "quantum.tensor"
+
+ assembly_format = """
+ $terms attr-dict `:` type(results)
+ """
+
+ terms = var_operand_def(ObservableType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class VarianceOp(IRDLOperation):
+ """Compute the variance of the given observable for the current state"""
+
+ name = "quantum.var"
+
+ assembly_format = """
+ $obs attr-dict `:` type(results)
+ """
+
+ obs = operand_def(ObservableType)
+
+ variance = result_def(Float64Type())
+
+ def __init__(self, obs: ObservableSSAValue | Operation):
+ super().__init__(operands=(obs,), result_types=(Float64Type(),))
+
+
+@irdl_op_definition
+class YieldOp(IRDLOperation):
+ """Return results from quantum program regions"""
+
+ name = "quantum.yield"
+
+ assembly_format = """
+ attr-dict ($retvals^ `:` type($retvals))?
+ """
+
+ retvals = var_operand_def(QuregType)
+
+ traits = traits_def(HasParent(AdjointOp), IsTerminator(), Pure(), ReturnLike())
+
+
+Quantum = Dialect(
+ "quantum",
+ [
+ AdjointOp,
+ AllocOp,
+ AllocQubitOp,
+ ComputationalBasisOp,
+ CountsOp,
+ CustomOp,
+ DeallocOp,
+ DeallocQubitOp,
+ DeviceInitOp,
+ DeviceReleaseOp,
+ ExpvalOp,
+ ExtractOp,
+ FinalizeOp,
+ GlobalPhaseOp,
+ HamiltonianOp,
+ HermitianOp,
+ InitializeOp,
+ InsertOp,
+ MeasureOp,
+ MultiRZOp,
+ NamedObsOp,
+ NumQubitsOp,
+ PCPhaseOp,
+ ProbsOp,
+ QubitUnitaryOp,
+ SampleOp,
+ SetBasisStateOp,
+ SetStateOp,
+ StateOp,
+ TensorOp,
+ VarianceOp,
+ YieldOp,
+ ],
+ [
+ ObservableType,
+ QubitType,
+ QuregType,
+ ResultType,
+ NamedObservableAttr,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py
new file mode 100644
index 0000000000..64380529ac
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py
@@ -0,0 +1,161 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+StableHLO dialect package for PennyLane's compiler infrastructure.
+
+This package contains organized elementwise operations and other StableHLO-related
+functionality.
+"""
+
+from .attributes import (
+ CustomCallApiVersion,
+ CustomCallApiVersionAttr,
+ GatherDimensionNumbers,
+ OutputOperandAlias,
+ ResultAccuracyModeAttr,
+ ScatterDimensionNumbers,
+)
+from .control_flow import (
+ IfOp,
+ OptimizationBarrierOp,
+ WhileOp,
+)
+from .data_movement import (
+ BroadcastInDimOp,
+ ConcatenateOp,
+ DynamicSliceOp,
+ GatherOp,
+ ReshapeOp,
+ ScatterOp,
+ SliceOp,
+)
+
+# Import the main StableHLO dialect
+from .dialect import StableHLO
+from .dynamism import (
+ DynamicBroadcastInDimOp,
+)
+from .elementwise_binary import (
+ ComplexOp,
+ DivideOp,
+ MaximumOp,
+ MinimumOp,
+ PowerOp,
+ RemainderOp,
+)
+from .elementwise_other import (
+ ClampOp,
+ CompareOp,
+ ConstantOp,
+ MapOp,
+ ReducePrecisionOp,
+ SelectOp,
+)
+
+# Import all elementwise operations explicitly
+from .elementwise_unary import (
+ ConvertOp,
+ CosineOp,
+ ExponentialMinusOneOp,
+ ExponentialOp,
+ FloorOp,
+ ImagOp,
+ IsFiniteOp,
+ LogisticOp,
+ LogOp,
+ LogPlusOneOp,
+ NegateOp,
+ RealOp,
+ RoundNearestAfzOp,
+ RoundNearestEvenOp,
+ RsqrtOp,
+ SignOp,
+ SineOp,
+ SqrtOp,
+ TanhOp,
+ TanOp,
+)
+from .extensibility import (
+ CustomCallOp,
+)
+from .reduction import (
+ ReduceOp,
+)
+
+# Export all operations and the dialect for external use
+__all__ = [
+ # Main dialect
+ "StableHLO",
+ # Elementwise unary operations
+ "ConvertOp",
+ "CosineOp",
+ "ExponentialMinusOneOp",
+ "ExponentialOp",
+ "FloorOp",
+ "ImagOp",
+ "IsFiniteOp",
+ "LogOp",
+ "LogPlusOneOp",
+ "LogisticOp",
+ "NegateOp",
+ "RealOp",
+ "RoundNearestAfzOp",
+ "RoundNearestEvenOp",
+ "RsqrtOp",
+ "SignOp",
+ "SineOp",
+ "SqrtOp",
+ "TanOp",
+ "TanhOp",
+ # Elementwise binary operations
+ "ComplexOp",
+ "DivideOp",
+ "MaximumOp",
+ "MinimumOp",
+ "PowerOp",
+ "RemainderOp",
+ # Elementwise other operations
+ "ClampOp",
+ "CompareOp",
+ "ConstantOp",
+ "MapOp",
+ "ReducePrecisionOp",
+ "SelectOp",
+ # Control flow operations
+ "IfOp",
+ "WhileOp",
+ "OptimizationBarrierOp",
+ # Data movement operations
+ "BroadcastInDimOp",
+ "ConcatenateOp",
+ "DynamicSliceOp",
+ "GatherOp",
+ "ReshapeOp",
+ "ScatterOp",
+ "SliceOp",
+ # Dynamism operations
+ "DynamicBroadcastInDimOp",
+ # Reduction operations
+ "ReduceOp",
+ # Extensibility operations
+ "CustomCallOp",
+ # Attributes
+ "GatherDimensionNumbers",
+ "ResultAccuracyModeAttr",
+ "ScatterDimensionNumbers",
+ "CustomCallApiVersion",
+ "CustomCallApiVersionAttr",
+ "OutputOperandAlias",
+]
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py
new file mode 100644
index 0000000000..7618d77afc
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py
@@ -0,0 +1,394 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+StableHLO attribute definitions for PennyLane's compiler infrastructure.
+
+This module provides attribute definitions based on the StableHLO specification
+(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including
+attributes for StableHLO operations.
+"""
+
+# pylint: disable=line-too-long
+
+from collections.abc import Sequence
+
+from xdsl.dialects.builtin import I64, ArrayAttr, IntegerAttr, i64
+from xdsl.ir import (
+ Attribute,
+ EnumAttribute,
+ ParametrizedAttribute,
+ SpacedOpaqueSyntaxAttribute,
+ StrEnum,
+)
+from xdsl.irdl import irdl_attr_definition
+from xdsl.parser import AttrParser
+from xdsl.printer import Printer
+
+
+# Utility functions for dimension array parsing/printing
+def parse_dims(parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]:
+ """Parse dimension array in [1, 2, 3] format"""
+ value = parser.parse_comma_separated_list(
+ AttrParser.Delimiter.SQUARE,
+ lambda: IntegerAttr(parser.parse_integer(), i64),
+ )
+ return ArrayAttr(value)
+
+
+def print_dims(printer: Printer, dims: ArrayAttr[IntegerAttr[I64]]):
+ """Print dimension array in [1, 2, 3] format"""
+ printer.print_string("[")
+ printer.print_list(
+ dims.data,
+ lambda dim: printer.print_string(f"{dim.value.data}"),
+ )
+ printer.print_string("]")
+
+
+class ResultAccuracyMode(StrEnum):
+ """
+ XLA result accuracy mode.
+ """
+
+ DEFAULT = "DEFAULT"
+ HIGH = "HIGHEST"
+ HIGHEST = "TOLERANCE"
+
+
+@irdl_attr_definition
+class ResultAccuracyModeAttr(EnumAttribute[ResultAccuracyMode], SpacedOpaqueSyntaxAttribute):
+ """
+ XLA result accuracy mode.
+
+ See external [documentation](https://github.com/openxla/stablehlo/blob/7c50d4efeaea30bff6aa5e46c7f71170f5aa06af/stablehlo/dialect/StablehloEnums.td#L49-L70).
+ """
+
+ name = "stablehlo.result_accuracy_mode"
+
+
+@irdl_attr_definition
+class GatherDimensionNumbers(ParametrizedAttribute):
+ """
+ XLA gather dimension numbers.
+
+ This attribute models the dimension information for gather operations.
+ See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42).
+ """
+
+ name = "stablehlo.gather"
+
+ offset_dims: ArrayAttr[IntegerAttr[I64]]
+ collapsed_slice_dims: ArrayAttr[IntegerAttr[I64]]
+ operand_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ start_indices_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ start_index_map: ArrayAttr[IntegerAttr[I64]]
+ index_vector_dim: IntegerAttr[I64]
+
+ def print_parameters(self, printer: Printer) -> None:
+ """Print gather dimension numbers in structured format"""
+ with printer.in_angle_brackets():
+ with printer.indented():
+ # Print offset_dims
+ printer.print_string("\noffset_dims = ")
+ print_dims(printer, self.offset_dims)
+ printer.print_string(",")
+
+ # Print collapsed_slice_dims
+ printer.print_string("\ncollapsed_slice_dims = ")
+ print_dims(printer, self.collapsed_slice_dims)
+ printer.print_string(",")
+
+ # Print operand_batching_dims
+ printer.print_string("\noperand_batching_dims = ")
+ print_dims(printer, self.operand_batching_dims)
+ printer.print_string(",")
+
+ # Print start_indices_batching_dims
+ printer.print_string("\nstart_indices_batching_dims = ")
+ print_dims(printer, self.start_indices_batching_dims)
+ printer.print_string(",")
+
+ # Print start_index_map
+ printer.print_string("\nstart_index_map = ")
+ print_dims(printer, self.start_index_map)
+ printer.print_string(",")
+
+ # Print index_vector_dim
+ printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}")
+ printer.print_string("\n")
+
+ @classmethod
+ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
+ """Parse gather dimension numbers from structured format"""
+ with parser.in_angle_brackets():
+ # Initialize default values for all fields
+ offset_dims = ArrayAttr([])
+ collapsed_slice_dims = ArrayAttr([])
+ operand_batching_dims = ArrayAttr([])
+ start_indices_batching_dims = ArrayAttr([])
+ start_index_map = ArrayAttr([])
+ index_vector_dim = IntegerAttr(0, i64)
+
+ # Try to parse offset_dims
+ if parser.parse_optional_characters("offset_dims") is not None:
+ parser.parse_punctuation("=")
+ offset_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse collapsed_slice_dims
+ if parser.parse_optional_characters("collapsed_slice_dims") is not None:
+ parser.parse_punctuation("=")
+ collapsed_slice_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse operand_batching_dims
+ if parser.parse_optional_characters("operand_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ operand_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse start_indices_batching_dims
+ if parser.parse_optional_characters("start_indices_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ start_indices_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse start_index_map
+ if parser.parse_optional_characters("start_index_map") is not None:
+ parser.parse_punctuation("=")
+ start_index_map = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse index_vector_dim
+ if parser.parse_optional_characters("index_vector_dim") is not None:
+ parser.parse_punctuation("=")
+ index_vector_dim = IntegerAttr(parser.parse_integer(), i64)
+
+ return (
+ offset_dims,
+ collapsed_slice_dims,
+ operand_batching_dims,
+ start_indices_batching_dims,
+ start_index_map,
+ index_vector_dim,
+ )
+
+
+@irdl_attr_definition
+class ScatterDimensionNumbers(ParametrizedAttribute):
+ """
+ XLA scatter dimension numbers.
+
+ This attribute models the dimension information for scatter operations.
+ See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28).
+ """
+
+ name = "stablehlo.scatter"
+
+ update_window_dims: ArrayAttr[IntegerAttr[I64]]
+ inserted_window_dims: ArrayAttr[IntegerAttr[I64]]
+ input_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ scatter_indices_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ scatter_dims_to_operand_dims: ArrayAttr[IntegerAttr[I64]]
+ index_vector_dim: IntegerAttr[I64]
+
+ def print_parameters(self, printer: Printer) -> None:
+ """Print scatter dimension numbers in structured format"""
+ with printer.in_angle_brackets():
+ with printer.indented():
+ # Print update_window_dims
+ printer.print_string("\nupdate_window_dims = ")
+ print_dims(printer, self.update_window_dims)
+ printer.print_string(",")
+
+ # Print inserted_window_dims
+ printer.print_string("\ninserted_window_dims = ")
+ print_dims(printer, self.inserted_window_dims)
+ printer.print_string(",")
+
+ # Print input_batching_dims
+ printer.print_string("\ninput_batching_dims = ")
+ print_dims(printer, self.input_batching_dims)
+ printer.print_string(",")
+
+ # Print scatter_indices_batching_dims
+ printer.print_string("\nscatter_indices_batching_dims = ")
+ print_dims(printer, self.scatter_indices_batching_dims)
+ printer.print_string(",")
+
+ # Print scatter_dims_to_operand_dims
+ printer.print_string("\nscatter_dims_to_operand_dims = ")
+ print_dims(printer, self.scatter_dims_to_operand_dims)
+ printer.print_string(",")
+
+ # Print index_vector_dim
+ printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}")
+ printer.print_string("\n")
+
+ @classmethod
+ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
+ """Parse scatter dimension numbers from structured format"""
+ with parser.in_angle_brackets():
+ # Initialize default values for all fields
+ update_window_dims = ArrayAttr([])
+ inserted_window_dims = ArrayAttr([])
+ input_batching_dims = ArrayAttr([])
+ scatter_indices_batching_dims = ArrayAttr([])
+ scatter_dims_to_operand_dims = ArrayAttr([])
+ index_vector_dim = IntegerAttr(0, i64)
+
+ # Try to parse update_window_dims
+ if parser.parse_optional_characters("update_window_dims") is not None:
+ parser.parse_punctuation("=")
+ update_window_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse inserted_window_dims
+ if parser.parse_optional_characters("inserted_window_dims") is not None:
+ parser.parse_punctuation("=")
+ inserted_window_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse input_batching_dims
+ if parser.parse_optional_characters("input_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ input_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse scatter_indices_batching_dims
+ if parser.parse_optional_characters("scatter_indices_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ scatter_indices_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse scatter_dims_to_operand_dims
+ if parser.parse_optional_characters("scatter_dims_to_operand_dims") is not None:
+ parser.parse_punctuation("=")
+ scatter_dims_to_operand_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse index_vector_dim
+ if parser.parse_optional_characters("index_vector_dim") is not None:
+ parser.parse_punctuation("=")
+ index_vector_dim = IntegerAttr(parser.parse_integer(), i64)
+
+ return (
+ update_window_dims,
+ inserted_window_dims,
+ input_batching_dims,
+ scatter_indices_batching_dims,
+ scatter_dims_to_operand_dims,
+ index_vector_dim,
+ )
+
+
+# ===== CustomCall and layout-related attributes =====
+
+
+class CustomCallApiVersion(StrEnum):
+ """StableHLO CustomCall API version."""
+
+ API_VERSION_UNSPECIFIED = "API_VERSION_UNSPECIFIED"
+ API_VERSION_ORIGINAL = "API_VERSION_ORIGINAL"
+ API_VERSION_STATUS_RETURNING = "API_VERSION_STATUS_RETURNING"
+ API_VERSION_STATUS_RETURNING_UNIFIED = "API_VERSION_STATUS_RETURNING_UNIFIED"
+ API_VERSION_TYPED_FFI = "API_VERSION_TYPED_FFI"
+
+
+@irdl_attr_definition
+class CustomCallApiVersionAttr(EnumAttribute[CustomCallApiVersion], SpacedOpaqueSyntaxAttribute):
+ """StableHLO custom call API version attribute.
+
+ Mirrors StableHLO enum for CustomCall API versions.
+ """
+
+ name = "stablehlo.custom_call_api_version"
+
+
+@irdl_attr_definition
+class OutputOperandAlias(ParametrizedAttribute):
+ """
+ This attribute captures the alias relationship of the output to one of the
+ operands for a ``CustomCall`` op, denoted by ``operand_index``. The
+ ``output_tuple_indices`` and ``operand_tuple_indices`` are used to index into
+ output and operand types. These indices lists are empty if the corresponding
+ types are not tuple types, and can be arbitrarily long in case of
+ arbitrarily nested tuple types.
+
+ See https://www.tensorflow.org/xla/aliasing.
+
+ Example when used as array with in stablehlo.custom-call:
+
+ ```mlir
+ %0 = "stablehlo.custom_call"(%arg0, %arg1) {
+ // other attributes
+ output_operand_alias = [
+ #stablehlo.output_operand_alias
+ ]
+ } : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple>
+
+ The output and the 0th operand are both tuples. The aliasing shows the
+ relationship between the 0th element in output tuple with the 1st element in
+ the 0th operand. And both of them are of the same type: ``tensor<2x3xf32>``.
+ ```
+ """
+
+ name = "stablehlo.output_operand_alias"
+
+ output_tuple_indices: ArrayAttr[IntegerAttr[I64]]
+ operand_index: IntegerAttr[I64]
+ operand_tuple_indices: ArrayAttr[IntegerAttr[I64]]
+
+ def print_parameters(self, printer: Printer) -> None:
+ """Print the OutputOperandAlias attribute."""
+ with printer.in_angle_brackets():
+ with printer.indented():
+ printer.print_string("\noutput_tuple_indices = ")
+ print_dims(printer, self.output_tuple_indices)
+ printer.print_string(",")
+
+ printer.print_string("\noperand_index = ")
+ printer.print_string(f"{self.operand_index.value.data}")
+ printer.print_string(",")
+
+ printer.print_string("\noperand_tuple_indices = ")
+ print_dims(printer, self.operand_tuple_indices)
+ printer.print_string("\n")
+
+ @classmethod
+ def parse_parameters(cls, parser: AttrParser):
+ """Parse the OutputOperandAlias attribute."""
+ with parser.in_angle_brackets():
+ output_tuple_indices = ArrayAttr([])
+ operand_index = IntegerAttr(0, i64)
+ operand_tuple_indices = ArrayAttr([])
+
+ if parser.parse_optional_characters("output_tuple_indices") is not None:
+ parser.parse_punctuation("=")
+ output_tuple_indices = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ if parser.parse_optional_characters("operand_index") is not None:
+ parser.parse_punctuation("=")
+ operand_index = IntegerAttr(parser.parse_integer(), i64)
+ parser.parse_optional_punctuation(",")
+
+ if parser.parse_optional_characters("operand_tuple_indices") is not None:
+ parser.parse_punctuation("=")
+ operand_tuple_indices = parse_dims(parser)
+
+ return (output_tuple_indices, operand_index, operand_tuple_indices)
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py
new file mode 100644
index 0000000000..c9edced70d
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py
@@ -0,0 +1,160 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Control flow operations for the StableHLO dialect.
+"""
+
+from typing import TypeVar
+
+from xdsl.dialects.builtin import AnyTensorType
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ region_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import (
+ Pure,
+ RecursivelySpeculatable,
+ RecursiveMemoryEffect,
+ SingleBlockImplicitTerminator,
+)
+from xdsl_jax.dialects.stablehlo import ReturnOp
+
+# Import our custom StableHLO types
+from .types import HLO_PredTensor, HLO_TensorOrPerAxisQuantizedTensorOrToken, HLO_TensorOrToken
+
+# Generic type variables for templating
+T_IN = TypeVar("T_IN", bound=AnyTensorType)
+T_OUT = TypeVar("T_OUT", bound=AnyTensorType)
+
+
+@irdl_op_definition
+class IfOp(IRDLOperation):
+ """
+ Produces the output from executing exactly one branch from `true_branch` or
+ `false_branch` depending on the value of `pred`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if
+
+ Example:
+ %result = "stablehlo.if"(%pred) ({
+ "stablehlo.return"(%result_true_branch) : (tensor) -> ()
+ }, {
+ "stablehlo.return"(%result_false_branch) : (tensor) -> ()
+ }) : (tensor) -> tensor
+ """
+
+ name = "stablehlo.if"
+
+ pred = operand_def(HLO_PredTensor)
+
+ res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
+
+ true_branch = region_def("single_block")
+
+ false_branch = region_def("single_block")
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ RecursivelySpeculatable(),
+ SingleBlockImplicitTerminator(ReturnOp),
+ # TODO: InferTypeOpInterface
+ # TODO: OpAsmOpInterface
+ )
+
+ # TODO: Add custom assembly format
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class WhileOp(IRDLOperation):
+ """
+ Produces the output from executing `body` function 0 or more times while the
+ `cond` function outputs `true`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while
+
+ Example:
+ ```mlir
+ %results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor, tensor
+ cond {
+ %cond = stablehlo.compare LT, %arg0, %ten : (tensor, tensor) -> tensor
+ stablehlo.return %cond : tensor
+ } do {
+ %new_sum = stablehlo.add %arg1, %one : tensor
+ %new_i = stablehlo.add %arg0, %one : tensor
+ stablehlo.return %new_i, %new_sum : tensor, tensor
+ }
+ """
+
+ name = "stablehlo.while"
+
+ operand = var_operand_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
+
+ res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
+
+ cond = region_def("single_block")
+
+ body = region_def("single_block")
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ RecursivelySpeculatable(),
+ SingleBlockImplicitTerminator(ReturnOp),
+ # TODO: InferTypeOpInterface
+ # TODO: OpAsmOpInterface
+ )
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class OptimizationBarrierOp(IRDLOperation):
+ """
+ Ensures that the operations that produce the `operand` are executed before any
+ operations that depend on the `result` and prevents compiler transformations
+ from moving operations across the barrier. Other than that, the operation is
+ an identity, i.e. `result` = `operand`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier
+
+ Example:
+ ```mlir
+ %result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor, tensor
+ ```
+ """
+
+ name = "stablehlo.optimization_barrier"
+
+ operand = var_operand_def(HLO_TensorOrToken)
+
+ res = var_result_def(HLO_TensorOrToken)
+
+ traits = traits_def(
+ Pure(),
+ # TODO: HLO_PairwiseSameOperandAndResultType
+ # TODO: InferTypeOpInterface
+ )
+
+ # TODO: Add custom assembly format
+ # assembly_format = """
+ # attr-dict ($operand^ `:` custom(type($operand), type($result))):(`(` `)`)?
+ # """
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py
new file mode 100644
index 0000000000..68ce2bddc7
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py
@@ -0,0 +1,416 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Data movement operations for the StableHLO dialect.
+"""
+
+from xdsl.dialects.builtin import BoolAttr, DenseArrayBase, IntegerAttr, TensorType, i64
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ opt_prop_def,
+ prop_def,
+ region_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.irdl.attributes import eq
+from xdsl.irdl.constraints import AtLeast
+from xdsl.irdl.operations import SameVariadicOperandSize
+from xdsl.traits import (
+ ConditionallySpeculatable,
+ NoMemoryEffect,
+ Pure,
+ RecursiveMemoryEffect,
+)
+from xdsl.utils.exceptions import VerifyException
+from xdsl.utils.type import get_element_type_or_self
+
+from catalyst.python_interface.xdsl_extras import (
+ AllMatchSameOperatorTrait,
+ SameOperandsAndResultElementType,
+ TensorConstraint,
+)
+
+from .attributes import GatherDimensionNumbers, ScatterDimensionNumbers
+from .types import HLO_AnyIntegerOrIndexTensor, HLO_AnyTensor, HLO_Int, HLO_IntTensor, HLO_Tensor
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class BroadcastInDimOp(IRDLOperation):
+ """
+ Expands the dimensions and/or rank of an input tensor by duplicating the
+ data in the ``operand`` tensor and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim
+
+ Example:
+ ```mlir
+ %result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
+ ```
+ """
+
+ name = "stablehlo.broadcast_in_dim"
+ operand = operand_def(HLO_AnyTensor)
+ broadcast_dimensions = prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_AnyTensor)
+
+ assembly_format = """
+ $operand `,` `dims` `=` $broadcast_dimensions
+ attr-dict `:` functional-type(operands, results)
+ """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic,
+ # TODO: HLO_CompatibleOperandsAndResultElementType,
+ )
+
+ def verify_(self) -> None:
+ """Verify non-quantized broadcast_in_dim constraints."""
+ o_type = self.operand_types[0]
+ r_type = self.result_types[0]
+
+ # These are constrained to tensors by the op definition
+ assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType)
+
+ # broadcast_in_dim_c2: broadcast_dimensions size == operand rank
+ dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member
+ operand_rank = o_type.get_num_dims()
+ if len(dims) != operand_rank:
+ raise VerifyException(
+ "broadcast_dimensions size ("
+ f"{len(dims)}"
+ ") does not match operand rank ("
+ f"{operand_rank}"
+ ")"
+ )
+
+ # broadcast_in_dim_c4: broadcast_dimensions should not have duplicates
+ if len(set(dims)) != len(dims):
+ raise VerifyException("broadcast_dimensions should not have duplicates")
+
+ # Result rank and per-dimension checks
+ result_rank = r_type.get_num_dims()
+ o_shape = o_type.get_shape()
+ r_shape = r_type.get_shape()
+
+ for i, dim_index in enumerate(dims):
+ # broadcast_in_dim_c3: each dim index in bounds of result rank
+ if dim_index < 0 or dim_index >= result_rank:
+ raise VerifyException(
+ "broadcast_dimensions contains invalid value "
+ f"{dim_index} for result with rank {result_rank}"
+ )
+
+ # If operand dim is static, enforce broadcast_in_dim_c5
+ if o_shape[i] != -1:
+ dim_size = o_shape[i]
+ result_dim_size = r_shape[dim_index]
+ if dim_size not in (1, result_dim_size):
+ raise VerifyException(
+ "size of operand dimension "
+ f"{i} ({dim_size}) is not equal to 1 or size of result dimension "
+ f"{dim_index} ({result_dim_size})"
+ )
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class ConcatenateOp(IRDLOperation):
+ """
+ Concatenates a variadic number of tensors in ``inputs`` along ``dimension``
+ dimension in the same order as the given arguments and produces a ``result``
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate
+
+ Example:
+ ```mlir
+ %result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
+ ```
+ """
+
+ name = "stablehlo.concatenate"
+
+ inputs = var_operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+ dimension = prop_def(IntegerAttr.constr(type=eq(i64), value=AtLeast(0)))
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ SameOperandsAndResultElementType(),
+ # InferTypeOpInterface(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results)
+ # """
+
+
+@irdl_op_definition
+class DynamicSliceOp(IRDLOperation):
+ """
+ Extracts a slice from the ``operand`` using dynamically-computed starting
+ indices and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice
+
+ Example:
+ ```mlir
+ %result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2]
+ : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x2xi32>
+ ```
+ """
+
+ name = "stablehlo.dynamic_slice"
+ operand = operand_def(HLO_Tensor)
+ start_indices = var_operand_def(TensorConstraint(element_type=HLO_Int, rank=0))
+ slice_sizes = prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_Tensor)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand `,` custom($start_indices)
+ # `sizes` `=` $slice_sizes attr-dict `:` functional-type(operands, results)
+ # """
+
+ traits = traits_def(
+ Pure(),
+ AllMatchSameOperatorTrait(
+ ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type"
+ ),
+ # TODO: InferTensorType(),
+ )
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class GatherOp(IRDLOperation):
+ """
+ Gathers slices from ``operand`` tensor from offsets specified in
+ ``start_indices`` and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather
+
+ Example:
+ ```mlir
+ %result = "stablehlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #stablehlo.gather<
+ offset_dims = [3, 4],
+ collapsed_slice_dims = [1],
+ operand_batching_dims = [0],
+ start_indices_batching_dims = [1],
+ start_index_map = [2, 1],
+ index_vector_dim = 3>,
+ slice_sizes = array,
+ indices_are_sorted = false
+ } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64>
+ ```
+ """
+
+ name = "stablehlo.gather"
+ operand = operand_def(HLO_Tensor)
+ start_indices = operand_def(HLO_IntTensor)
+ dimension_numbers = prop_def(GatherDimensionNumbers)
+ slice_sizes = prop_def(DenseArrayBase.constr(i64))
+ indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ result = result_def(HLO_Tensor)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ AllMatchSameOperatorTrait(
+ ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type"
+ ),
+ # TODO: InferTensorTypeWithReify(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results)
+ # """
+
+
+@irdl_op_definition
+class ReshapeOp(IRDLOperation):
+ """
+ Performs reshape of ``operand`` tensor to a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape
+
+ Example:
+ ```mlir
+ %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32>
+ ```
+ """
+
+ name = "stablehlo.reshape"
+ operand = operand_def(HLO_AnyTensor)
+ result = result_def(HLO_AnyTensor)
+
+ assembly_format = """
+ operands attr-dict `:` functional-type(operands, results)
+ """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ # TODO: HLO_CompatibleOperandsAndResultElementType,
+ )
+
+ def verify_(self) -> None:
+ """Verify that the operation has the same shape for all operands and results."""
+ o_type = self.operand_types[0]
+ r_type = self.result_types[0]
+
+ # These are constrained to tensors by the op definition
+ assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType)
+
+ # If o_type or r_type is dynamically shaped there is nothing to verify.
+ if not o_type.has_static_shape() or not r_type.has_static_shape():
+ return
+
+ # If the operand type is statically shaped (not required) the number of
+ # elements must match that of the result type.
+ num_operand_elements = 1
+ for dim in o_type.get_shape():
+ num_operand_elements *= dim
+
+ num_result_elements = 1
+ for dim in r_type.get_shape():
+ num_result_elements *= dim
+
+ if num_result_elements != num_operand_elements:
+ raise VerifyException(
+ "number of output elements ("
+ f"{num_result_elements}"
+ ") doesn't match expected number of elements ("
+ f"{num_operand_elements}"
+ ")"
+ )
+
+
+@irdl_op_definition
+class ScatterOp(IRDLOperation):
+ """
+ Produces ``results`` tensors which are equal to ``inputs`` tensors except that
+ several slices specified by ``scatter_indices`` are updated with the values
+ ``updates`` using ``update_computation``.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
+
+ Example:
+ ```mlir
+ %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = stablehlo.add %arg0, %arg1 : tensor
+ stablehlo.return %0 : tensor
+ }) {
+ scatter_dimension_numbers = #stablehlo.scatter<
+ update_window_dims = [3, 4],
+ inserted_window_dims = [1],
+ input_batching_dims = [0],
+ scatter_indices_batching_dims = [1],
+ scatter_dims_to_operand_dims = [2, 1],
+ index_vector_dim = 3>,
+ indices_are_sorted = false,
+ unique_indices = false
+ } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
+ ```
+ """
+
+ name = "stablehlo.scatter"
+ inputs = var_operand_def(HLO_Tensor)
+ scatter_indices = operand_def(HLO_AnyIntegerOrIndexTensor)
+ updates = var_operand_def(HLO_Tensor)
+ scatter_dimension_numbers = prop_def(ScatterDimensionNumbers)
+ indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ unique_indices = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ result = var_result_def(HLO_Tensor)
+ update_computation = region_def("single_block")
+ # TODO: The MLIR implementation doesn't have the SingleBlockImplicitTerminator trait,
+ # However, it is checked to have a terminator in the verifier,
+ # which does not specifically check the terminator to be stablehlo.return.
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ ConditionallySpeculatable(),
+ # TODO: InferTypeOpInterface(),
+ )
+
+ irdl_options = [SameVariadicOperandSize()]
+
+ # TODO: MLIR has a custom verifier for the scatter operation.
+
+
+@irdl_op_definition
+class SliceOp(IRDLOperation):
+ """
+ Extracts a slice from the ``operand`` using statically-computed starting
+ indices and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
+
+ Example:
+ ```mlir
+ %result = stablehlo.slice %operand [1:3, 4:8:2]
+ : (tensor<3x8xi64>) -> tensor<2x2xi64>
+
+ // Same in generic form: the `1:3` above is mapped to the first entry in
+ // `start_indices` and `limit_indices`, while `strides` is implicitly 1.
+ // The `4:8:2` above is parsed into the second entry of `start_indices`,
+ // `limit_indices` and `strides` respectively.
+ %result = "stablehlo.slice" (%operand) {
+ start_indices = array,
+ limit_indices = array,
+ strides = array
+ } : (tensor<3x8xi64>) -> tensor<2x2xi64>
+ ```
+ """
+
+ name = "stablehlo.slice"
+
+ operand = operand_def(HLO_Tensor)
+ start_indices = prop_def(DenseArrayBase.constr(i64))
+ limit_indices = prop_def(DenseArrayBase.constr(i64))
+ strides = prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_Tensor)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand custom($start_indices, $limit_indices, $strides)
+ # attr-dict `:` functional-type(operands, results)
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ AllMatchSameOperatorTrait(("start_indices", "limit_indices", "strides"), len, "size"),
+ SameOperandsAndResultElementType(),
+ )
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py
new file mode 100644
index 0000000000..ce880bba2b
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py
@@ -0,0 +1,207 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Extended StableHLO dialect that dynamically includes all upstream operations
+plus custom operations for PennyLane's compiler infrastructure.
+
+This module automatically imports all operations and attributes from the upstream
+xdsl_jax.dialects.stablehlo and adds custom ones without needing to hardcode
+the upstream operation list.
+"""
+
+import xdsl_jax.dialects.stablehlo as xstablehlo
+from xdsl.ir import Dialect
+
+from .attributes import (
+ CustomCallApiVersionAttr,
+ GatherDimensionNumbers,
+ OutputOperandAlias,
+ ResultAccuracyModeAttr,
+ ScatterDimensionNumbers,
+)
+from .control_flow import (
+ IfOp,
+ OptimizationBarrierOp,
+ WhileOp,
+)
+from .data_movement import (
+ BroadcastInDimOp,
+ ConcatenateOp,
+ DynamicSliceOp,
+ GatherOp,
+ ReshapeOp,
+ ScatterOp,
+ SliceOp,
+)
+from .dynamism import (
+ DynamicBroadcastInDimOp,
+)
+from .elementwise_binary import (
+ ComplexOp,
+ DivideOp,
+ MaximumOp,
+ MinimumOp,
+ PowerOp,
+ RemainderOp,
+)
+from .elementwise_other import (
+ ClampOp,
+ CompareOp,
+ ConstantOp,
+ MapOp,
+ ReducePrecisionOp,
+ SelectOp,
+)
+
+# Import all elementwise operations from organized files
+from .elementwise_unary import (
+ ConvertOp,
+ CosineOp,
+ ExponentialMinusOneOp,
+ ExponentialOp,
+ FloorOp,
+ ImagOp,
+ IsFiniteOp,
+ LogisticOp,
+ LogOp,
+ LogPlusOneOp,
+ NegateOp,
+ RealOp,
+ RoundNearestAfzOp,
+ RoundNearestEvenOp,
+ RsqrtOp,
+ SignOp,
+ SineOp,
+ SqrtOp,
+ TanhOp,
+ TanOp,
+)
+from .extensibility import (
+ CustomCallOp,
+)
+from .reduction import (
+ ReduceOp,
+)
+from .types import UniformQuantizedPerAxisType, UniformQuantizedType
+
+# Operations to add to the dialect
+OPERATIONS = [
+ ClampOp,
+ CompareOp,
+ ComplexOp,
+ ConstantOp,
+ ConvertOp,
+ CosineOp,
+ DivideOp,
+ ExponentialMinusOneOp,
+ ExponentialOp,
+ FloorOp,
+ ImagOp,
+ IsFiniteOp,
+ LogOp,
+ LogPlusOneOp,
+ LogisticOp,
+ MapOp,
+ MaximumOp,
+ MinimumOp,
+ NegateOp,
+ PowerOp,
+ RealOp,
+ ReducePrecisionOp,
+ RemainderOp,
+ RoundNearestAfzOp,
+ RoundNearestEvenOp,
+ RsqrtOp,
+ SelectOp,
+ SignOp,
+ SineOp,
+ SqrtOp,
+ TanOp,
+ TanhOp,
+ # Data movement operations
+ BroadcastInDimOp,
+ ConcatenateOp,
+ DynamicSliceOp,
+ GatherOp,
+ ReshapeOp,
+ ScatterOp,
+ SliceOp,
+ # Control flow operations
+ IfOp,
+ WhileOp,
+ OptimizationBarrierOp,
+ # Dynamism operations
+ DynamicBroadcastInDimOp,
+ # Reduction operations
+ ReduceOp,
+ # Extensibility operations
+ CustomCallOp,
+]
+
+# Attributes to add to the dialect
+ATTRIBUTES = [
+ CustomCallApiVersionAttr,
+ GatherDimensionNumbers,
+ ResultAccuracyModeAttr,
+ OutputOperandAlias,
+ ScatterDimensionNumbers,
+ UniformQuantizedPerAxisType,
+ UniformQuantizedType,
+]
+
+# Operations/attributes from upstream that should be deleted/replaced in the local version
+UPSTREAM_OPERATIONS_TO_DELETE = [
+ xstablehlo.ConstantOp,
+]
+UPSTREAM_ATTRIBUTES_TO_DELETE = []
+
+
+def filter_and_extend_upstream(upstream_list, to_delete, to_add):
+ """Filter out operations/attributes from upstream list and add new ones.
+
+ Args:
+ upstream_list: List of operations/attributes to filter
+ to_delete: List of operations/attributes to remove
+ to_add: List of operations/attributes to add
+
+ Returns:
+ Modified list of operations/attributes
+ """
+ filtered_ops = list(upstream_list)
+
+ # Remove operations that should be deleted
+ for op_to_delete in to_delete:
+ if op_to_delete in filtered_ops:
+ filtered_ops.remove(op_to_delete)
+
+ # Add new operations
+ filtered_ops.extend(to_add)
+
+ return filtered_ops
+
+
+all_operations = filter_and_extend_upstream(
+ xstablehlo.StableHLO.operations, UPSTREAM_OPERATIONS_TO_DELETE, OPERATIONS
+)
+all_attributes = filter_and_extend_upstream(
+ xstablehlo.StableHLO.attributes, UPSTREAM_ATTRIBUTES_TO_DELETE, ATTRIBUTES
+)
+
+# Create the extended StableHLO dialect by dynamically getting upstream components
+StableHLO = Dialect(
+ "stablehlo",
+ all_operations,
+ all_attributes,
+)
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py
new file mode 100644
index 0000000000..e38b4986f7
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py
@@ -0,0 +1,198 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dynamism operations for the StableHLO dialect.
+"""
+
+from xdsl.dialects.builtin import DenseArrayBase, TensorType, i64
+from xdsl.irdl import (
+ IRDLOperation,
+ ParsePropInAttrDict,
+ irdl_op_definition,
+ operand_def,
+ opt_prop_def,
+ prop_def,
+ result_def,
+ traits_def,
+)
+from xdsl.traits import (
+ ConditionallySpeculatable,
+ NoMemoryEffect,
+)
+from xdsl.utils.exceptions import VerifyException
+
+from catalyst.python_interface.xdsl_extras import TensorConstraint
+
+from .types import HLO_AnyTensor, HLO_DimensionValue
+
+
+@irdl_op_definition
+class DynamicBroadcastInDimOp(IRDLOperation):
+ """
+ This operation is functionally identical to
+ [broadcast_in_dim](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim)
+ op, but the result shape is specified dynamically via ``output_dimensions``.
+
+ It also accepts optional attributes to express static knowledge about the
+ expanding behavior of dimensions. If not specified, all dimensions are
+ assumed to be possibly expanding. The sets of dimensions that are known to
+ be expanding and the set of dimensions that are known to be non-expanding
+ must be disjoint and they must be a subset of the operand's dimensions.
+
+ See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim
+
+ Example:
+ ```mlir
+ %operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
+ %output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
+ %result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
+ broadcast_dimensions = array,
+ known_expanding_dimensions = array,
+ known_nonexpanding_dimensions = array
+ } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
+ ```
+ """
+
+ name = "stablehlo.dynamic_broadcast_in_dim"
+
+ operand = operand_def(HLO_AnyTensor)
+ output_dimensions = operand_def(TensorConstraint(element_type=HLO_DimensionValue, rank=1))
+ broadcast_dimensions = prop_def(DenseArrayBase.constr(i64))
+ known_expanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64))
+ known_nonexpanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_AnyTensor)
+
+ assembly_format = (
+ "$operand `,` $output_dimensions `,` `dims` `=` $broadcast_dimensions "
+ "attr-dict `:` functional-type(operands, results)"
+ )
+
+ traits = traits_def(
+ ConditionallySpeculatable(),
+ NoMemoryEffect(),
+ # TODO: InferShapedTypeOpInterface(),
+ )
+
+ irdl_options = [ParsePropInAttrDict()]
+
+ # pylint: disable=too-many-branches
+ def verify_(self):
+ """Verify the operation."""
+ operand_ty = self.operand_types[0]
+ result_ty = self.result_types[0]
+ bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member
+
+ # Operand and result must be tensors
+ assert isinstance(operand_ty, TensorType) and isinstance(result_ty, TensorType)
+
+ self._verify_rank_constraints(bcast_dims, operand_ty, result_ty)
+
+ # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates
+ if len(set(bcast_dims)) != len(bcast_dims):
+ raise VerifyException("broadcast_dimensions should not have duplicates")
+
+ self._verify_per_dimension_bounds(bcast_dims, operand_ty, result_ty)
+
+ self._verify_expansion_hints(operand_ty)
+
+ def _verify_rank_constraints(self, bcast_dims, operand_ty, result_ty):
+ """Verify then operand and result tensors against the rank constraints."""
+
+ operand_rank = operand_ty.get_num_dims()
+ result_rank = result_ty.get_num_dims()
+
+ # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank
+ if len(bcast_dims) != operand_rank:
+ raise VerifyException(
+ "broadcast_dimensions size ("
+ f"{len(bcast_dims)}"
+ ") does not match operand rank ("
+ f"{operand_rank}"
+ ")"
+ )
+
+ # dynamic_broadcast_in_dim_c3: result rank >= operand rank
+ if result_rank < operand_rank:
+ raise VerifyException(
+ "result rank ("
+ f"{result_rank}"
+ ") is less than operand rank ("
+ f"{operand_rank}"
+ ")"
+ )
+
+ # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank
+ out_dims_ty = self.output_dimensions.type # pylint: disable=no-member
+ assert isinstance(out_dims_ty, TensorType)
+ # Must be rank-1 tensor (enforced by type constraint), and length must match result
+ # rank when statically known
+ out_shape = out_dims_ty.get_shape()
+ if len(out_shape) != 1:
+ raise VerifyException("output_dimensions must be a 1D tensor")
+ if out_shape[0] != -1 and out_shape[0] != result_rank:
+ raise VerifyException(
+ "length of output_dimensions ("
+ f"{out_shape[0]}"
+ ") is not compatible with result rank ("
+ f"{result_rank}"
+ ")"
+ )
+
+ def _verify_per_dimension_bounds(self, bcast_dims, operand_ty, result_ty):
+ """Verify compatibility of operand and result dimensions."""
+ # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility
+ operand_shape = operand_ty.get_shape()
+ result_shape = result_ty.get_shape()
+ result_rank = result_ty.get_num_dims()
+
+ for i, dim_index in enumerate(bcast_dims):
+ if dim_index < 0 or dim_index >= result_rank:
+ raise VerifyException(
+ "broadcast_dimensions contains invalid value "
+ f"{dim_index} for result with rank {result_rank}"
+ )
+ op_dim = operand_shape[i]
+ res_dim = result_shape[dim_index]
+ # If operand dim is static and not size-1, require compatibility with result dim
+ if op_dim not in (-1, 1):
+ if res_dim not in (-1, op_dim):
+ raise VerifyException(
+ "size of operand dimension "
+ f"{i} ({op_dim}) is not compatible with size of result dimension "
+ f"{dim_index} ({res_dim})"
+ )
+
+ def _verify_expansion_hints(self, operand_ty):
+ """Verify the operation's expansion hints."""
+ # dynamic_broadcast_in_dim_c8: no duplicate expansion hints across both lists
+ operand_rank = operand_ty.get_num_dims()
+
+ hints = []
+ if self.known_expanding_dimensions is not None:
+ hints.extend(self.known_expanding_dimensions.get_values()) # pylint: disable=no-member
+ if self.known_nonexpanding_dimensions is not None:
+ hints.extend(
+ self.known_nonexpanding_dimensions.get_values() # pylint: disable=no-member
+ )
+ if len(set(hints)) != len(hints):
+ raise VerifyException("duplicate expansion hint for at least one operand dimension")
+
+ # dynamic_broadcast_in_dim_c9/c10: each hint must reference a valid operand dimension
+ for h in set(hints):
+ if h < 0 or h >= operand_rank:
+ raise VerifyException(
+ "hint for expanding dimension "
+ f"{h} does not refer to a valid operand dimension"
+ )
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py
new file mode 100644
index 0000000000..8180c7a208
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py
@@ -0,0 +1,214 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Binary elementwise operations for the StableHLO dialect.
+"""
+
+import abc
+from typing import Generic, TypeVar
+
+from xdsl.dialects.builtin import AnyTensorType, ComplexType, Float32Type, Float64Type, TensorType
+from xdsl.ir import Attribute, SSAValue
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ result_def,
+ traits_def,
+)
+from xdsl.traits import NoMemoryEffect
+
+from catalyst.python_interface.xdsl_extras import (
+ Elementwise,
+ SameOperandsAndResultShape,
+ SameOperandsElementType,
+)
+
+from .types import (
+ HLO_ComplexTensor,
+ HLO_Fp32Or64Tensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_Tensor,
+)
+
+# Type aliases
+F32Or64Type = Float32Type | Float64Type
+F32Or64TensorType = TensorType[F32Or64Type]
+ComplexTensorType = TensorType[ComplexType]
+
+# Generic type variables for templating
+T_LHS = TypeVar("T_LHS", bound=AnyTensorType)
+T_RHS = TypeVar("T_RHS", bound=AnyTensorType)
+T_OUT = TypeVar("T_OUT", bound=AnyTensorType)
+
+
+class ElementwiseBinaryOperation(IRDLOperation, abc.ABC, Generic[T_LHS, T_RHS, T_OUT]):
+ """
+ Templated base class for elementwise binary operations.
+
+ This class provides a flexible template for binary operations that can work
+ with different tensor types.
+
+ For more information about the semantics, see:
+ https://openxla.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
+ """
+
+ lhs = operand_def(T_LHS)
+ rhs = operand_def(T_RHS)
+ result = result_def(T_OUT)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ SameOperandsAndResultShape(),
+ Elementwise(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $lhs `,` $rhs attr-dict
+ # `:` custom(type($lhs), type($rhs), type($result))
+ # """
+
+ def __init__(self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None):
+ if result_type is None:
+ result_type = lhs.type
+ super().__init__(operands=(lhs, rhs), result_types=(result_type,))
+
+
+@irdl_op_definition
+class ComplexOp(
+ ElementwiseBinaryOperation[HLO_Fp32Or64Tensor, HLO_Fp32Or64Tensor, HLO_ComplexTensor]
+):
+ """
+ Performs element-wise conversion to a complex value from a pair of real and
+ imaginary values, `lhs` and `rhs`, and produces a `result` tensor.
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex
+ Example:
+ ```mlir
+ %result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex>
+ ```
+ """
+
+ name = "stablehlo.complex"
+
+ # assembly_format = """
+ # operands attr-dict
+ # `:` custom(type($lhs), type($rhs), type($result))
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ SameOperandsElementType(),
+ SameOperandsAndResultShape(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ )
+
+
+@irdl_op_definition
+class DivideOp(
+ ElementwiseBinaryOperation[
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ ]
+):
+ """
+ Performs element-wise division of dividend `lhs` and divisor `rhs` tensors
+ and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide
+
+ Example:
+ ```mlir
+ %result = stablehlo.divide %lhs, %rhs : tensor<4xf32>
+ ```
+ """
+
+ name = "stablehlo.divide"
+
+
+@irdl_op_definition
+class MaximumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise max operation on tensors `lhs` and `rhs` and produces
+ a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum
+
+ Example:
+ ```mlir
+ %result = stablehlo.maximum %lhs, %rhs : tensor<4xf32>
+ ```
+ """
+
+ name = "stablehlo.maximum"
+
+
+@irdl_op_definition
+class MinimumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise min operation on tensors `lhs` and `rhs` and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum
+
+ Example:
+ ```mlir
+ %result = stablehlo.minimum %lhs, %rhs : tensor<4xf32>
+ ```
+ """
+
+ name = "stablehlo.minimum"
+
+
+@irdl_op_definition
+class PowerOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
+
+ Example:
+ ```mlir
+ %result = stablehlo.power %lhs, %rhs : tensor<6xf64>
+ ```
+ """
+
+ name = "stablehlo.power"
+
+
+@irdl_op_definition
+class RemainderOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors
+ and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder
+
+ Example:
+ ```mlir
+ %result = stablehlo.remainder %lhs, %rhs : tensor<4xi64>
+ ```
+ """
+
+ name = "stablehlo.remainder"
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py
new file mode 100644
index 0000000000..18527a028d
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py
@@ -0,0 +1,236 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Other elementwise operations for the StableHLO dialect.
+"""
+
+import xdsl_jax.dialects.stablehlo as xstablehlo
+from xdsl.dialects.builtin import (
+ AnyFloat,
+ DenseArrayBase,
+ DenseIntOrFPElementsAttr,
+ IntegerAttr,
+ TensorType,
+ i32,
+ i64,
+)
+from xdsl.irdl import (
+ IRDLOperation,
+ attr_def,
+ irdl_op_definition,
+ operand_def,
+ opt_attr_def,
+ prop_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_region_def,
+)
+from xdsl.irdl.attributes import eq
+from xdsl.irdl.constraints import AtLeast
+from xdsl.traits import NoMemoryEffect, RecursiveMemoryEffect, SingleBlockImplicitTerminator
+
+from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape
+
+from .types import HLO_AnyTensor, HLO_FpOrQuantizedIntTensor, HLO_PredTensor, HLO_Tensor
+
+# Type aliases
+FloatTensorType = TensorType[AnyFloat]
+
+
+@irdl_op_definition
+class ClampOp(IRDLOperation):
+ """Element-wise clamp with min and max bounds.
+
+ See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp
+ """
+
+ name = "stablehlo.clamp"
+
+ min = operand_def(HLO_Tensor)
+ operand = operand_def(HLO_Tensor)
+ max = operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $min `,` $operand `,` $max attr-dict
+ # `:` custom(type($min), type($operand), type($max), type($result))
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ # TODO: HLO_CompatibleOperandsAndResultElementType(),
+ # TODO: HLO_BroadcastingElementwise(),
+ # TODO: InferTensorType(),
+ # TODO: InferShapedTypeOpInterface(),
+ )
+
+
+@irdl_op_definition
+class CompareOp(IRDLOperation):
+ """Element-wise compare with direction and type attributes."""
+
+ name = "stablehlo.compare"
+
+ assembly_format = """
+ $comparison_direction `,` $lhs `,` $rhs (`,` $comparison_type^)? attr-dict `:` functional-type(operands, results)
+ """
+
+ lhs = operand_def(HLO_Tensor)
+ rhs = operand_def(HLO_Tensor)
+ result = result_def(HLO_PredTensor)
+ comparison_direction = attr_def(xstablehlo.ComparisonDirectionAttr)
+ comparison_type = opt_attr_def(xstablehlo.ComparisonTypeAttr)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ Elementwise(),
+ SameOperandsAndResultShape(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ # TODO: HLO_CompatibleOperandsElementType(),
+ # TODO: InferTensorTypeWithReify(),
+ )
+
+
+@irdl_op_definition
+class MapOp(IRDLOperation):
+ """
+ Applies a map function `computation` to `inputs` along the `dimensions` and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map
+
+ Example:
+ ```mlir
+ %result = "stablehlo.map"(%input0, %input1) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = stablehlo.multiply %arg0, %arg1 : tensor
+ stablehlo.return %0 : tensor
+ }) {
+ dimensions = array
+ } : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
+ ```
+ """
+
+ name = "stablehlo.map"
+
+ inputs = var_operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+ dimensions = attr_def(DenseArrayBase.constr(i64))
+ computation = var_region_def("single_block")
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ SameOperandsAndResultShape(),
+ SingleBlockImplicitTerminator(xstablehlo.ReturnOp),
+ # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic(),
+ # TODO: InferTypeOpInterface
+ # TODO: InferShapedTypeOpInterface(),
+ )
+
+
+@irdl_op_definition
+class ReducePrecisionOp(IRDLOperation):
+ """
+ Performs element-wise conversion of `operand` to another floating-point type
+ that uses `exponent_bits` and `mantissa_bits` and back to the original
+ floating-point type and produces an `output` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision
+
+ Example:
+ ```mlir
+ %output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64>
+ ```
+ """
+
+ name = "stablehlo.reduce_precision"
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand `,` `format` `=` custom($exponent_bits, $mantissa_bits)
+ # attr-dict `:` custom(type($operand), type($output))
+ # """
+
+ operand = operand_def(HLO_FpOrQuantizedIntTensor)
+ result = result_def(HLO_FpOrQuantizedIntTensor)
+
+ exponent_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(1)))
+ mantissa_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(0)))
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ Elementwise(),
+ # TODO: HLO_CompatibleOperandsAndResultType(),
+ # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(),
+ )
+
+
+@irdl_op_definition
+class SelectOp(IRDLOperation):
+ """
+ Produces a `result` tensor where each element is selected from `on_true` or
+ `on_false` tensor based on the value of the corresponding element of `pred`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select
+
+ Example:
+ ```mlir
+ %result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32>
+ ```
+ """
+
+ name = "stablehlo.select"
+
+ # assembly_format = """
+ # operands attr-dict `:`
+ # custom(type($pred), type($on_true), type($on_false), type($result))
+ # """
+
+ pred = operand_def(HLO_PredTensor)
+ on_true = operand_def(HLO_Tensor)
+ on_false = operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ )
+
+
+@irdl_op_definition
+class ConstantOp(IRDLOperation):
+ """
+ Produces an ``output`` tensor from a constant ``value``.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
+
+ Example:
+ ```mlir
+ %output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
+ """
+
+ name = "stablehlo.constant"
+
+ value = prop_def(DenseIntOrFPElementsAttr)
+ output = result_def(HLO_AnyTensor)
+
+ def __init__(self, value: DenseIntOrFPElementsAttr):
+ super().__init__(properties={"value": value}, result_types=(value.type,))
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py
new file mode 100644
index 0000000000..a8a6e6ca86
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py
@@ -0,0 +1,552 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Unary elementwise operations for the StableHLO dialect.
+"""
+
+import abc
+from typing import Generic, TypeVar
+
+from xdsl.dialects.builtin import (
+ I1,
+ AnyFloat,
+ AnyTensorType,
+ ComplexType,
+ TensorType,
+)
+from xdsl.ir import Attribute, SSAValue
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ opt_attr_def,
+ result_def,
+ traits_def,
+)
+from xdsl.traits import NoMemoryEffect
+
+from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape
+
+from .attributes import ResultAccuracyMode, ResultAccuracyModeAttr
+from .types import (
+ HLO_FloatTensor,
+ HLO_FpComplexOrQuantizedIntTensor,
+ HLO_FpOrComplexTensor,
+ HLO_FpOrQuantizedIntTensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_NonQuantizedTensor,
+ HLO_PredTensor,
+ HLO_SIntFpComplexOrQuantizedIntTensor,
+)
+
+# Type aliases
+I1TensorType = TensorType[I1]
+FloatTensorType = TensorType[AnyFloat]
+FloatOrComplexType = AnyFloat | ComplexType
+FloatOrComplexTensorType = TensorType[FloatOrComplexType]
+ComplexTensorType = TensorType[ComplexType]
+
+# Generic type variables for templating
+T_IN = TypeVar("T_IN", bound=AnyTensorType)
+T_OUT = TypeVar("T_OUT", bound=AnyTensorType)
+
+
+class ElementwiseUnaryOperation(IRDLOperation, abc.ABC, Generic[T_IN, T_OUT]):
+ """
+ Templated base class for elementwise unary operations.
+
+ This class provides a flexible template for unary operations that can work
+ with different tensor types.
+
+ For more informtation about the semantics, see:
+ https://openxla.org/xla/operation_semantics#element-wise_unary_functions
+ """
+
+ operand = operand_def(T_IN)
+ result = result_def(T_OUT)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand attr-dict `:` custom(type($operand), type($result))
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ SameOperandsAndResultShape(),
+ Elementwise(),
+ # TODO: InferShapedTypeOpInterface(),
+ # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(),
+ )
+
+ def __init__(self, operand: SSAValue, result_type: Attribute | None = None):
+ if result_type is None:
+ result_type = operand.type
+ super().__init__(operands=(operand,), result_types=(result_type,))
+
+
+@irdl_op_definition
+class ConvertOp(ElementwiseUnaryOperation[HLO_NonQuantizedTensor, HLO_NonQuantizedTensor]):
+ """
+ Performs an element-wise conversion from one element type to another on
+ `operand` tensor and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert
+
+ Example:
+ ```mlir
+ %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex>
+ ```
+ """
+
+ name = "stablehlo.convert"
+
+ traits = traits_def(SameOperandsAndResultShape())
+
+
+@irdl_op_definition
+class CosineOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise cosine operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine
+
+ Example:
+ ```mlir
+ %result = stablehlo.cosine %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.cosine"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+ # TODO: implement HLO_CompatibleOperandsAndResultType()
+ # traits = traits_def(
+ # HLO_CompatibleOperandsAndResultType()
+ # )
+
+
+@irdl_op_definition
+class ExponentialMinusOneOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise exponential minus one operation on `operand` tensor
+ and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one
+
+ Example:
+ ```mlir
+ %result = stablehlo.exponential_minus_one %operand : tensor<2xf64>
+ ```
+ """
+
+ name = "stablehlo.exponential_minus_one"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+ # TODO: implement HLO_CompatibleOperandsAndResultType()
+ # traits = traits_def(
+ # HLO_CompatibleOperandsAndResultType()
+ # )
+
+
+@irdl_op_definition
+class ExponentialOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise exponential operation on `operand` tensor and produces
+ a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential
+
+ Example:
+ ```mlir
+ %result = stablehlo.exponential %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.exponential"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+ # TODO: implement HLO_CompatibleOperandsAndResultType()
+ # traits = traits_def(
+ # HLO_CompatibleOperandsAndResultType()
+ # )
+
+
+@irdl_op_definition
+class FloorOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]):
+ """
+ Performs element-wise floor of `operand` tensor and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor
+
+ Example:
+ ```mlir
+ %result = stablehlo.floor %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.floor"
+
+
+@irdl_op_definition
+class ImagOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]):
+ """
+ Extracts the imaginary part, element-wise, from the `operand` and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag
+
+ Example:
+ ```mlir
+ %result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.imag"
+
+
+@irdl_op_definition
+class IsFiniteOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_PredTensor]):
+ """
+ Performs element-wise check whether the value in `x` is finite (i.e. is
+ neither +Inf, -Inf, nor NaN) and produces a `y` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite
+
+ Example:
+ ```mlir
+ %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1>
+ ```
+ """
+
+ name = "stablehlo.is_finite"
+
+
+@irdl_op_definition
+class LogOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise logarithm operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log
+
+ Example:
+ ```mlir
+ %result = stablehlo.log %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.log"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class LogPlusOneOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise logarithm plus one operation on `operand` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one
+
+ Example:
+ ```mlir
+ %result = stablehlo.log_plus_one %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.log_plus_one"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class LogisticOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise logistic operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic
+
+ Example:
+ ```mlir
+ %result = stablehlo.logistic %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.logistic"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class NegateOp(
+ ElementwiseUnaryOperation[
+ HLO_IntFpOrComplexOrQuantizedIntTensor, HLO_IntFpOrComplexOrQuantizedIntTensor
+ ]
+):
+ """
+ Performs element-wise negation of `operand` tensor and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate
+
+ Example:
+ ```mlir
+ %result = stablehlo.negate %operand : tensor<2x3xi32>
+ ```
+ """
+
+ name = "stablehlo.negate"
+
+
+@irdl_op_definition
+class RealOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]):
+ """
+ Extracts the real part, element-wise, from the `operand` and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real
+
+ Example:
+ ```mlir
+ %result = stablehlo.real %operand : tensor<2xcomplex> : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.real"
+
+
+@irdl_op_definition
+class RoundNearestAfzOp(
+ ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise rounding towards the nearest integer, breaking ties
+ away from zero, on the `operand` tensor and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz
+
+ Example:
+ ```mlir
+ %result = stablehlo.round_nearest_afz %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.round_nearest_afz"
+
+
+@irdl_op_definition
+class RoundNearestEvenOp(
+ ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise rounding towards the nearest integer, breaking ties
+ towards the even integer, on the `operand` tensor and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even
+
+ Example:
+ ```mlir
+ %result = stablehlo.round_nearest_even %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.round_nearest_even"
+
+
+@irdl_op_definition
+class RsqrtOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise reciprocal square root operation on `operand` tensor
+ and produces a `result` tensor, implementing the `rSqrt` operation from the
+ IEEE-754 specification.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt
+
+ Example:
+ ```mlir
+ %result = stablehlo.rsqrt %operand : tensor<2x2xf32>
+ ```
+ """
+
+ name = "stablehlo.rsqrt"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class SignOp(
+ ElementwiseUnaryOperation[
+ HLO_SIntFpComplexOrQuantizedIntTensor, HLO_SIntFpComplexOrQuantizedIntTensor
+ ]
+):
+ """
+ Returns the sign of the `operand` element-wise and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
+
+ Example:
+ ```mlir
+ %result = stablehlo.sign %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.sign"
+
+
+@irdl_op_definition
+class SineOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise sine operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine
+
+ Example:
+ ```mlir
+ %result = stablehlo.sine %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.sine"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class SqrtOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise square root operation on `operand` tensor and produces
+ a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt
+
+ Example:
+ ```mlir
+ %result = stablehlo.sqrt %operand : tensor<2x2xf32>
+ ```
+ """
+
+ name = "stablehlo.sqrt"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class TanOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise tangent operation on `operand` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan
+
+ Example:
+ ```mlir
+ %result = stablehlo.tan %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.tan"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class TanhOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise hyperbolic tangent operation on `operand` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh
+
+ Example:
+ ```mlir
+ %result = stablehlo.tanh %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.tanh"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py
new file mode 100644
index 0000000000..e6f8f0542d
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py
@@ -0,0 +1,167 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dynamism operations for the StableHLO dialect.
+"""
+
+
+from xdsl.dialects.builtin import (
+ ArrayAttr,
+ BoolAttr,
+ DenseIntElementsAttr,
+ DictionaryAttr,
+ FlatSymbolRefAttr,
+ StringAttr,
+ TensorType,
+ TupleType,
+)
+from xdsl.ir import Attribute
+from xdsl.irdl import (
+ AnyAttr,
+ IRDLOperation,
+ irdl_op_definition,
+ opt_prop_def,
+ prop_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import (
+ MemoryEffect,
+)
+from xdsl.utils.exceptions import VerifyException
+
+from .attributes import CustomCallApiVersion, CustomCallApiVersionAttr, OutputOperandAlias
+
+
+@irdl_op_definition
+class CustomCallOp(IRDLOperation):
+ """
+ Encapsulates an implementation-defined operation ``call_target_name`` that
+ takes ``inputs`` and ``called_computations`` and produces ``results``.
+
+ Depending on the API version there are two ways to pass extra bits of static
+ information to the external function:
+ 1. Use ``API_VERSION_TYPED_FFI`` which allows passing a dictionary attribute.
+ 2. Use a previous API version with a ``StringAttr`` to encode backend config.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call
+
+ Example:
+ ```mlir
+ %results = stablehlo.custom_call @foo(%input0) {
+ backend_config = {bar = 42 : i32},
+ api_version = 4 : i32,
+ called_computations = [@foo]
+ } : (tensor) -> tensor
+ ```
+ """
+
+ name = "stablehlo.custom_call"
+
+ inputs = var_operand_def(AnyAttr())
+ call_target_name = prop_def(StringAttr)
+ has_side_effect = prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ backend_config = opt_prop_def(DictionaryAttr | StringAttr)
+ api_version = prop_def(
+ CustomCallApiVersionAttr,
+ default_value=CustomCallApiVersionAttr(CustomCallApiVersion.API_VERSION_ORIGINAL),
+ )
+ called_computations = opt_prop_def(ArrayAttr[FlatSymbolRefAttr], default_value=ArrayAttr([]))
+ operand_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr])
+ result_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr])
+ output_operand_aliases = prop_def(ArrayAttr[OutputOperandAlias])
+
+ result = var_result_def(AnyAttr())
+
+ traits = traits_def(
+ MemoryEffect(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # custom($call_target_name) `(` $inputs `)`
+ # attr-dict `:` functional-type(operands, results)
+ # """
+
+ def verify_(self) -> None:
+ """Verify the CustomCallOp."""
+ # If both operand and result layout attributes are not specified then nothing to verify.
+ if self.operand_layouts is None and self.result_layouts is None:
+ return
+
+ # Layout constraints for either both operands & results or none should be specified.
+ if (self.operand_layouts is None) != (self.result_layouts is None):
+ raise VerifyException(
+ "Layout attributes should be specified for either both operands and results "
+ "or none."
+ )
+
+ assert self.operand_layouts is not None and self.result_layouts is not None
+
+ def verify_types_and_layouts(
+ types: tuple[Attribute, ...], layouts: ArrayAttr, value_name: str
+ ):
+ if len(types) != len(layouts.data):
+ raise VerifyException(
+ "Number of "
+ f"{value_name}s must match the number of {value_name} layouts, "
+ f"{len(types)} != {len(layouts.data)}"
+ )
+
+ for index, (ty, layout_attr) in enumerate(zip(types, layouts.data)):
+ # Tuple types are not fully supported with layout constraints yet
+ if isinstance(ty, TupleType):
+ raise VerifyException(
+ "Tuple types are not fully supported with layout constraints yet"
+ )
+
+ try:
+ dims = list(layout_attr.get_values())
+ except Exception as exc:
+ raise VerifyException("invalid layout attribute") from exc
+
+ # For non-tensor types, layout must be empty
+ if not isinstance(ty, TensorType):
+ if len(dims) == 0:
+ continue
+ raise VerifyException(
+ "Only tensor types can have non-empty layout: "
+ f"{value_name} #{index} of type {ty} has layout {dims}"
+ )
+
+ # For ranked tensors, require permutation of [0, rank)
+ rank = ty.get_num_dims()
+ if rank != len(dims) or sorted(dims) != list(range(rank)):
+ raise VerifyException(
+ f"incorrect layout {dims} for type {ty}, layout must be a permutation "
+ f"of [0, {rank})"
+ )
+
+ # Operand types
+ operand_types: tuple[Attribute, ...] = tuple(op.type for op in self.operands)
+
+ # Result types: if single tuple result, use its element types
+ if len(self.result_types) == 1 and isinstance(self.result_types[0], TupleType):
+ tuple_ty: TupleType = self.result_types[0]
+ result_types = tuple(tuple_ty.types.data)
+ else:
+ result_types = tuple(self.result_types)
+
+ # Verify that operands and operand layouts match.
+ verify_types_and_layouts(operand_types, self.operand_layouts, "operand")
+ # Verify that results and result layouts match.
+ verify_types_and_layouts(result_types, self.result_layouts, "result")
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py
new file mode 100644
index 0000000000..7e9bdcfdf8
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py
@@ -0,0 +1,169 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dynamism operations for the StableHLO dialect.
+"""
+
+from xdsl.dialects.builtin import DenseArrayBase, i64
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ prop_def,
+ region_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.irdl.operations import SameVariadicOperandSize
+from xdsl.traits import (
+ RecursiveMemoryEffect,
+ SingleBlockImplicitTerminator,
+)
+from xdsl.utils.exceptions import VerifyException
+from xdsl_jax.dialects import stablehlo as xstablehlo
+
+from .types import HLO_Tensor
+
+
+@irdl_op_definition
+class ReduceOp(IRDLOperation):
+ """
+ Applies a reduction function ``body`` to ``inputs`` and ``init_values`` along the
+ ``dimensions`` and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce
+
+ Example:
+ ```mlir
+ %result = "stablehlo.reduce"(%input, %init_value) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = stablehlo.add %arg0, %arg1 : tensor
+ stablehlo.return %0 : tensor
+ }) {
+ dimensions = array
+ } : (tensor<1x6xi64>, tensor) -> tensor<1xi64>
+ ```
+ """
+
+ name = "stablehlo.reduce"
+
+ inputs = var_operand_def(HLO_Tensor)
+ init_values = var_operand_def(HLO_Tensor)
+ dimensions = prop_def(DenseArrayBase.constr(i64))
+ result = var_result_def(HLO_Tensor)
+ body = region_def("single_block")
+
+ irdl_options = [SameVariadicOperandSize()]
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ # TODO: InferShapedTypeOpInterface(),
+ # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic,
+ # TODO: InferTensorTypeWithReify(),
+ SingleBlockImplicitTerminator(xstablehlo.ReturnOp),
+ )
+
+ # pylint: disable=no-member
+ # pylint: disable=too-many-branches
+ def verify_(self):
+ """Verify the ReduceOp."""
+ # Gather shaped operand/result types
+ input_types = [op.type for op in self.inputs]
+ init_types = [op.type for op in self.init_values]
+
+ self._verify_input_and_init_types(input_types, init_types)
+
+ self._verify_reducer_region(input_types)
+
+ def _verify_input_and_init_types(self, input_types, init_types):
+ """Verify the types of the inputs and init values."""
+
+ # Basic structural checks mirroring verifyReduceOpInputsAndInferShape
+ if len(input_types) == 0:
+ raise VerifyException("expected at least 1 input for reduce")
+ if len(input_types) != len(init_types):
+ raise VerifyException("number of inputs must match number of init_values")
+
+ # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility
+ dims_attr = self.dimensions
+ dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple()
+
+ # All inputs must have equal rank; dimensions must be within rank and unique
+ # and not empty.
+ ranks = []
+ for t in input_types:
+ # Tensors by op definition
+ assert hasattr(t, "get_num_dims")
+ ranks.append(t.get_num_dims())
+ rank0 = ranks[0]
+ if any(r != rank0 for r in ranks):
+ raise VerifyException("all inputs must have the same rank")
+
+ if len(dims) == 0:
+ raise VerifyException("dimensions cannot be empty for reduce")
+ if len(set(dims)) != len(dims):
+ raise VerifyException("dimensions should not have duplicates")
+ if any(d < 0 or d >= rank0 for d in dims):
+ raise VerifyException("dimensions contains an invalid value")
+
+ # Element type compatibility between each input and its init value
+ for it, iv in zip(input_types, init_types):
+ it_elem = it.get_element_type()
+ iv_elem = iv.get_element_type()
+ if it_elem != iv_elem:
+ raise VerifyException("input and init_value must have the same element type")
+
+ def _verify_reducer_region(self, input_types):
+ """Verify the operation's reducer region."""
+
+ # reduce_c2/c6: verify reducer region shape
+ # Expect block with arity 2 * number of inputs, with matching tensor element types
+ # and 0D tensors
+ if len(self.body.blocks) != 1:
+ raise VerifyException("reducer must have a single block")
+ block = self.body.blocks[0]
+
+ expected_args = 2 * len(input_types)
+ if len(block.args) != expected_args:
+ raise VerifyException(
+ f"reducer must take {expected_args} arguments, got {len(block.args)}"
+ )
+
+ # Each pair (arg_i, arg_{i+N}) must be 0D tensors of the input element type
+ for i, it in enumerate(input_types):
+ it_elem = it.get_element_type()
+ acc = block.args[i]
+ val = block.args[i + len(input_types)]
+ for a in (acc, val):
+ a_ty = a.type
+ if not hasattr(a_ty, "get_num_dims") or a_ty.get_num_dims() != 0:
+ raise VerifyException("reducer arguments must be rank-0 tensors")
+ if a_ty.get_element_type() != it_elem:
+ raise VerifyException(
+ "reducer argument element types must match input element type"
+ )
+
+ # Region must terminate with exactly len(inputs) results
+ ret = block.ops.last
+ if len(ret.operands) != len(input_types):
+ raise VerifyException("reducer must return exactly one value per input")
+ for i, it in enumerate(input_types):
+ it_elem = it.get_element_type()
+ rty = ret.operands[i].type
+ if not hasattr(rty, "get_num_dims") or rty.get_num_dims() != 0:
+ raise VerifyException("reducer return values must be rank-0 tensors")
+ if rty.get_element_type() != it_elem:
+ raise VerifyException("reducer return element types must match input element type")
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/types.py b/frontend/catalyst/python_interface/dialects/stablehlo/types.py
new file mode 100644
index 0000000000..a27f4fc8f2
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/types.py
@@ -0,0 +1,247 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+StableHLO type definitions for PennyLane's compiler infrastructure.
+
+This module provides type definitions based on the StableHLO specification
+(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including
+token types and other necessary type definitions for StableHLO operations.
+"""
+
+from typing import TypeAlias
+
+from xdsl.dialects.builtin import (
+ AnyFloatConstr,
+ ComplexType,
+ Float32Type,
+ Float64Type,
+ IndexType,
+ IntAttr,
+ IntAttrConstraint,
+ IntegerType,
+ ParametrizedAttribute,
+ Signedness,
+ SignednessAttr,
+ TensorType,
+ i1,
+)
+from xdsl.irdl import eq, irdl_attr_definition
+from xdsl.irdl.attributes import EqAttrConstraint, ParamAttrConstraint
+from xdsl.irdl.constraints import IntSetConstraint
+from xdsl_jax.dialects.stablehlo import TokenType
+
+from catalyst.python_interface.xdsl_extras.constraints import (
+ NestedTupleOfConstraint,
+)
+
+
+def _create_param_constrained_type(
+ base_attr: type, widths: list[int], signedness: Signedness | None = None
+):
+ """Create an integer type constrained using ParamAttrConstraint with IntSetConstraint."""
+ width_constraint = IntAttrConstraint(IntSetConstraint(frozenset(widths)))
+
+ if signedness is None:
+ signedness_constraint = None
+ else:
+ signedness_constraint = EqAttrConstraint(SignednessAttr(signedness))
+
+ return ParamAttrConstraint(base_attr, [width_constraint, signedness_constraint])
+
+
+# =============================================================================
+# Core StableHLO types constraints
+# =============================================================================
+
+HLO_Pred = eq(i1)
+HLO_PredTensor: TypeAlias = TensorType[HLO_Pred]
+
+# NOTE: IntegerType is defined in the StableHLO spec as:
+# IntegerType ::= SignedIntegerType | UnsignedIntegerType,
+# but the MLIR implementation is using signless integers instead of signed,
+# and there is a TODO to fix it.
+
+_HLO_INT_WIDTHS = [2, 4, 8, 16, 32, 64]
+HLO_SignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.SIGNED)
+HLO_UnsignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.UNSIGNED)
+HLO_SignlessInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, None)
+
+HLO_Int: TypeAlias = HLO_UnsignedInt | HLO_SignlessInt
+HLO_IntTensor: TypeAlias = TensorType[HLO_Int]
+
+_HLO_INT_OR_PRED_WIDTHS = [1, 2, 4, 8, 16, 32, 64]
+HLO_IntOrPred = _create_param_constrained_type(IntegerType, _HLO_INT_OR_PRED_WIDTHS, None)
+
+
+HLO_AnyIntegerOrIndex: TypeAlias = IntegerType | IndexType
+HLO_AnyIntegerOrIndexTensor: TypeAlias = TensorType.constr(HLO_AnyIntegerOrIndex)
+
+HLO_DimensionValue: TypeAlias = HLO_Int | IndexType
+
+# Constraint variants for use in unions with ParamAttrConstraint
+HLO_Float: TypeAlias = AnyFloatConstr
+HLO_Float32Or64: TypeAlias = Float32Type | Float64Type
+HLO_FloatTensor: TypeAlias = TensorType.constr(HLO_Float)
+HLO_Fp32Or64Tensor: TypeAlias = TensorType.constr(HLO_Float32Or64)
+
+# Complex as a constraint over element types {f32,f64}
+HLO_Complex: TypeAlias = ComplexType[HLO_Float32Or64]
+HLO_ComplexTensor: TypeAlias = TensorType.constr(HLO_Complex)
+
+# =============================================================================
+# Quantized element type definitions
+# =============================================================================
+
+
+@irdl_attr_definition
+class UniformQuantizedType(ParametrizedAttribute):
+ """
+ Placeholder for StableHLO per-tensor uniform quantized types.
+
+ Parameterized by width to support different quantized integer widths
+ (e.g., 8-bit, 16-bit quantization).
+ """
+
+ name = "stablehlo.uniform_quantized"
+ width: IntAttr
+ signedness: SignednessAttr
+
+
+@irdl_attr_definition
+class UniformQuantizedPerAxisType(ParametrizedAttribute):
+ """
+ Placeholder for StableHLO per-axis uniform quantized types.
+
+ Parameterized by width to support different quantized integer widths
+ (e.g., 8-bit, 16-bit quantization).
+ """
+
+ name = "stablehlo.uniform_quantized_per_axis"
+ width: IntAttr
+ signedness: SignednessAttr
+
+
+# =============================================================================
+# StableHLO quantized type aliases
+# =============================================================================
+
+_HLO_QUANTIZED_WIDTHS = [2, 4, 8, 16, 32]
+
+# Constraint-based types for operation definitions
+HLO_QuantizedSignedInt = _create_param_constrained_type(
+ UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED
+)
+HLO_QuantizedUnsignedInt = _create_param_constrained_type(
+ UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED
+)
+HLO_QuantizedAnySignednessInt = _create_param_constrained_type(
+ UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, None
+)
+HLO_QuantizedInt: TypeAlias = HLO_QuantizedSignedInt | HLO_QuantizedUnsignedInt
+
+HLO_PerAxisQuantizedSignedInt = _create_param_constrained_type(
+ UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED
+)
+HLO_PerAxisQuantizedUnsignedInt = _create_param_constrained_type(
+ UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED
+)
+HLO_PerAxisQuantizedAnySignednessInt = _create_param_constrained_type(
+ UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, None
+)
+HLO_PerAxisQuantizedInt: TypeAlias = HLO_PerAxisQuantizedSignedInt | HLO_PerAxisQuantizedUnsignedInt
+
+# =============================================================================
+# Main tensor type definitions
+# =============================================================================
+
+HLO_Tensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt]
+HLO_NonQuantizedTensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred]
+
+# Note: There is a discrepancy between the StableHLO spec and the MLIR implementation.
+# The spec does not allow unranked tensors, but the MLIR implementation
+# defines it as a tensor of any type and rank. There is a TODO to fix this in MLIR.
+# Therefore, we use the correct ranked tensor type.
+HLO_AnyTensor: TypeAlias = TensorType[
+ HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt | HLO_PerAxisQuantizedInt
+]
+HLO_TensorOrToken: TypeAlias = HLO_Tensor | TokenType
+HLO_TensorOrPerAxisQuantizedTensorOrToken: TypeAlias = HLO_AnyTensor | TokenType
+
+# HLO_AnyTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Token]>
+HLO_AnyTuple = NestedTupleOfConstraint([HLO_AnyTensor, TokenType])
+
+HLO_CustomCallValue: TypeAlias = HLO_Tensor | TokenType | HLO_AnyTuple
+
+# =============================================================================
+# HLO combined type definitions
+# =============================================================================
+
+HLO_PredOrIntTensor: TypeAlias = TensorType.constr(HLO_IntOrPred)
+
+HLO_FpOrComplexTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_Complex)
+HLO_FpOrQuantizedIntTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_QuantizedInt)
+HLO_FpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr(
+ HLO_Float | HLO_Complex | HLO_QuantizedInt
+)
+HLO_IntFpOrComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr(
+ HLO_Int | HLO_Float | HLO_Complex | HLO_QuantizedInt
+)
+HLO_SIntFpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr(
+ HLO_SignedInt | HLO_Float | HLO_Complex | HLO_QuantizedInt
+)
+
+
+__all__ = [
+ # Core types
+ "HLO_Pred",
+ "HLO_PredTensor",
+ "HLO_Int",
+ "HLO_IntTensor",
+ "HLO_AnyIntegerOrIndex",
+ "HLO_AnyIntegerOrIndexTensor",
+ "HLO_DimensionValue",
+ "HLO_Float",
+ "HLO_Float32Or64",
+ "HLO_FloatTensor",
+ "HLO_Fp32Or64Tensor",
+ "HLO_ComplexTensor",
+ "HLO_SignedInt",
+ "HLO_UnsignedInt",
+ "HLO_SignlessInt",
+ "HLO_QuantizedSignedInt",
+ "HLO_QuantizedUnsignedInt",
+ "HLO_QuantizedAnySignednessInt",
+ "HLO_QuantizedInt",
+ "HLO_PerAxisQuantizedSignedInt",
+ "HLO_PerAxisQuantizedUnsignedInt",
+ "HLO_PerAxisQuantizedAnySignednessInt",
+ "HLO_PerAxisQuantizedInt",
+ # Quantized types
+ "UniformQuantizedType",
+ "UniformQuantizedPerAxisType",
+ "HLO_Tensor",
+ "HLO_NonQuantizedTensor",
+ "HLO_AnyTensor",
+ "HLO_TensorOrToken",
+ "HLO_TensorOrPerAxisQuantizedTensorOrToken",
+ "HLO_CustomCallValue",
+ # Combined types
+ "HLO_PredOrIntTensor",
+ "HLO_FpOrComplexTensor",
+ "HLO_FpOrQuantizedIntTensor",
+ "HLO_FpComplexOrQuantizedIntTensor",
+ "HLO_IntFpOrComplexOrQuantizedIntTensor",
+ "HLO_SIntFpComplexOrQuantizedIntTensor",
+]
diff --git a/frontend/catalyst/python_interface/dialects/transform.py b/frontend/catalyst/python_interface/dialects/transform.py
new file mode 100644
index 0000000000..32449e7d5e
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/transform.py
@@ -0,0 +1,122 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This file contains an updated version of the transform dialect.
+As of the time of writing, xDSL uses the MLIR released with LLVM's
+version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916
+will be updating MLIR where the transform dialect has the
+`apply_registered_pass` operation re-defined.
+
+See the following changelog on the above PR
+
+ Things related to transform.apply_registered_pass op:
+
+ It now takes in a dynamic_options
+
+ [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as
+ a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static.
+ https://github.com/llvm/llvm-project/pull/142683
+
+ The options it takes in are now dictionaries instead of strings
+ [MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159
+ https://github.com/llvm/llvm-project/pull/143159
+
+This file will re-define the apply_registered_pass operation in xDSL
+and the transform dialect.
+
+Once xDSL moves to a newer version of MLIR, these changes should
+be contributed upstream.
+"""
+
+from xdsl.dialects.builtin import Dialect
+from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp
+from xdsl.dialects.transform import (
+ DictionaryAttr,
+ StringAttr,
+)
+from xdsl.dialects.transform import Transform as xTransform
+from xdsl.dialects.transform import (
+ TransformHandleType,
+ irdl_op_definition,
+ operand_def,
+ prop_def,
+ result_def,
+)
+from xdsl.ir import Attribute, SSAValue
+from xdsl.irdl import IRDLOperation, ParsePropInAttrDict
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class ApplyRegisteredPassOp(IRDLOperation):
+ """
+ See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop).
+ """
+
+ name = "transform.apply_registered_pass"
+
+ options = prop_def(DictionaryAttr, default_value=DictionaryAttr({}))
+ pass_name = prop_def(StringAttr)
+ target = operand_def(TransformHandleType)
+ result = result_def(TransformHandleType)
+ # While this assembly format doesn't match
+ # the one in upstream MLIR,
+ # this is because xDSL currently lacks CustomDirectives
+ # https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives
+ # https://github.com/xdslproject/xdsl/pull/4829
+ # However, storing the property in the attribute should still work
+ # specially when parsing and printing in generic format.
+ # Which is how Catalyst and XDSL currently communicate at the moment.
+ # TODO: Add test.
+ assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)"
+ irdl_options = [ParsePropInAttrDict()]
+
+ def __init__(
+ self,
+ pass_name: str | StringAttr,
+ target: SSAValue,
+ options: dict[str | StringAttr, Attribute | str | bool | int] | None = None,
+ ):
+ if isinstance(pass_name, str):
+ pass_name = StringAttr(pass_name)
+
+ if isinstance(options, dict):
+ options = DictionaryAttr(options)
+
+ super().__init__(
+ properties={
+ "pass_name": pass_name,
+ "options": options,
+ },
+ operands=[target],
+ result_types=[target.type],
+ )
+
+
+# Copied over from xDSL's sources
+# the main difference will be the use
+# of a different ApplyRegisteredPassOp
+operations = list(xTransform.operations)
+del operations[operations.index(xApplyRegisteredPassOp)]
+operations.append(ApplyRegisteredPassOp)
+
+Transform = Dialect(
+ "transform",
+ [
+ *operations,
+ ],
+ [
+ *xTransform.attributes,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst
new file mode 100644
index 0000000000..1525c030ee
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst
@@ -0,0 +1,1376 @@
+Unified Compiler Cookbook
+=========================
+
+**Note:** The cookbook is developed with the following package versions,
+on Python 3.12.11:
+
+.. code-block:: bash
+
+ jax==0.6.2
+ jaxlib==0.6.2
+ numpy==2.3.1
+ pennylane==0.44.0-dev19
+ pennylane-lightning==0.43.0
+ pennylane-catalyst==0.14.0-dev15
+ xdsl==0.53.0
+ xdsl-jax==git+https://github.com/xdslproject/xdsl-jax.git@895f7c13e8d0f02bbe99d7fb9ebcaafea4ea629f#egg=xdsl_jax
+
+Note that ``xdsl-jax`` does not currently have a release published on
+PyPI, so it needs to be installed from GitHub by running the following:
+
+.. code-block:: bash
+
+ pip install git+https://github.com/xdslproject/xdsl-jax.git
+
+Motivation
+==========
+
+As we approach FTQC, quantum compilation becomes a more and more
+important research area. Catalyst uses MLIR as its intermediate
+representation, which is also the layer in which a majority of the
+optimizations happen. However, quantum compilation researchers are not
+likely to be accustomed with MLIR or C++.
+
+So, the motivation of the βUnified Compilerβ is to provide a Python
+layer in which compilation passes can be implemented and applied.
+Additionally, we also want to enable researchers to use abstractions
+that they are familiar with. Weβre aiming to do this using xDSL, which
+is a reimplementation of MLIR in Python.
+
+This document is meant to be a quickstart for users that are interested
+in developing compiler passes for Catalyst, but are not familiar with
+MLIR or xDSL.
+
+MLIR basics
+===========
+
+If readers are already familiar to some degree with MLIR, this section
+can be skipped.
+
+SSA
+---
+
+βIn compiler design, static single assignment form (often abbreviated as
+SSA form or simply SSA) is a type of intermediate representation (IR)
+where each variable is assigned exactly onceβ [`1 <#references>`__].
+
+SSA is powerful because it allows us to define chains of uses and
+definitions, i.e, we can keep track of which operation created a
+variable (since each variable only gets created once), and all
+operations that use that variable. These chains of uses and definitions,
+if used well, can make transformations quite performant, and in some
+cases, very simple to implement, but they can also make the IR harder to
+parse.
+
+IR structure
+------------
+
+MLIR represents programs using a hierarchical, graph-like structure. In
+this structure, nodes are called *operations*, and edges are called
+*values*. Each value is the result of exactly one operation or argument
+for a block (more on that later), and has a type that is defined by the
+type system (more on that later also) [`2 <#references>`__].
+
+The IR is recursively nested - operations may contain regions, which
+contain blocks, which contain operations. More concretely, an operation
+may contain zero or more regions, each of which must contain one or more
+blocks, each of which holds a list of arguments and an ordered list of
+operations that may use those arguments [`3 <#references>`__].
+
+Operations
+~~~~~~~~~~
+
+Operations are basic units of execution. Operations are fully
+extensible, i.e., there is no fixed list of operations. Operations can
+return zero or more results, take in zero or more operands, declare
+properties and attributes, and can have zero or more successors and
+regions [`2 <#references>`__].
+
+Regions
+~~~~~~~
+
+A region is an ordered list of blocks. Its semantics are defined by the
+operation that contains them [`2 <#references>`__]. For example, an
+``scf.IfOp``, which represents a conditional operation, contains two
+regions, one for the true branch, and another for the false branch, and
+the way we interpret these regions is dependent on the fact that they
+belong to the ``scf.IfOp``.
+
+Blocks
+~~~~~~
+
+Blocks are lists of operations. The operations inside blocks are
+executed in order. Blocks take a list of block arguments, annotated in a
+function-like way. The first block in a region is special, and is called
+the βentry blockβ. The block arguments of the entry block are also
+arguments to its outer region [`2 <#references>`__].
+
+Values
+~~~~~~
+
+In MLIR, there are computable values with a type, a single defining
+operation, and zero or more uses [`4 <#references>`__]. These values are
+either the results of operations or block arguments, and adhere to SSA
+semantics.
+
+Dialects
+--------
+
+MLIR uses dialects to allow developers to define a set of high level
+operations, attributes, and types, which can be used in the intermediate
+representation to represent custom subroutines, etc. and can be
+converted into lower level representations using interpretation rules
+that can be defined. For example, the MLIR ``arith`` dialect defines
+operations for arithmetic computations, the ``linalg`` dialect defines
+linear algebra operations, etc. We use dialects to define quantum
+instructions such as gates, measurements, and attributes such as qubits,
+quantum registers.
+
+Def-use and use-def chains
+--------------------------
+
+Frequently, we might have a value and we want to determine which
+instructions use this value. The list of all users of a value is the
+*def-use chain*. Conversely, we might have an instruction and we want to
+determine which values it uses. The list of values used by an
+instruction is the *use-def chain* [`5 <#references>`__]. These chains
+allow us to iterate through operations topologically, which can be very
+powerful when implementing passes.
+
+xDSL API
+========
+
+Now that weβre familiar with the high-level constructs that MLIR uses,
+letβs go over what their xDSL actual implementation looks like.
+
+``SSAValue``
+------------
+
+``SSAValue`` is the class used to represent variables. SSA values may
+used by operations as operands, and returned as results. The three key
+properties that this class provides are listed below:
+
+- ``type``: The type of the value. Since SSA values are variables, their
+ value is not known at compile time, but their type is.
+- ``uses``: A set of all operations that use a given ``SSAValue`` as an
+ operand. The operations are wrapped around a convenience class called
+ ``Use``, and the corresponding operation can be accessed using
+ ``Use.operation``.
+- ``owner``: The operation or block that defined a given ``SSAValue``
+ (more on that later)
+
+``SSAValue`` has two subclasses, which will be seen a lot when
+inspecting xDSL modules. These are:
+
+- ``OpResult``: This subclass represents SSA values that are defined by
+ an operation
+
+ .. code-block::
+
+ c = add a b
+
+ In the above pseudocode, ``c`` is defined by the operation ``add``,
+ and in xDSL, will be represented as an ``OpResult`` instance.
+
+- ``BlockArgument``: This subclass represents SSA values that are block
+ arguments
+
+``Attribute``
+-------------
+
+An attribute defines a compile-time value. These can be used to define
+types, and can also be used to define properties of operations that are
+known at compile time. For example:
+
+- ``CustomOp``, which is an operation from the ``Quantum`` dialect (more
+ on that later) has a ``gate_name`` that must be a string, represented
+ by the ``StringAttr`` attribute. ``StringAttr`` has a reference to the
+ concrete string representing the gate name, and we can access this
+ concrete value at compile-time using ``CustomOp.gate_name.data``.
+- ``CustomOp`` also takes qubits as inputs and outputs, which are of
+ type ``QubitType``. The ``QubitType`` class inherits ``Attribute``,
+ but we use it to declare a type that can be used to define SSA values.
+
+Below is the definition of ``QubitType`` to illustrate:
+
+.. code-block:: python
+
+ # TypeAttribute just means that this attribute can be used to represent
+ # the types of the operands and results of an operation, but not its
+ # properties.
+ # ParametrizedAttribute means that this attribute can take parameters as
+ # input (although QubitType doesn't have any parameters).
+ @irdl_attr_definition
+ class QubitType(ParametrizedAttribute, TypeAttribute):
+ """A value-semantic qubit (state)."""
+
+ name = "quantum.bit"
+
+``Operation``
+-------------
+
+The ``Operation`` class is used to represent operations, which are basic
+units of execution. All instructions in a module are operations. In
+fact, the modules themselves are operations. Operations contain several
+fields used to define their form and function:
+
+- Operands: operands are runtime values that the operation consumes.
+ Note that when defining an operation, we only declare the *types* of
+ the operands, not the actual operands. Only when we construct an
+ instance of an operation do we provide actual ``SSAValue``\ s as the
+ operands, and these must adhere to the type system.
+- Properties: properties are compile-time values used to define the
+ semantics of an operation. For example, the ``CustomOp`` operation
+ that is used to define quantum gates in Catalyst has a property called
+ ``gate_name``, which is a string specifier of the gateβs name. This
+ name directly impacts how the operation should be interpreted, and its
+ value is known at compile-time.
+- Attributes: Operation attributes are stored as a dictionary,
+ containing more compile-time values. Generally, these donβt get used
+ much in xDSL, but they serve a purpose very similar to properties.
+- Result types: operations may return values, and if so, the types of
+ the return values must be defined.
+
+Below is the definition of ``CustomOp`` from the ``Quantum`` dialect,
+which represents general quantum gates to illustrate what defining
+operations looks like:
+
+.. code-block:: python
+
+ @irdl_op_definition
+ class CustomOp(IRDLOperation):
+ """A generic quantum gate on n qubits with m floating point parameters."""
+
+ name = "quantum.custom"
+
+ # assembly_format defines what the operation should look like
+ # when pretty-printed in the IR
+ assembly_format = """
+ $gate_name `(` $params `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ # These options are used because we have operands whose lengths aren't
+ # known (eg. different instances of CustomOp may have different number
+ # of qubits depending on the gate they represent (2 for CNOT, 1 for
+ # RX, etc.). These options basically say, "when the operation instance
+ # is initialized, create 2 properties that store the length of each of
+ # the different groups of operands and results.
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ # var_operand_def means that the length of this operand
+ # can vary.
+ params = var_operand_def(EqAttrConstraint(Float64Type()))
+
+ in_qubits = var_operand_def(BaseAttr(QubitType))
+
+ # prop_def means that gate_name is a required property
+ gate_name = prop_def(BaseAttr(StringAttr))
+
+ # opt_prop_def means adjoint is an optional property. Additionally,
+ # it's type is a UnitAttr(), which essentially means that a given
+ # instance of CustomOp is an adjoint gate iff it has an adjoint property.
+ # The value of the property is irrelevant; it only gets meaning from its
+ # existance.
+ adjoint = opt_prop_def(EqAttrConstraint(UnitAttr()))
+
+ in_ctrl_qubits = var_operand_def(BaseAttr(QubitType))
+
+ in_ctrl_values = var_operand_def(EqAttrConstraint(IntegerType(1)))
+
+ # var_result_def means that the length of out_qubits can vary
+ out_qubits = var_result_def(BaseAttr(QubitType))
+
+ out_ctrl_qubits = var_result_def(BaseAttr(QubitType))
+
+.. _dialects-1:
+
+Dialects
+--------
+
+The ``Dialect`` class in xDSL is used as a container around a list of
+operations, types, and attributes, and it also declares a name for the
+dialect. At the time of writing this, there are currently 4 custom
+dialects available in the xDSL layer of Catalyst:
+
+- ``Quantum``: this dialect contains operations and attributes necessary
+ for general qubit-level operations, such as gates, measurements,
+ qubits, etc.
+- ``Catalyst``: this dialect contains operations and attributes for
+ classical computing features unavailable out of the box with xDSL/MLIR
+- ``QEC``: This dialect contains operations and attributes useful for
+ QEC, such as PPRs/PPMs.
+- ``MBQC``: This dialect contains operations and attributes for
+ representing MBQC formalism.
+
+Pass API
+--------
+
+xDSL provides an API for defining and applying transformations on
+programs (or modules), which is described below:
+
+``ModulePass``
+~~~~~~~~~~~~~~
+
+``ModulePass`` is used to create rewrite passes over an IR module. It is
+the parent class used to define compiler passes (or transforms).
+``ModulePass`` has two key fields that must be implemented
+
+- ``name``: This is the name that is used to reference the pass.
+- ``apply``: This method takes a ``ModuleOp`` as input and applies the
+ rewrite pattern of the pass to the module. Note that this mutates the
+ input module *in-place* rather than returning a new, transformed
+ module.
+
+``RewritePattern``
+~~~~~~~~~~~~~~~~~~
+
+``RewritePattern`` is the class that provides the API for pattern
+matching. The most important method of this class is
+``match_and_rewrite``. The first argument to this method is an
+operation, and it must be type-hinted using the specific operation weβre
+trying to match. This type hint gets used by xDSL to match the operation
+weβre trying to rewrite. The second argument is a ``PatternRewriter``
+instance. I will cover this class in detail below, but it is essentially
+the class that provides the API for rewriting the operations that weβre
+matching.
+
+For example, if I wanted to match all Hadamard gates, I would use
+``CustomOp`` from the ``Quantum`` dialect in the type hint for the first
+argument (since there is no ``Hadamard`` operation in the ``Quantum``
+dialect), and check in the body of the method if the op is a
+``Hadamard``:
+
+.. code-block:: python
+
+ from xdsl import pattern_rewriter
+ from catalyst.python_interface.dialects.quantum import CustomOp
+
+ class MyPattern(pattern_rewriter.RewritePattern):
+ """Dummy class for example."""
+
+ # This decorator is what xDSL uses to match operations
+ # based on the type hint.
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(
+ self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter
+ ):
+ if op.gate_name.data != "Hadamard":
+ # If not Hadamard, we do nothing
+ return
+
+ # Do whatever we need
+
+``PatternRewriter``
+~~~~~~~~~~~~~~~~~~~
+
+``PatternRewriter`` is the class that provides the API for rewriting the
+IR. It includes several methods for replacing/removing/updating
+operations, replacing values, replacing uses of a value with another
+value, etc. In most cases, any rewriting that users want to do must be
+done through this API rather than manually, as it includes state
+management for keeping track of whether any changes were made, which is
+necessary for the worklist algorithm (more on that later).
+
+Some key methods are:
+
+- ``replace_op``: Replaces one operation with another.
+- ``replace_all_uses_with``: Replaces all uses of one value with
+ another.
+- ``erase_op``: Erases an operation. If this operation returns any
+ values, all uses of these values must be updated accordingly before
+ the erasure.
+- ``notify_op_modified``: Method to notify the rewriter that a change
+ was made to an operation manually.
+
+The example below shows us implementing a ``RewritePattern`` that
+updates all ``Hadamard``\ s with ``PauliX``\ s:
+
+.. code-block:: python
+
+ from xdsl import pattern_rewriter
+ from xdsl.dialects import builtin
+ from catalyst.python_interface.dialects.quantum import CustomOp
+
+ class HToXPattern(pattern_rewriter.RewritePattern):
+ """Dummy class for example."""
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(
+ self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter
+ ):
+ if op.gate_name.data != "Hadamard":
+ # If not Hadamard, we do nothing
+ return
+
+ # Update the gate name to PauliX and notify the rewriter that
+ # the op was manually updated
+ op.gate_name = builtin.StringAttr("PauliX")
+ rewriter.notify_op_modified(op)
+
+ # Alternatively, we could also create a new CustomOp for
+ # the PauliX from scratch, and replace the Hadamard with
+ # the new op:
+ # new_op = CustomOp(
+ # gate_name="PauliX",
+ # in_qubits=op.in_qubits,
+ # )
+ # rewriter.replace_op(op, new_op)
+
+``PatternRewriteWalker``
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+``PatternRewriteWalker`` walks over the IR in depth-first order, and
+applies a provided ``RewritePattern`` to it. By default, it implements a
+worklist algorithm that keeps iterating over the operations and matching
+and rewriting until a steady state is reached (i.e.Β no new changes are
+detected; this is why ``PatternRewriter`` needs to keep track of whether
+any changes were made).
+
+Putting everything together, we can create a ``ModulePass`` that
+replaces all ``Hadamard``\ s with ``PauliX``\ s
+
+.. code-block:: python
+
+ from xdsl import passes, pattern_rewriter
+ from xdsl.dialects import builtin
+ from catalyst.python_interface.dialects.quantum import CustomOp
+ from catalyst.python_interface import compiler_transform
+
+ class HToXPattern(pattern_rewriter.RewritePattern):
+ """Dummy class for example."""
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(
+ self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter
+ ):
+ if op.gate_name.data != "Hadamard":
+ # If not Hadamard, we do nothing
+ return
+
+ # Update the gate name to PauliX and notify the rewriter that
+ # the op was manually updated
+ op.gate_name = builtin.StringAttr("PauliX")
+ rewriter.notify_op_modified(op)
+
+ class HToXPass(passes.ModulePass):
+ """Pass that replaces Hadamards with PauliXs"""
+
+ name = "h-to-x"
+
+ def apply(self, ctx, module):
+ """Apply the iterative pass."""
+ walker = pattern_rewriter.PatternRewriteWalker(
+ pattern=HToXPattern()
+ )
+ walker.rewrite_module(module)
+
+ # We will cover this later
+ h_to_x_pass = compiler_transform(HToXPass)
+
+``PassPipeline``
+~~~~~~~~~~~~~~~~
+
+``PassPipeline`` is a meta-pass that takes a sequence of
+``ModulePass``\ es as input, and applies them to the input module. The
+following example shows how a sequence of ``ModulePass``\ s can be
+applied to a module ``mod``:
+
+.. code-block:: python
+
+ from xdsl import passes
+
+ pipeline = passes.PassPipeline((Pass1(), Pass2(), Pass3()))
+ pipeline.apply(xdsl.context.Context(), mod)
+
+To complete the example weβve been building in this section, letβs put
+it all together and implement a ``PassPipeline`` to apply the
+``HToXPass`` to an xDSL module.
+
+Letβs first create the module to which we want to apply the pass. For
+this, we will use the ``xdsl_from_qjit`` utility, which is described in
+the βPennyLane integrationβ section below.
+
+- **Creating the module**
+
+ .. code-block:: python
+
+ import pennylane as qml
+ from catalyst.python_interface.conversion import xdsl_from_qjit
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @xdsl_from_qjit
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit():
+ qml.Hadamard(0)
+ qml.Hadamard(1)
+ qml.Hadamard(2)
+ return qml.state()
+
+ >>> mod = circuit()
+ >>> print(mod)
+ builtin.module @circuit {
+ func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex>
+ func.return %0 : tensor<8xcomplex>
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor
+ %3 = quantum.alloc(3) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = quantum.custom "Hadamard"() %5 : !quantum.bit
+ %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor
+ %8 = tensor.extract %7[] : tensor
+ %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit
+ %10 = quantum.custom "Hadamard"() %9 : !quantum.bit
+ %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor
+ %12 = tensor.extract %11[] : tensor
+ %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit
+ %14 = quantum.custom "Hadamard"() %13 : !quantum.bit
+ %15 = tensor.extract %0[] : tensor
+ %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit
+ %17 = tensor.extract %7[] : tensor
+ %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit
+ %19 = tensor.extract %11[] : tensor
+ %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit
+ %21 = quantum.compbasis qreg %20 : !quantum.obs
+ %22 = quantum.state %21 : tensor<8xcomplex>
+ quantum.dealloc %20 : !quantum.reg
+ quantum.device_release
+ func.return %22 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ }
+
+- **Transforming the module**
+
+ In the above module, there are 3 ``CustomOp``\ s, each with gate name
+ ``Hadamard``. Letβs try applying our pass to it. Bear in mind that
+ passes update modules in-place:
+
+ .. code-block:: python
+
+ from xdsl import passes
+
+ pipeline = passes.PassPipeline((HToXPass(),))
+ pipeline.apply(xdsl.context.Context(), mod)
+
+ >>> print(mod)
+ builtin.module @circuit {
+ func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex>
+ func.return %0 : tensor<8xcomplex>
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor
+ %3 = quantum.alloc(3) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = quantum.custom "PauliX"() %5 : !quantum.bit
+ %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor
+ %8 = tensor.extract %7[] : tensor
+ %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit
+ %10 = quantum.custom "PauliX"() %9 : !quantum.bit
+ %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor
+ %12 = tensor.extract %11[] : tensor
+ %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit
+ %14 = quantum.custom "PauliX"() %13 : !quantum.bit
+ %15 = tensor.extract %0[] : tensor
+ %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit
+ %17 = tensor.extract %7[] : tensor
+ %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit
+ %19 = tensor.extract %11[] : tensor
+ %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit
+ %21 = quantum.compbasis qreg %20 : !quantum.obs
+ %22 = quantum.state %21 : tensor<8xcomplex>
+ quantum.dealloc %20 : !quantum.reg
+ quantum.device_release
+ func.return %22 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ }
+
+ Great! We can see that all the ``Hadamard``\ s have been replaced with
+ ``PauliX``\ s, just how we wanted.
+
+PennyLane integration
+=====================
+
+This section will cover the API in the ``qml.compiler.python_compiler``
+submodule.
+
+Lowering to MLIR
+----------------
+
+Catalyst compiles programs using the following workflow, when program
+capture is enabled:
+
+- Trace user function to create ``ClosedJaxpr`` (plxpr)
+- Convert plxpr to value-semantic jaxpr
+- Lower value-semantic jaxpr to MLIR
+- Apply passes defined in a static pipeline to the MLIR
+
+ - These passes include optimizations, further lowering to more
+ elementary dialects, and lowering to LLVM
+
+- Generate machine code
+
+The integration with the xDSL layer happens after we lower to MLIR. We
+currently rely on JAXβs API to lower to MLIR. This has the special
+effect of lowering to a specific dialect called StableHLO, which is used
+to represent all arithmetic operations present in the program.
+
+Once lowered to MLIR, if the original ``qjit`` decorator specified the
+xDSL pass plugin, we pass control over to the xDSL layer, which applies
+all transforms that were requested by the user. We can request the use
+of the xDSL plugin like so:
+
+.. code-block:: python
+
+ from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
+
+ @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
+ ...
+
+.. _ir-structure-1:
+
+IR structure
+~~~~~~~~~~~~
+
+The lowered MLIR has a lot of structure to aid developers, which is
+described below using the following example:
+
+.. code-block:: python
+
+ import pennylane as qml
+
+ qml.capture.enable()
+ dev = qml.device("lightning.qubit", wires=1)
+
+ @qml.qjit
+ @qml.transforms.cancel_inverses
+ @qml.transforms.merge_rotations
+ @qml.qnode(dev)
+ def circuit():
+ qml.X(0)
+ return qml.state()
+
+>>> print(circuit.mlir)
+module @circuit {
+ func.func public @jit_circuit() -> tensor<2xcomplex> attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<2xcomplex>
+ return %0 : tensor<2xcomplex>
+ }
+ module @module_circuit {
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ %0 = transform.apply_registered_pass "merge-rotations" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ %1 = transform.apply_registered_pass "remove-chained-self-inverse" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> tensor<2xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %c0_i64 = arith.constant 0 : i64
+ quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %0 = quantum.alloc( 1) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.state %3 : tensor<2xcomplex>
+ quantum.dealloc %2 : !quantum.reg
+ quantum.device_release
+ return %4 : tensor<2xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+}
+
+- The program is represented as a module with a function inside it. This
+ function contains non-QNode user code (which there is none in our
+ example for simplicity), and calls to QNodes using the
+ ``catalyst.launch_kernel`` operation.
+- QNodes are represented as modules as wellβeach QNode has its own
+ module. These modules have 4 key components:
+
+ - The ``module attributes {transform.with_named_sequence}`` contains
+ the transforms that the user requested for the QNode. In our
+ example, those are the ``merge_rotations`` and ``cancel_inverses``
+ transforms.
+ - The ``@circuit`` function represents the body of the QNode. Inside
+ this body, we initialize a device using the specified shots,
+ allocate a quantum register, and then apply both quantum and
+ classical instructions. Note that quantum instructions can *only* be
+ present inside this function. Note that this function will contain
+ an attribute called ``qnode``.
+
+SSA in the ``Quantum`` dialect
+------------------------------
+
+In the ``Quantum`` dialect, as mentioned earlier, gates are represented
+using the ``CustomOp`` operation. This operation accepts ``in_qubits``
+and ``in_ctrl_qubits``, which are variable length sequences of
+``QubitType`` attributes, which correspond to wire indices.
+``CustomOp``\ s also return ``out_qubits`` and ``out_ctrl_qubits`` of
+the same type.
+
+Before a ``QubitType`` can be used by any gates, it must be extracted
+from the quantum register, or a ``QregType``. Quantum registers can be
+thought of as a sequence containing all valid wire indices that are
+available to be used by gates/measurements. Qubits can be extracted from
+and inserted into a quantum register using the ``ExtractOp`` and
+``InsertOp`` operations. An ``AllocOp`` is used to allocate a quantum
+register with the user-provided number of device wires.
+
+Letβs take a look at a very simple example. Below, I have a very simple
+circuit with two gates applied to the same wire. Letβs take a look at
+its MLIR representation:
+
+- **Example**
+
+ .. code-block:: python
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit():
+ qml.X(0)
+ qml.H(0)
+ return qml.state()
+
+ >>> print(circuit.mlir)
+ module @circuit {
+ func.func public @jit_circuit() -> tensor<8xcomplex> attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex>
+ return %0 : tensor<8xcomplex>
+ }
+ module @module_circuit {
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %c0_i64 = arith.constant 0 : i64
+ quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %0 = quantum.alloc( 3) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit
+ %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.state %3 : tensor<8xcomplex>
+ quantum.dealloc %2 : !quantum.reg
+ quantum.device_release
+ return %4 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+ }
+
+**Notes**
+
+- ``quantum.alloc`` initializes a quantum register (``%0``) with 3
+ wires, which is the number of wires used to create the PennyLane
+ device.
+- ``quantum.extract`` extracts a qubit corresponding to wire index 0
+ (``%1``)
+
+ - ``%1`` is used as input by the ``X`` gate.
+
+- The ``X`` gate returns a qubit (``%out_qubits``), which is used by the
+ ``quantum.custom`` corresponding to the ``H`` gate.
+
+ - Note that this is different from how the circuit is defined in
+ Python. Instead of just using ``%1`` again, the ``H`` gate uses
+ ``%out_qubits``.
+
+- The ``H`` gate returns a new qubit (``%out_qubits_0``). This qubit is
+ consumed by ``quantum.insert``, which inserts updates the quantum
+ register to essentially say that ``%out_qubits_0`` should be the new
+ qubit that corresponds to wire index 0.
+- In this example, note that ``%1``, ``%out_qubits``, and
+ ``%out_qubits_0`` all correspond to wire index 0. This has cool
+ implications, that I will discuss below.
+
+Dynamic wires
+~~~~~~~~~~~~~
+
+The lowering rules for Catalyst automatically handle dynamic wires. When
+a new dynamic wire, ``w``, is used, all wires used before it are first
+inserted into the quantum register using ``InsertOp``. Only then does
+the ``QubitType`` corresponding to ``w`` get extracted from the quantum
+register using ``ExtractOp``. Consider the following example:
+
+- **Example**
+
+ .. code-block:: python
+
+ import pennylane as qml
+
+ qml.capture.enable()
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit(w1: int, w2: int):
+ qml.X(0)
+ qml.Y(w1)
+ qml.Z(w1)
+ qml.S(w2)
+ qml.T(w1)
+ qml.H(0)
+ return qml.state()
+
+ >>> print(circuit.mlir)
+ module @circuit {
+ func.func public @jit_circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg0, %arg1) : (tensor, tensor) -> tensor<8xcomplex>
+ return %0 : tensor<8xcomplex>
+ }
+ module @module_circuit {
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %c0_i64 = arith.constant 0 : i64
+ quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %0 = quantum.alloc( 3) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
+ %extracted = tensor.extract %arg0[] : tensor
+ %3 = quantum.extract %2[%extracted] : !quantum.reg -> !quantum.bit
+ %out_qubits_0 = quantum.custom "PauliY"() %3 : !quantum.bit
+ %out_qubits_1 = quantum.custom "PauliZ"() %out_qubits_0 : !quantum.bit
+ %extracted_2 = tensor.extract %arg0[] : tensor
+ %4 = quantum.insert %2[%extracted_2], %out_qubits_1 : !quantum.reg, !quantum.bit
+ %extracted_3 = tensor.extract %arg1[] : tensor
+ %5 = quantum.extract %4[%extracted_3] : !quantum.reg -> !quantum.bit
+ %out_qubits_4 = quantum.custom "S"() %5 : !quantum.bit
+ %extracted_5 = tensor.extract %arg1[] : tensor
+ %6 = quantum.insert %4[%extracted_5], %out_qubits_4 : !quantum.reg, !quantum.bit
+ %extracted_6 = tensor.extract %arg0[] : tensor
+ %7 = quantum.extract %6[%extracted_6] : !quantum.reg -> !quantum.bit
+ %out_qubits_7 = quantum.custom "T"() %7 : !quantum.bit
+ %extracted_8 = tensor.extract %arg0[] : tensor
+ %8 = quantum.insert %6[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit
+ %9 = quantum.extract %8[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits_9 = quantum.custom "Hadamard"() %9 : !quantum.bit
+ %10 = quantum.insert %8[ 0], %out_qubits_9 : !quantum.reg, !quantum.bit
+ %11 = quantum.compbasis qreg %10 : !quantum.obs
+ %12 = quantum.state %11 : tensor<8xcomplex>
+ quantum.dealloc %10 : !quantum.reg
+ quantum.device_release
+ return %12 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+ }
+
+ **Notes**
+
+ - If a qubit is inserted into the quantum register, it is not reused
+ again, and a new qubit corresponding to the same wire label must be
+ extracted from the register.
+ - When a dynamic wire is going to be used for the first time, all
+ qubits that have been used previously without being re-inserted into
+ the quantum register must be inserted.
+ - If a dynamic wire is reused before any other wires, then it does not
+ need to be inserted to and extracted from the quantum register
+ again.
+ - To use static wires after dynamic wires, the dynamic wires are again
+ re-inserted into the quantum register.
+ - All of the above make it such that dynamic wires essentially create
+ barriers that break the qubit def-use chains. This causes some
+ functionality to be lost, but makes sure that weβre using qubits
+ safely.
+
+Implications/notes
+~~~~~~~~~~~~~~~~~~
+
+- Operations on the same wires can be tracked quickly using the chain of
+ definitions and uses of the qubits (def-use chains).
+- Operations on dynamic wires (i.e wires whose values are only known at
+ run-time) are handled automatically and we donβt need to worry about
+ managing how we work around them. For context, this is an issue that
+ we found in the plxpr variant of ``cancel_inverses``, where two
+ operations on the same wire that were separated by another operation
+ on a dynamic wire would get cancelled, which is incorrect
+ (`source `__).
+- One thing to keep in mind is that qubits (``QubitType``) and quantum
+ registers (``QregType``) are there so that we conform to SSA semantics
+ in a way that works well for our purposes. They get meaning from how
+ we choose to interpret them. We could just as easily have defined the
+ ``Quantum`` dialect and Catalystβs lowering rules to use wire indices
+ the same way we do in Python, but we may lose capabilities that MLIR
+ and xDSL enable through the SSA form.
+
+``compiler_transform``
+----------------------
+
+``compiler_transform`` is the function used to register xDSL
+``ModulePass``\ es to be used with ``qjit``-ed workflows. It is
+currently accessible as
+``qml.compiler.python_compiler.compiler_transform``.
+
+.. code-block:: python
+
+ from catalyst.python_interface import compiler_transform
+
+ class MyPass(xdsl.passes.ModulePass):
+ """MyPass that does something"""
+
+ name = "my-pass"
+
+ def apply(self, ctx, module):
+ # Apply the pass to module
+ return
+
+ my_pass = compiler_transform(MyPass)
+
+ # Program capture must be enabled to use the compiler transform
+ # as a decorator
+ qml.capture.enable()
+ dev = qml.device("lightning.qubit", wires=1)
+
+ @qml.qjit(
+ pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()]
+ )
+ @my_pass
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ circuit(1.5)
+
+The ``compiler_transform`` function returns an object that gives easy
+access to the underlying ``ModulePass``, as well as its name as seen by
+the compiler.
+
+>>> my_pass.module_pass
+__main__.MyPass
+>>> my_pass.name
+'my-pass'
+
+Additionally, we donβt need to manually apply passes using
+``PassPipeline`` when decorating QNodes with registered compiler
+transforms. Those transforms get applied automatically when the workflow
+is compiled!
+
+Conversion utilities
+--------------------
+
+The ``python_compiler.conversion`` submodule provides several utilities
+for creating xDSL modules from Python functions. There are many more
+utilities in the submodule, but I will focus on the most important ones,
+and provide examples of how to use them.
+
+- ``xdsl_module(jitted_fn: Callable) -> Callable[Any, [xdsl.dialects.builtin.ModuleOp]]``:
+ Create a wrapper around a ``jax.jit``-ed function that returns an xDSL
+ module
+- ``xdsl_from_qjit(qjitted_fn: QJIT) -> Callable[..., xbuiltin.ModuleOp]``:
+ Create a wrapper around a ``qjit``-ed function that returns an xDSL
+ module. This is currently not merged to ``master``
+- ``inline_module(from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None) -> None``:
+ This function takes two modules as input, and inlines the whole body
+ of the first module into the second. Additionally, if
+ ``change_main_to`` is provided, it looks for a function named
+ ``main``, and updates its name to ``change_main_to``
+- ``inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp, *args, **kwargs) -> None``:
+ This function takes a ``jax.jit``-ed function, converts it into an
+ xDSL module, and then inlines the contents of the xDSL module into
+ ``mod``. Note that this function does not return anything; instead, it
+ modifies ``mod`` in-place.
+
+Check out the :doc:`xDSL conversion utilities tutorial <./xdsl_utils_tutorial>` to see examples of how each of
+the utilities can be used.
+
+Useful patterns
+===============
+
+Now that we have gone over compilers, xDSL, and how itβs being used in
+PennyLane, letβs take a look at some common patterns that might be
+useful.
+
+Post-processing functions
+-------------------------
+
+Post-processing functions are purely classical, so we can leverage the
+``xdsl_module`` utility function to create xDSL modules from Python
+code, and inject it into the modules we are rewriting as needed. The
+:doc:`xDSL post-processing tutorial <./xdsl_post_processing>` shows an
+example where we perform very simple post-processing on a QNode that
+returns an expectation value by squaring the expectation value.
+
+Splitting tapes
+---------------
+
+When implementing passes that may transform our module such that
+multiple device executions are required (akin to tape transforms that
+return multiple tapes), there are various strategies that can be used.
+Because there is no guarantee of shared structure between the tapes,
+there is no one perfect strategy that can be used for all transforms.
+Iβll use tapes to provide details below about some common cases:
+
+- If all the tapes are identical (eg. ``dynamic_one_shot``), then the
+ entire execution can be put inside a ``for`` loop, with
+ post-processing on the execution outputs done how I showed in the
+ above notebook.
+- If all gates are identical, but measurements are different (eg.
+ ``split_non_commuting``), we can capture all gates in a single
+ function, and then use a ``for`` loop that iterates over the number of
+ tapes. Within this ``for`` loop, we would call the aforementioned
+ function, which would evolve the state, and apply a different
+ measurement inside each iteration of the loop. Post-processing can be
+ handled same as above.
+- If there is little/no shared structure between the tapes, we would
+ need separate functions for each of the transformed tapes. We would
+ need to call each function one by one, and then use the results for
+ post-processing.
+
+All of the above are very non-trivial. I will leave out code examples
+for now, as that may be unnecessarily time consuming. If we get to the
+stage where we need to write a transform that splits into multiple
+tapes, we can revisit this section and the Python compiler/compilation
+team can assist in developing such transforms.
+
+Note
+~~~~
+
+Currently, we donβt have a consistent API for implementing transforms
+that split QNodes into multiple device executions (eg.
+``split_non_commuting``, etc.). Thus, it is *very* hard to implement
+transforms that do that. We cannot guarantee that one transform that
+does split QNodes into multiple device executions will work well when
+there are other transforms in the pipeline.
+
+Writing tests
+=============
+
+**Note to readers**: this section is written based on how the testing
+infrastructure for the Python compiler exists in PennyLane. However, the
+Python compiler may be getting moved to Catalyst, in which case, the
+infrastructure would likely change.
+
+FileCheck
+---------
+
+`FileCheck `__ is a
+pattern matching file verifier developed by the LLVM project and is
+widely used to test MLIR. Since xDSL and MLIR are syntactically
+identical, we can use the string representation of xDSL modules for
+testing with FileCheck.
+
+Below is an example of how an MLIR/xDSL string may look when populated
+with directives that FileCheck can use for testing. In this example, we
+have a void function ``test_func`` that takes no arguments, creates two
+qubits (corresponding to different wires), and applies a ``PauliX`` to
+each of these wires. We can see that there are 4 comments starting with
+``CHECK`` - these comments are what FileCheck uses to match the expected
+string against the actual string.
+
+The number of ``CHECK`` comments does not need to be the same as the
+number of operations in the program - we can see below that there is no
+``CHECK`` for ``func.func`` and ``return``.
+
+``CHECK`` statements can assign parameters that can be reused by later
+``CHECK``\ s. Below, the first two ``CHECK``\ s create ``q0`` and
+``q1``, which are matched using the regular expression ``%.+``, which
+matches any expressions that start with ``%`` and contain at least one
+character after it. The latter 2 ``CHECK``\ s then use ``q0`` and ``q1``
+as the input qubits for the assertion.
+
+``CHECK`` statements can contain partial statements. Below, the last two
+``CHECK``\ s donβt include the outputs of the ``quantum.custom``
+operations.
+
+.. code-block:: python
+
+ program = """
+ func.func @test_func() {
+ // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
+ // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit
+ %0 = "test.op"() : () -> !quantum.bit
+ %1 = "test.op"() : () -> !quantum.bit
+ // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit
+ // CHECK: quantum.custom "PauliX"() [[q1]] : !quantum.bit
+ %2 = quantum.custom "PauliX"() %0 : !quantum.bit
+ %3 = quantum.custom "PauliX"() %0 : !quantum.bit
+ return
+ }
+ """
+
+FileCheck essentially uses the MLIR generated by a program and asserts
+it against directives that can be specified by the user. Commonly used
+directives are:
+
+- ``CHECK``: This is used to check if a specified pattern is found. Its
+ syntax is very simple: they are fixed strings that must occur in
+ order, and horizontal whitespace is ignored by default.
+- ``CHECK-NOT``: This directive is used to check that the provided
+ string does not occur between two matches, before the first match, or
+ after the last match.
+- ``CHECK-DAG``: This directive may be used when itβs necessary to match
+ strings that donβt have to occur in the same order, but it is able to
+ match valid topological orderings of the program DAG, with edges from
+ the definition to the uses of a variable.
+- ``CHECK-NEXT``: This directive is used to check that the matched line
+ occurs right after the previously matched line with no other lines in
+ between them.
+- ``CHECK-SAME``: This directive is used when we want to match lines and
+ would like to verify that matches happen on the same line as the
+ previous match.
+
+To find more details about the above directives, or to learn about other
+available directives, please refer to the `FileCheck
+documentation `__.
+
+Test dialect
+------------
+
+xDSL provides a ``Test`` dialect
+(`source `__),
+which contains many operations that are useful for unit testing. In our
+testing, we found the ``TestOp`` operation to be the most useful. This
+operation can produce arbitrary results, which we can use to limit
+artificial dependencies on other dialects.
+
+For example, if I just need to assert that a specific gate is present,
+without ``TestOp``, I would need my module to contain an ``AllocOp``
+that creates a quantum register, an ``ExtractOp`` that extracts a qubit
+from the register, and only then I can use the qubit for the gate Iβm
+trying to match. Instead, I can just insert a ``TestOp`` that returns a
+qubit and use that.
+
+This is very powerful for unit testing, as it makes writing tests much
+simpler, while also limiting the scope of the test as one would expect
+for unit tests.
+
+In the code block in the previous section, ``TestOp`` has been used in
+exactly the described way - we use it to create 2 qubits that are then
+used as input for the 2 ``PauliX`` gates.
+
+.. _pennylane-integration-1:
+
+PennyLane integration
+---------------------
+
+To use FileCheck with ``pytest``, we use the ```filecheck`` Python
+package `__, which allows us to use
+assertions for testing in a way that ``pytest`` can understand. All of
+the ``filecheck`` API has been captured inside two fixtures available
+within the ``tests/python_compiler`` folder:
+
+- ``run_filecheck``: This fixture is for unit testing. One can specify a
+ program along with filecheck directives as a multi-line string.
+- ``run_filecheck_qjit``: This fixture is for integration testing. One
+ can create a normal ``qml.qjit``-ed workflow and include filecheck
+ directives as in-line comments.
+
+Letβs write tests for the ``HToXPass`` that was implemented in the
+`βPass APIβ <#pass-api>`__ sub-section to illustrate. The dev comments
+will explain what is going on.
+
+``run_filecheck`` example
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def test_h_to_x_pass(run_filecheck):
+ """Test that Hadamard gets converted into PauliX."""
+ # The original program creates a qubit, and gives it to a
+ # CustomOp that is a Hadamard. The CHECK directives check that
+ # the transformed program has a CustomOp that is a PauliX applied
+ # to the same qubit, and no CustomOp that is a Hadamard.
+
+ # Below we also see how we can use `test.op`, which we use to
+ # create a qubit to give to the Hadamard.
+ program = """
+ func.func @test_func() {
+ // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
+ %0 = "test.op"() : () -> !quantum.bit
+ // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %1 = quantum.custom "Hadamard"() %0 : !quantum.bit
+ return
+ }
+ """
+
+ # First, we must create the pass pipeline that we want to apply
+ # to our program.
+ pipeline = (HToXPass(),)
+
+ # Next, we use run_filecheck to run filecheck testing on the
+ # program. The fixture will create an xDSL module from the program
+ # string, call the filecheck API, and assert correctness.
+ run_filecheck(program, pipeline)
+
+ # Optionally, there are two keyword arguments that can modify the
+ # behaviour of run_filecheck. Both the roundtrip and verify
+ # arguments are false by default.
+ run_filecheck(program, pipeline, roundtrip=True, verify=True)
+
+ # roundtrip=True makes it so that after parsing the program string
+ # to an xDSL module, we print it as a string again, and then parse
+ # it back into an xDSL module. This is useful when writing tests
+ # for dialects, so that we can check that the dialects can be printed
+ # parsed correctly.
+
+ # verify=True simply runs `module.verify()`, which iteratively uses
+ # xDSL's verifiers to verify all operations and attributes in the
+ # module.
+
+``run_filecheck_qjit`` example
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
+
+ def test_h_to_x_pass_integration(run_filecheck_qjit):
+ """Test that Hadamard gets converted into PauliX."""
+ # The original program simply applies a Hadamard to a circuit
+ # Remember that we need to specify the pass plugin. Additionally,
+ # with qjit, we can use the decorator created using
+ # `compiler_transform`. To make sure that the xDSL API works
+ # correctly, program capture must be enabled.
+ # qml.capture.enable()
+ @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath])
+ @h_to_x_pass
+ def circuit():
+ # CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
+ # CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit
+ # CHECK-NOT: quantum.custom "Hadamard"
+ qml.Hadamard(0)
+ return qml.state()
+
+ # Finally, we use the run_filecheck_qjit fixture. We pass it our
+ # original qjitted workflow. It extracts the filecheck directives
+ # from the workflow, creates an MLIR program, parses it to xDSL,
+ # applies the specified transforms, and uses the filecheck API to
+ # assert correctness.
+ run_filecheck_qjit(circuit)
+
+ # run_filecheck_qjit also accepts a boolean `verify` argument
+ # that is false by default, which works exactly the same way as
+ # the `verify` argument of `run_filecheck`.
+ run_filecheck_qjit(circuit, verify=True)
+
+Key blockers
+============
+
+There are several blockers that are currently disabling developers from
+taking full advantage of the ``python_compiler`` submodule. These
+include:
+
+* Lack of support for quantum subroutines. This impacts pattern
+ matching passes that need to substitute the matched operation(s) with
+ subroutines containing quantum instructions.
+
+Strategies to circumvent blockers
+---------------------------------
+
+* We can use dummy subroutines for now. We know what the inputs and
+ outputs of these subroutines should be, so we can create our own
+ ``FuncOp``\ s that adhere to the input/output spec and just have
+ their body be empty for now. To see an example where we create a
+ dummy quantum subroutine and use it to develop a pass, check out the
+ :doc:`xDSL subroutines tutorial <./xdsl_dummy_quantum_subroutines>`.
+
+Suggested reading
+=================
+
+Useful dialects
+---------------
+
+- ``scf``: Structured control flow
+- ``func``: Functions
+- ``builtin``: Core types and attributes
+- ``arith``: arithmetic operations
+- ``stablehlo``: Advanced match dialect
+
+References
+==========
+
+#. Wikimedia Foundation. (2025, August 11). *Static single-assignment
+ form*. Wikipedia.
+ https://en.wikipedia.org/wiki/Static_single-assignment_form
+#. *MLIR Language Reference*. MLIR. (n.d.).
+ https://mlir.llvm.org/docs/LangRef/
+#. *Understanding the IR Structure*. MLIR. (n.d.-b).
+ https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/
+#. *Mlir::Value class reference*. MLIR. (n.d.-b).
+ https://mlir.llvm.org/doxygen/classmlir_1_1Value.html
+#. *LLVM Programmerβs Manual*. LLVM. (n.d.).
+ https://llvm.org/docs/ProgrammersManual.html
diff --git a/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst
new file mode 100644
index 0000000000..a1e3a53324
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst
@@ -0,0 +1,216 @@
+.. code-block:: python
+
+ from dataclasses import dataclass
+
+ import pennylane as qml
+ from catalyst.python_interface.conversion import xdsl_from_qjit
+ from catalyst.python_interface.dialects.quantum import CustomOp, QubitType
+
+ from xdsl import context, passes, pattern_rewriter
+ from xdsl.builder import ImplicitBuilder
+ from xdsl.dialects import builtin, func
+ from xdsl.ir import Block, Region
+
+Convert into xDSL module
+========================
+
+.. code-block:: python
+
+ dev = qml.device("lightning.qubit", wires=5)
+
+ @xdsl_from_qjit
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.H(0)
+ return qml.expval(qml.Z(0))
+
+
+>>> qjit_mod = circuit(1.5)
+>>> print(qjit_mod)
+builtin.module @circuit {
+ func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ func.return %0 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = quantum.custom "Hadamard"() %5 : !quantum.bit
+ %7 = quantum.namedobs %6[PauliZ] : !quantum.obs
+ %8 = quantum.expval %7 : f64
+ %9 = tensor.from_elements %8 : tensor
+ %10 = tensor.extract %0[] : tensor
+ %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit
+ quantum.dealloc %11 : !quantum.reg
+ quantum.device_release
+ func.return %9 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+}
+
+
+Letβs create a quantum subroutine
+=================================
+
+This subroutineβs purpose is to replace Hadamard gates with a blackbox
+(which will be empty for now), so it must take a single qubit as its
+argument. Additionally, we are assuming that the qubits on which the
+subroutine will act are not just the ones that the subroutine takes as
+input, so we must also provide the quantum register as input, and also
+give it as the output.
+
+``FuncOp``\ s have a single region with a single block that contains the
+body of the function.
+
+We need to build the functionβs body by populating its inner ``Block``
+with operations. We can do so using the ``xdsl.builder.ImplicitBuilder``
+class. This class can be used as a context manager that takes a
+``Block`` as input, and any operations created within the context of the
+builder get added to the block. Letβs try it out.
+
+Here, we create a subroutine that we will use to replace
+``Hadamard``\ s. This subroutine applies a gate provided by the user,
+and returns the ``out_qubit`` of the gate.
+
+.. code-block:: python
+
+ def create_hadamard_replacement_subroutine(gate_name):
+ input_types = (QubitType(),)
+ output_types = (QubitType(),)
+ block = Block(arg_types=input_types)
+
+ with ImplicitBuilder(block):
+ in_qubits = [block.args[0]]
+ op1 = CustomOp(in_qubits=in_qubits, gate_name=gate_name)
+ func.ReturnOp(op1.out_qubits[0])
+
+ region = Region([block])
+ funcOp = func.FuncOp("replace_hadamard", (input_types, output_types), region=region)
+ return funcOp
+
+
+>>> funcOp = create_hadamard_replacement_subroutine("S")
+>>> print(funcOp)
+func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit {
+ %1 = quantum.custom "S"() %0 : !quantum.bit
+ func.return %1 : !quantum.bit
+}
+
+
+Now, we write a pass to do the substitution
+===========================================
+
+.. code-block:: python
+
+ class ReplaceHadamardPattern(pattern_rewriter.RewritePattern):
+
+ def __init__(self, subroutine: func.FuncOp):
+ self.subroutine = subroutine
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(self, customOp: CustomOp, rewriter: pattern_rewriter.PatternRewriter):
+ if customOp.gate_name.data != "Hadamard":
+ return
+
+ callOp = func.CallOp(
+ builtin.SymbolRefAttr("replace_hadamard"),
+ [customOp.in_qubits[0]],
+ self.subroutine.function_type.outputs.data,
+ )
+ rewriter.insert_op_after_matched_op(callOp)
+ rewriter.replace_all_uses_with(customOp.out_qubits[0], callOp.results[0])
+ rewriter.erase_op(customOp)
+
+
+ @dataclass(frozen=True)
+ class ReplaceHadamardPass(passes.ModulePass):
+ name = "replace-hadamard"
+ gate_name: str
+
+ def apply(self, ctx: context.Context, module: builtin.ModuleOp):
+ funcOp = create_hadamard_replacement_subroutine(self.gate_name)
+ module.regions[0].blocks.first.add_op(funcOp)
+
+ pattern_rewriter.PatternRewriteWalker(
+ pattern_rewriter.GreedyRewritePatternApplier([ReplaceHadamardPattern(funcOp)])
+ ).rewrite_module(module)
+
+Letβs see it in action
+======================
+
+Here, we will replace all ``Hadamard``\ s with ``S``\ s
+
+.. code-block:: python
+
+ ctx = context.Context()
+
+ pipeline = passes.PassPipeline((ReplaceHadamardPass("S"),))
+ pipeline.apply(ctx, qjit_mod)
+
+Great! We can see below that ``Hadamard`` was replaced by a call to
+``replace_hadamard``, which applies a single ``S`` gate.
+
+>>> print(qjit_mod)
+builtin.module @circuit {
+ func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ func.return %0 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = func.call @replace_hadamard(%5) : (!quantum.bit) -> !quantum.bit
+ %7 = quantum.namedobs %6[PauliZ] : !quantum.obs
+ %8 = quantum.expval %7 : f64
+ %9 = tensor.from_elements %8 : tensor
+ %10 = tensor.extract %0[] : tensor
+ %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit
+ quantum.dealloc %11 : !quantum.reg
+ quantum.device_release
+ func.return %9 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit {
+ %1 = quantum.custom "S"() %0 : !quantum.bit
+ func.return %1 : !quantum.bit
+ }
+}
diff --git a/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst
new file mode 100644
index 0000000000..b9ca0ac3ae
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst
@@ -0,0 +1,225 @@
+Simple tutorial for injecting functions into xDSL modules
+=========================================================
+
+.. code-block:: python
+
+ from dataclasses import dataclass
+ import jax
+
+ import pennylane as qml
+ from catalyst.python_interface.conversion import inline_module, xdsl_from_qjit, xdsl_module
+
+ from xdsl import context, passes, pattern_rewriter
+ from xdsl.dialects import builtin, func
+ from xdsl.traits import SymbolTable
+ from xdsl.rewriter import InsertPoint
+
+Create workflow and convert to xDSL module
+==========================================
+
+.. code-block:: python
+
+ @xdsl_from_qjit
+ @qml.qjit(target="mlir")
+ def workflow(x, y):
+ dev = qml.device("lightning.qubit", wires=5)
+
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ res = circuit(x)
+ return res - y
+
+
+>>> xmod = workflow(3.5, 4.5)
+>>> print(xmod)
+builtin.module @workflow {
+ func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ %1 = "stablehlo.convert"(%arg3) : (tensor) -> tensor
+ %2 = "stablehlo.subtract"(%0, %1) : (tensor, tensor) -> tensor
+ func.return %2 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = tensor.extract %arg0[] : tensor
+ %7 = quantum.custom "RX"(%6) %5 : !quantum.bit
+ %8 = quantum.namedobs %7[PauliZ] : !quantum.obs
+ %9 = quantum.expval %8 : f64
+ %10 = tensor.from_elements %9 : tensor
+ %11 = tensor.extract %0[] : tensor
+ %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit
+ quantum.dealloc %12 : !quantum.reg
+ quantum.device_release
+ func.return %10 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+}
+
+
+Now, letβs try creating a pass that squares the output of the qnode
+===================================================================
+
+To do so, we can use the ``inline_module`` utility to easily add our
+post-processing function into the module weβre transforming. First, we
+create the function that squares the input value and turn it into an
+xDSL module.
+
+.. code-block:: python
+
+ @jax.jit
+ def square(x):
+ return x * x
+
+
+>>> square_mod = xdsl_module(square)(1.5)
+>>> print(square_mod)
+builtin.module @jit_square attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) {
+ %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor
+ func.return %0 : tensor
+ }
+}
+
+
+.. code-block:: python
+
+ def is_kernel_launch(op):
+ return op.name == "catalyst.launch_kernel"
+
+
+ class SquarePattern(pattern_rewriter.RewritePattern):
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter):
+ # We only rewrite the function that calls the qnode, and the caller of the qnode will
+ # always have catalyst.launch_kernel present. Additionally, we only rewrite the caller
+ # if it hasn't already been rewritten. We can put a UnitAttr() inside the caller's
+ # attributes to indicate whether it has been rewritten or not.
+ if funcOp.attributes.get("transformed") == builtin.UnitAttr() or not any(
+ is_kernel_launch(op) for op in funcOp.body.ops
+ ):
+ return
+
+ # Update funcOp to inidicate that it has been rewritten
+ funcOp.attributes["transformed"] = builtin.UnitAttr()
+
+ # Insert square into the module
+ mod = funcOp.parent_op()
+ inline_module(square_mod, mod, change_main_to="square")
+ square_fn = SymbolTable.lookup_symbol(mod, "square")
+
+ # Call square_fn and use its results instead of the qnode's results for
+ # the rest of the function
+ for op in funcOp.body.walk():
+ if is_kernel_launch(op):
+ callOp = func.CallOp(
+ builtin.SymbolRefAttr(square_fn.sym_name),
+ op.results,
+ square_fn.function_type.outputs.data,
+ )
+ rewriter.insert_op(callOp, InsertPoint.after(op))
+
+ # We have inserted a CallOp that takes the output of the qnode as input. Let's call
+ # the qnode output %0, and the CallOp output %1. The following replaces all uses of
+ # %0 with %1 EXCEPT for the case where %0 is an input to callOp
+ op.results[0].replace_by_if(callOp.results[0], lambda use: use.operation != callOp)
+ rewriter.notify_op_modified(funcOp)
+
+
+ @dataclass(frozen=True)
+ class SquarePass(passes.ModulePass):
+ name = "square"
+
+ def apply(self, ctx: context.Context, module: builtin.ModuleOp):
+ pattern_rewriter.PatternRewriteWalker(
+ pattern_rewriter.GreedyRewritePatternApplier([SquarePattern()])
+ ).rewrite_module(module)
+
+Letβs apply the pass to our workflow
+====================================
+
+.. code-block:: python
+
+ ctx = context.Context()
+
+ pipeline = passes.PassPipeline((SquarePass(),))
+ pipeline.apply(ctx, xmod)
+
+Great! Letβs see what the transformed module looks like
+=======================================================
+
+As you can see below, the ``square_xdsl`` function is the first function
+in the module, and it gets called by ``jit_workflow``, and its
+inputs/outputs are consistent with the behaviour we wanted.
+
+>>> print(xmod)
+builtin.module @workflow {
+ func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface, transformed} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ %1 = func.call @square(%0) : (tensor) -> tensor
+ %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor
+ %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor
+ func.return %3 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = tensor.extract %arg0[] : tensor
+ %7 = quantum.custom "RX"(%6) %5 : !quantum.bit
+ %8 = quantum.namedobs %7[PauliZ] : !quantum.obs
+ %9 = quantum.expval %8 : f64
+ %10 = tensor.from_elements %9 : tensor
+ %11 = tensor.extract %0[] : tensor
+ %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit
+ quantum.dealloc %12 : !quantum.reg
+ quantum.device_release
+ func.return %10 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ func.func public @square(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) {
+ %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor
+ func.return %0 : tensor
+ }
+}
diff --git a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst
new file mode 100644
index 0000000000..8c0163a3de
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst
@@ -0,0 +1,404 @@
+Python compiler utilities
+=========================
+
+All utilities we care about are in the
+``catalyst.python_interface.conversion`` submodule.
+
+.. code-block:: python
+
+ import pennylane as qml
+
+ from catalyst.python_interface.conversion import (
+ inline_jit_to_module,
+ inline_module,
+ xdsl_from_qjit,
+ xdsl_module,
+ )
+
+``xdsl_module``
+===============
+
+This function takes a ``jax.jit``-ed function as input, and returns a
+wrapper. This wrapper can be called to return an xDSL module. Note that
+this function is intended to be used to covert purely classical
+functions into xDSL modules. Letβs take a look at a very simple example:
+
+.. code-block:: python
+
+ import jax
+
+
+ @jax.jit
+ def inner(x):
+ return x**2
+
+
+ @jax.jit
+ def outer(x, y):
+ return inner(x) - y
+
+
+>>> wrapped_outer = xdsl_module(outer)
+>>> jit_mod = wrapped_outer(1.5, 2.5)
+>>> print(jit_mod)
+builtin.module @jit_outer attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) {
+ %0 = func.call @inner(%arg1) : (tensor) -> tensor
+ %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor
+ func.return %1 : tensor
+ }
+ func.func private @inner(%arg0 : tensor) -> (tensor) {
+ %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor
+ func.return %0 : tensor