diff --git a/.codecov.yml b/.codecov.yml
index e45b6334f3..057bd2b172 100644
--- a/.codecov.yml
+++ b/.codecov.yml
@@ -7,3 +7,6 @@ coverage:
status:
project: false
patch: true
+
+ignore:
+ - "frontend/catalyst/python_interface"
diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml
index 999e5ad137..d622074d73 100644
--- a/.github/workflows/check-catalyst.yaml
+++ b/.github/workflows/check-catalyst.yaml
@@ -473,6 +473,9 @@ jobs:
# macOS requirements.txt
python3 -m pip install cuda-quantum==0.6.0
python3 -m pip install oqc-qcaas-client
+ # Install graphviz for testing the mlir-op-graph integration
+ sudo apt-get install -y graphviz
+ python3 -m pip install graphviz
make frontend
- name: Get Cached LLVM Build
@@ -556,6 +559,9 @@ jobs:
sudo apt-get install -y libasan6 make
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
python3 -m pip install -r requirements.txt
+ # Install graphviz for testing the mlir-op-graph integration
+ sudo apt-get install -y graphviz
+ python3 -m pip install graphviz
make frontend
- name: Get Cached LLVM Build
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 8087a6d084..2ac3b66fd5 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -69,6 +69,48 @@
Improvements 🛠
+* Add an experimental `outline_state_evolution_pass` xDSL pass to `catalyst.python_interface.transforms`,
+ which moves all quantum gate operations to a private callable.
+ [(#8367)](https://github.com/PennyLaneAI/pennylane/pull/8367)
+
+* A new experimental `split_non_commuting_pass` compiler pass has been added to
+ `catalyst.python_interface.transforms`. This pass splits quantum functions that
+ measure observables on the same wires into multiple function executions, where
+ each execution measures observables on different wires (using the "wires" grouping
+ strategy). The original function is replaced with calls to these generated functions,
+ and the results are combined appropriately.
+ [(#8531)](https://github.com/PennyLaneAI/pennylane/pull/8531)
+
+* Add the `PCPhaseOp` operation to the xDSL Quantum dialect.
+ [(#8621)](https://github.com/PennyLaneAI/pennylane/pull/8621)
+
+* Users can now apply xDSL passes without the need to pass the `pass_plugins` argument to
+ the `qjit` decorator.
+ [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572)
+ [(#8573)](https://github.com/PennyLaneAI/pennylane/pull/8573)
+ [(#2169)](https://github.com/PennyLaneAI/catalyst/pull/2169)
+ [(#2183)](https://github.com/PennyLaneAI/catalyst/pull/2183)
+
+* The :meth:`catalyst.python_interface.transforms.convert_to_mbqc_formalism_pass` now
+ supports :class:`~xdsl.dialects.scf.IndexSwitchOp` in IR and ignores regions that have no body.
+ [(#8632)](https://github.com/PennyLaneAI/pennylane/pull/8632)
+
+* The `convert_to_mbqc_formalism` compilation pass now outlines the operations to represent a gate
+ in the MBQC formalism into subroutines in order to reduce the IR size for large programs.
+ [(#8619)](https://github.com/PennyLaneAI/pennylane/pull/8619)
+
+* The :meth:`catalyst.python_interface.Compiler.run` method now accepts a string as input,
+ which is parsed and transformed with xDSL.
+ [(#8587)](https://github.com/PennyLaneAI/pennylane/pull/8587)
+
+* An `is_xdsl_pass` function has been added to the `catalyst.python_interface.pass_api` module.
+ This function checks if a pass name corresponds to an xDSL implemented pass.
+ [(#8572)](https://github.com/PennyLaneAI/pennylane/pull/8572)
+
+* A new `catalyst.python_interface.utils` submodule has been added, containing general-purpose utilities for
+ working with xDSL. This includes a function that extracts the concrete value of scalar, constant SSA values.
+ [(#8514)](https://github.com/PennyLaneAI/pennylane/pull/8514)
+
* A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of
:func:`~.qjit`. This option saves intermediate IR files after each pass,
but only when the IR is actually modified by the pass.
@@ -184,6 +226,20 @@
Bug fixes 🐛
+* The experimental xDSL :func:`~catalyst.python_interface.transforms.measurements_from_samples_pass`
+ pass has been updated to support `shots` defined by an `arith.constant` operation.
+ [(#8460)](https://github.com/PennyLaneAI/pennylane/pull/8460)
+
+* The experimental xDSL :func:`~catalyst.python_interface.transforms.diagonalize_measurements`
+ pass has been updated to fix a bug that included the wrong SSA value for final qubit insertion
+ and deallocation at the end of the circuit. A clear error is now also raised when there are
+ observables with overlapping wires.
+ [(#8383)](https://github.com/PennyLaneAI/pennylane/pull/8383)
+
+* Fixes a bug in the constructor of the xDSL Quantum dialect's `QubitUnitaryOp` that
+ prevented an instance from being constructed.
+ [(#8456)](https://github.com/PennyLaneAI/pennylane/pull/8456)
+
* Fixes an issue where a heap-to-stack allocation conversion pass was causing SIGSEGV issues
during program execution at runtime.
[(#2172)](https://github.com/PennyLaneAI/catalyst/pull/2172)
@@ -243,6 +299,15 @@
Internal changes ⚙️
+* The `catalyst.python_interface.visualization` module has been renamed to
+ `catalyst.python_interface.inspection`, and various utility functions within this module
+ have been streamlined.
+ [(#2237)](https://github.com/PennyLaneAI/catalyst/pull/2237)
+
+* Migrated the `pennylane.compiler.python_compiler` submodule from PennyLane to Catalyst.
+ It is now accessible as `catalyst.python_interface`.
+ [(#2199)](https://github.com/PennyLaneAI/catalyst/pull/2199)
+
* Resource tracking now writes out at device destruction time instead of qubit deallocation
time. The written resources will be the total amount of resources collected throughout the
lifetime of the execution. For executions that split work between multiple functions,
@@ -313,6 +378,11 @@
Documentation 📝
+* Added a "Unified Compiler Cookbook" RST file, along with tutorials, to `catalyst.python_interface.doc`,
+ which provides a quickstart guide for getting started with xDSL and its integration with PennyLane and
+ Catalyst.
+ [(#8571)](https://github.com/PennyLaneAI/pennylane/pull/8571)
+
* A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected.
[(#2136)](https://github.com/PennyLaneAI/catalyst/pull/2136)
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index a2bf73d0a9..9b56ee8fff 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -384,13 +384,13 @@ def to_mlir_opt(
*args, stdin=None, options: Optional[CompileOptions] = None, using_python_compiler=False
):
"""echo ${input} | catalyst --tool=opt *args *opts -"""
- # Check if we need to use Python compiler for xDSL passes
+ # Check if we need to use the Python interface for xDSL passes
if using_python_compiler:
- # Use Python compiler path for xDSL passes
+ # Use the Python interface path for xDSL passes
# pylint: disable-next=import-outside-toplevel
- from pennylane.compiler.python_compiler import Compiler as PythonCompiler
+ from catalyst.python_interface import Compiler as UnifiedCompiler
- compiler = PythonCompiler()
+ compiler = UnifiedCompiler()
stdin = compiler.run(stdin, callback=None)
# These are the options that may affect compilation
@@ -548,7 +548,7 @@ def check_nested_operations(op):
@debug_logger
def is_using_python_compiler(self, mlir_module=None):
- """Returns true if we need the Python compiler path.
+ """Returns true if we need the Python interface path.
This happens when:
1. xDSL plugin is explicitly loaded (legacy), OR
@@ -606,7 +606,7 @@ def _create_xdsl_pass_save_callback(self, workspace):
os.makedirs(user_transform_dir, exist_ok=True)
class SavePassIRCallback:
- """Callback to save IR after each pass in python_compiler."""
+ """Callback to save IR after each pass in python_interface."""
def __init__(self, transform_dir):
self.transform_dir = transform_dir
@@ -666,10 +666,10 @@ def run(self, mlir_module, *args, **kwargs):
# We keep this module here to keep xDSL requirement optional
# Only move this is it has been decided that xDSL is no longer optional.
# pylint: disable-next=import-outside-toplevel
- from pennylane.compiler.python_compiler import Compiler as PythonCompiler
+ from catalyst.python_interface import Compiler as UnifiedCompiler
callback = self._create_xdsl_pass_save_callback(workspace)
- compiler = PythonCompiler()
+ compiler = UnifiedCompiler()
ir = compiler.run(ir, callback=callback)
return self.run_from_ir(
diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py
index 99d2d181d1..d720f0d577 100644
--- a/frontend/catalyst/jax_primitives_utils.py
+++ b/frontend/catalyst/jax_primitives_utils.py
@@ -376,9 +376,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
try:
# pylint: disable=import-outside-toplevel
- from pennylane.compiler.python_compiler.pass_api import (
- is_xdsl_pass,
- )
+ from catalyst.python_interface.pass_api import is_xdsl_pass
if is_xdsl_pass(_pass.name):
uses_xdsl_passes = True
diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py
index f17051e41b..74f1df8865 100644
--- a/frontend/catalyst/jit.py
+++ b/frontend/catalyst/jit.py
@@ -582,13 +582,13 @@ def mlir_opt(self):
"""Obtain the MLIR representation after optimization"""
if not self.mlir_module:
return None
- using_python_compiler = self.compiler.is_using_python_compiler(self.mlir_module)
+ using_unified_compiler = self.compiler.is_using_unified_compiler(self.mlir_module)
stdin = self.mlir_module.operation.get_asm(
- print_generic_op_form=using_python_compiler,
+ print_generic_op_form=using_unified_compiler,
enable_debug_info=self.compile_options.use_nameloc,
)
return to_mlir_opt(
- stdin=stdin, options=self.compile_options, using_python_compiler=using_python_compiler
+ stdin=stdin, options=self.compile_options, using_unified_compiler=using_unified_compiler
)
@debug_logger
diff --git a/frontend/catalyst/python_interface/__init__.py b/frontend/catalyst/python_interface/__init__.py
new file mode 100644
index 0000000000..5eda3e4454
--- /dev/null
+++ b/frontend/catalyst/python_interface/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unified Compiler API for integration of Catalyst with xDSL."""
+
+from .compiler import Compiler
+from .inspection import QMLCollector
+from .parser import QuantumParser
+from .pass_api import compiler_transform
+
+__all__ = [
+ "Compiler",
+ "compiler_transform",
+ "QuantumParser",
+ "QMLCollector",
+]
diff --git a/frontend/catalyst/python_interface/compiler.py b/frontend/catalyst/python_interface/compiler.py
new file mode 100644
index 0000000000..f98ff2b8d8
--- /dev/null
+++ b/frontend/catalyst/python_interface/compiler.py
@@ -0,0 +1,83 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file contains the implementation of the PennyLane-xDSL integration API."""
+
+
+import io
+
+from jax._src.interpreters import mlir
+from jaxlib.mlir.dialects import stablehlo
+from jaxlib.mlir.ir import Context as jaxContext
+from jaxlib.mlir.ir import Module as jaxModule
+from pennylane.typing import Callable
+from xdsl.context import Context as xContext
+from xdsl.dialects.builtin import ModuleOp
+from xdsl.passes import ModulePass, PassPipeline
+from xdsl.printer import Printer
+
+from catalyst.python_interface.parser import QuantumParser
+from catalyst.python_interface.pass_api import ApplyTransformSequence
+
+
+# pylint: disable=too-few-public-methods
+class Compiler:
+ """Compiler namespace"""
+
+ @staticmethod
+ def run(
+ module: jaxModule | str,
+ callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = None,
+ ) -> jaxModule | str:
+ """Runs the apply-transform-sequence pass.
+
+ The apply-transform-sequence pass is a "meta-pass". In other words,
+ it is a pass that runs other passes.
+
+ Args:
+ module: Either a Jax MLIR module or MLIR IR as a string
+ callback: Optional callback function called between passes
+
+ Returns:
+ jaxModule | str: jaxModule if the input was a jaxModule, else a string.
+ """
+ # Convert to generic text format
+ is_jax_module = isinstance(module, jaxModule)
+ if is_jax_module:
+ gentxtmod = module.operation.get_asm(
+ binary=False, print_generic_op_form=True, assume_verified=True
+ )
+ else:
+ gentxtmod = module
+
+ # Parse and transform with xDSL
+ ctx = xContext(allow_unregistered=True)
+ parser = QuantumParser(ctx, gentxtmod)
+ # xmod is modified in place
+ xmod = parser.parse_module()
+ pipeline = PassPipeline((ApplyTransformSequence(callback=callback),))
+ pipeline.apply(ctx, xmod)
+
+ # Convert back to string
+ buffer = io.StringIO()
+ Printer(stream=buffer, print_generic_format=True).print_op(xmod)
+
+ # Convert back to jaxModule if input was jaxModule
+ if is_jax_module:
+ with jaxContext() as jctx:
+ jctx.allow_unregistered_dialects = True
+ jctx.append_dialect_registry(mlir.upstream_dialects)
+ stablehlo.register_dialect(jctx) # pylint: disable=no-member
+ newmod: jaxModule = jaxModule.parse(buffer.getvalue())
+ return newmod
+ return buffer.getvalue()
diff --git a/frontend/catalyst/python_interface/conversion.py b/frontend/catalyst/python_interface/conversion.py
new file mode 100644
index 0000000000..975a46b3cb
--- /dev/null
+++ b/frontend/catalyst/python_interface/conversion.py
@@ -0,0 +1,171 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for converting to xDSL module."""
+
+from collections.abc import Callable, Sequence
+from functools import wraps
+from typing import TypeAlias
+
+from jax._src.lib import _jax
+from jaxlib.mlir.dialects import stablehlo as jstablehlo
+from jaxlib.mlir.ir import Context as jContext
+from jaxlib.mlir.ir import Module as jModule
+from xdsl.context import Context as xContext
+from xdsl.dialects import builtin as xbuiltin
+from xdsl.dialects import func as xfunc
+from xdsl.ir import Dialect as xDialect
+from xdsl.traits import SymbolTable as xSymbolTable
+
+from catalyst import QJIT
+from catalyst.python_interface.parser import QuantumParser
+
+JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member
+
+
+def _mlir_module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule:
+ """Get the MLIR module from a jax.jitted function"""
+ return func.lower(*args, **kwargs).compiler_ir()
+
+
+def mlir_module(func: JaxJittedFunction) -> Callable[..., jModule]:
+ """Returns a wrapper that creates an MLIR module from a jax.jitted function."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> jModule:
+ return _mlir_module_inline(func, *args, **kwargs)
+
+ return wrapper
+
+
+def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover
+ """Create the generic textual representation for a jax.jitted function"""
+ lowered = func.lower(*args, **kwargs)
+ mod = lowered.compiler_ir()
+ return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)
+
+
+def generic_str(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover
+ """Returns a wrapper that creates the generic textual representation for a
+ jax.jitted function."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> str:
+ return _generic_str_inline(func, *args, **kwargs)
+
+ return wrapper
+
+
+def parse_generic_to_xdsl_module(
+ program: str, extra_dialects: Sequence[xDialect] | None = None
+) -> xbuiltin.ModuleOp: # pragma: no cover
+ """Parses a generic MLIR program string to an xDSL module."""
+ ctx = xContext(allow_unregistered=True)
+ parser = QuantumParser(ctx, program, extra_dialects=extra_dialects)
+ moduleOp: xbuiltin.ModuleOp = parser.parse_module()
+ return moduleOp
+
+
+def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover
+ """Parses a generic MLIR program string to an MLIR module."""
+ with jContext() as ctx:
+ ctx.allow_unregistered_dialects = True
+ jstablehlo.register_dialect(ctx) # pylint: disable=no-member
+ return jModule.parse(program)
+
+
+def mlir_from_docstring(func: Callable) -> jModule: # pragma: no cover
+ """Returns a wrapper that parses an MLIR program string located in the docstring
+ into an MLIR module."""
+
+ @wraps(func)
+ def wrapper(*_, **__):
+ return parse_generic_to_mlir_module(func.__doc__)
+
+ return wrapper
+
+
+def _xdsl_module_inline(
+ func: JaxJittedFunction, *args, **kwargs
+) -> xbuiltin.ModuleOp: # pragma: no cover
+ """Get the xDSL module from a jax.jitted function"""
+ generic_repr = _generic_str_inline(func, *args, **kwargs)
+ return parse_generic_to_xdsl_module(generic_repr)
+
+
+def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover
+ """Returns a wrapper that parses an MLIR program string located in the docstring
+ into an xDSL module."""
+
+ @wraps(func)
+ def wrapper(*_, **__):
+ return parse_generic_to_xdsl_module(func.__doc__)
+
+ return wrapper
+
+
+def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover
+ """Returns a wrapper that creates an xDSL module from a jax.jitted function."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp:
+ return _xdsl_module_inline(func, *args, **kwargs)
+
+ return wrapper
+
+
+def inline_module(
+ from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None
+) -> None:
+ """Inline the contents of one xDSL module into another xDSL module. The inlined body is appended
+ to the end of ``to_mod``.
+
+ If ``from_mod`` has a ``main`` function, its name is changed to ``change_main_to`` if specified.
+ """
+ if change_main_to:
+ main = xSymbolTable.lookup_symbol(from_mod, "main")
+ if main is not None:
+ assert isinstance(main, xfunc.FuncOp)
+ main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to)
+
+ for op in from_mod.body.ops:
+ xSymbolTable.insert_or_update(to_mod, op.clone())
+
+
+def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp) -> Callable[..., None]:
+ """Inline a ``jax.jit``-ed Python function to an xDSL module. The inlined body is appended
+ to the end of ``mod`` in-place. The name of the entry point function of ``func`` is the same
+ as the name of ``func``."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func_mod = _xdsl_module_inline(func, *args, **kwargs)
+ inline_module(func_mod, mod, change_main_to=func.__name__)
+
+ return wrapper
+
+
+def xdsl_from_qjit(func: QJIT) -> Callable[..., xbuiltin.ModuleOp]:
+ """Decorator to convert QJIT-ed functions into xDSL modules."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func.jaxpr, *_ = func.capture(args, **kwargs)
+ _mlir_module = func.generate_ir()
+ _generic_str = _mlir_module.operation.get_asm(
+ binary=False, print_generic_op_form=True, assume_verified=True
+ )
+ return parse_generic_to_xdsl_module(_generic_str)
+
+ return wrapper
diff --git a/frontend/catalyst/python_interface/dialects/__init__.py b/frontend/catalyst/python_interface/dialects/__init__.py
new file mode 100644
index 0000000000..1b0ab7d1a9
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This submodule contains xDSL dialects for the unified compiler."""
+
+from .catalyst import Catalyst
+from .mbqc import MBQC
+from .qec import QEC
+from .quantum import Quantum
+from .stablehlo import StableHLO
+from .transform import Transform
+
+__all__ = ["Catalyst", "MBQC", "Quantum", "QEC", "StableHLO", "Transform"]
diff --git a/frontend/catalyst/python_interface/dialects/catalyst.py b/frontend/catalyst/python_interface/dialects/catalyst.py
new file mode 100644
index 0000000000..a1cf627272
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/catalyst.py
@@ -0,0 +1,268 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This file contains the Catalyst dialect for the unified compiler.
+
+This file was originally ported automatically by xDSL (using the ``xdsl-tblgen`` tool)
+and modified manually to support the unified compiler.
+
+The catalyst dialect serves as a standard library for the Catalyst compiler.
+It contains data structures that support core compiler functionality.
+"""
+
+from typing import ClassVar
+
+from xdsl.dialects.builtin import (
+ I64,
+ ArrayAttr,
+ DenseArrayBase,
+ DictionaryAttr,
+ FlatSymbolRefAttrConstr,
+ FunctionType,
+ IntegerAttr,
+ IntegerType,
+ MemRefType,
+ StringAttr,
+ SymbolRefAttr,
+ UnitAttr,
+ i32,
+)
+from xdsl.ir import AttributeCovT, Dialect, Generic, ParametrizedAttribute, TypeAttribute
+from xdsl.irdl import (
+ AnyAttr,
+ IRDLOperation,
+ ParsePropInAttrDict,
+ VarConstraint,
+ base,
+ irdl_attr_definition,
+ irdl_op_definition,
+ operand_def,
+ opt_operand_def,
+ opt_prop_def,
+ prop_def,
+ region_def,
+ result_def,
+ var_operand_def,
+ var_result_def,
+)
+
+
+@irdl_attr_definition
+class ArrayListType(Generic[AttributeCovT], ParametrizedAttribute, TypeAttribute):
+ """A dynamically resizable array"""
+
+ name = "catalyst.arraylist"
+
+ element_type: AttributeCovT
+
+
+@irdl_op_definition
+class AssertionOp(IRDLOperation):
+ """Asserts condition at runtime."""
+
+ name = "catalyst.assert"
+
+ assertion = operand_def(IntegerType(1))
+
+ error = prop_def(StringAttr)
+
+
+@irdl_op_definition
+class CallbackCallOp(IRDLOperation):
+ """CallbackCallOp operation."""
+
+ name = "catalyst.callback_call"
+
+ assembly_format = """
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ """
+
+ callee = prop_def(FlatSymbolRefAttrConstr)
+
+ inputs = var_operand_def()
+
+ arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ callback_results = var_result_def()
+
+ irdl_options = [ParsePropInAttrDict()]
+
+
+@irdl_op_definition
+class CallbackOp(IRDLOperation):
+ """Operation denoting a symbol to refer to user callbacks."""
+
+ name = "catalyst.callback"
+
+ sym_name = prop_def(StringAttr)
+
+ function_type = prop_def(FunctionType)
+
+ id = prop_def(IntegerAttr[I64])
+
+ argc = prop_def(IntegerAttr[I64])
+
+ resc = prop_def(IntegerAttr[I64])
+
+ arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ body = region_def()
+
+
+@irdl_op_definition
+class CustomCallOp(IRDLOperation):
+ """CustomCall operation"""
+
+ name = "catalyst.custom_call"
+
+ assembly_format = """
+ `fn` `(`$call_target_name`)` `(` $inputs `)`
+ attr-dict `:` functional-type(operands, results)
+ """
+
+ inputs = var_operand_def()
+
+ call_target_name = prop_def(StringAttr)
+
+ number_original_arg = opt_prop_def(DenseArrayBase.constr(i32))
+
+ custom_results = var_result_def()
+
+ irdl_options = [ParsePropInAttrDict()]
+
+
+@irdl_op_definition
+class LaunchKernelOp(IRDLOperation):
+ """LaunchKernelOp operation."""
+
+ name = "catalyst.launch_kernel"
+
+ assembly_format = """
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ """
+
+ callee = prop_def(SymbolRefAttr)
+
+ inputs = var_operand_def()
+
+ arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
+
+ kernel_results = var_result_def()
+
+ irdl_options = [ParsePropInAttrDict()]
+
+
+@irdl_op_definition
+class ListDeallocOp(IRDLOperation):
+ """Deallocate the underlying memory of an arraylist."""
+
+ name = "catalyst.list_dealloc"
+
+ assembly_format = """ $list attr-dict `:` type($list) """
+
+ list = operand_def(ArrayListType)
+
+
+@irdl_op_definition
+class ListInitOp(IRDLOperation):
+ """Initialize a dynamically resizable arraylist."""
+
+ name = "catalyst.list_init"
+
+ assembly_format = """ attr-dict `:` type($list) """
+
+ list = result_def(ArrayListType)
+
+
+@irdl_op_definition
+class ListLoadDataOp(IRDLOperation):
+ """Get the underlying memref storing the data of an array list."""
+
+ name = "catalyst.list_load_data"
+
+ assembly_format = """ $list attr-dict `:` type($list) `->` type($data) """
+
+ list = operand_def(ArrayListType)
+
+ data = result_def(MemRefType)
+
+
+@irdl_op_definition
+class ListPopOp(IRDLOperation):
+ """Remove an element from the end of an array list and return it."""
+
+ name = "catalyst.list_pop"
+
+ assembly_format = """ $list attr-dict `:` type($list) """
+
+ T: ClassVar = VarConstraint("T", AnyAttr())
+
+ list = operand_def(base(ArrayListType[T]))
+
+ result = result_def(T)
+
+
+@irdl_op_definition
+class ListPushOp(IRDLOperation):
+ """Append an element to the end of an array list."""
+
+ name = "catalyst.list_push"
+
+ assembly_format = """ $value `,` $list attr-dict `:` type($list) """
+
+ T: ClassVar = VarConstraint("T", AnyAttr())
+
+ value = operand_def(T)
+
+ list = operand_def(base(ArrayListType[T]))
+
+
+@irdl_op_definition
+class PrintOp(IRDLOperation):
+ """Prints numeric values or constant strings at runtime."""
+
+ name = "catalyst.print"
+
+ val = opt_operand_def()
+
+ const_val = opt_prop_def(StringAttr)
+
+ print_descriptor = prop_def(UnitAttr)
+
+
+Catalyst = Dialect(
+ "catalyst",
+ [
+ AssertionOp,
+ CallbackCallOp,
+ CallbackOp,
+ CustomCallOp,
+ LaunchKernelOp,
+ ListDeallocOp,
+ ListInitOp,
+ ListLoadDataOp,
+ ListPopOp,
+ ListPushOp,
+ PrintOp,
+ ],
+ [
+ ArrayListType,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/mbqc.py b/frontend/catalyst/python_interface/dialects/mbqc.py
new file mode 100644
index 0000000000..f08ae517d7
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/mbqc.py
@@ -0,0 +1,174 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the definition of the MBQC dialect for the unified compiler.
+
+The MBQC dialect is a set of operations and types used to represent measurement-based
+quantum-computing instructions in the xDSL framework.
+
+It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the
+catalyst/mlir/include/MBQC/IR/MBQCDialect.td file in the catalyst repository.
+
+For detailed documentation on the operations contained in this dialect, please refer to the MBQC
+dialect documentation in Catalyst.
+"""
+
+from typing import TypeAlias
+
+from xdsl.dialects.builtin import (
+ I32,
+ AnyAttr,
+ Float64Type,
+ IntegerAttr,
+ IntegerType,
+ StringAttr,
+ i1,
+)
+from xdsl.ir import Dialect, EnumAttribute, Operation, SpacedOpaqueSyntaxAttribute, SSAValue
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_attr_definition,
+ irdl_op_definition,
+ operand_def,
+ opt_prop_def,
+ prop_def,
+ result_def,
+)
+from xdsl.utils.exceptions import VerifyException
+from xdsl.utils.str_enum import StrEnum # StrEnum is standard in Python>=3.11
+
+from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint
+
+from .quantum import QubitType, QuregType
+
+QubitSSAValue: TypeAlias = SSAValue[QubitType]
+
+
+class MeasurementPlaneEnum(StrEnum):
+ """Enum containing supported measurement-plane attributes"""
+
+ XY = "XY"
+ YZ = "YZ"
+ ZX = "ZX"
+
+
+@irdl_attr_definition
+class MeasurementPlaneAttr(EnumAttribute[MeasurementPlaneEnum], SpacedOpaqueSyntaxAttribute):
+ """Planes in the Bloch sphere representation with support for arbitrary-basis measurements"""
+
+ name = "mbqc.measurement_plane"
+
+
+@irdl_op_definition
+class MeasureInBasisOp(IRDLOperation):
+ """A parametric single-qubit projective measurement in an arbitrary basis."""
+
+ name = "mbqc.measure_in_basis"
+
+ assembly_format = """
+ `[` $plane `,` $angle `]` $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results)
+ """
+
+ in_qubit = operand_def(QubitType)
+
+ plane = prop_def(MeasurementPlaneAttr)
+
+ angle = operand_def(Float64Type())
+
+ postselect = opt_prop_def(IntegerAttr[I32])
+
+ mres = result_def(IntegerType(1))
+
+ out_qubit = result_def(QubitType)
+
+ def __init__(
+ self,
+ in_qubit: QubitSSAValue | Operation,
+ plane: MeasurementPlaneAttr,
+ angle: SSAValue[Float64Type],
+ postselect: int | IntegerAttr | None = None,
+ ):
+ properties = {"plane": plane}
+
+ if isinstance(postselect, int):
+ postselect = IntegerAttr.from_int_and_width(postselect, 32)
+
+ if postselect is not None:
+ properties["postselect"] = postselect
+
+ super().__init__(
+ operands=(in_qubit, angle),
+ properties=properties,
+ result_types=(IntegerType(1), QubitType()),
+ )
+
+ def verify_(self):
+ """Verify operation when rewriting."""
+ if self.postselect is None:
+ return
+
+ if self.postselect.value.data not in [0, 1]: # pylint: disable=no-member
+ raise VerifyException("'postselect' must be 0 or 1.")
+
+
+@irdl_op_definition
+class GraphStatePrepOp(IRDLOperation):
+ """Allocate resources for a new graph state."""
+
+ name = "mbqc.graph_state_prep"
+
+ assembly_format = """
+ `(` $adj_matrix `:` type($adj_matrix) `)` `[` `init` $init_op `,` `entangle` $entangle_op `]` attr-dict `:` type(results)
+ """
+
+ adj_matrix = operand_def(
+ TensorConstraint(element_type=i1, rank=1) | MemRefConstraint(element_type=i1, rank=1)
+ )
+
+ init_op = prop_def(StringAttr)
+
+ entangle_op = prop_def(StringAttr)
+
+ qreg = result_def(QuregType)
+
+ def __init__(
+ self, adj_matrix: AnyAttr, init_op: str | StringAttr, entangle_op: str | StringAttr
+ ):
+ if isinstance(init_op, str):
+ init_op = StringAttr(data=init_op)
+
+ if isinstance(entangle_op, str):
+ entangle_op = StringAttr(data=entangle_op)
+
+ properties = {"init_op": init_op, "entangle_op": entangle_op}
+
+ qreg = QuregType()
+
+ super().__init__(
+ operands=(adj_matrix,),
+ result_types=(qreg,),
+ properties=properties,
+ )
+
+
+MBQC = Dialect(
+ "mbqc",
+ [
+ MeasureInBasisOp,
+ GraphStatePrepOp,
+ ],
+ [
+ MeasurementPlaneAttr,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/qec.py b/frontend/catalyst/python_interface/dialects/qec.py
new file mode 100644
index 0000000000..b88e05203f
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/qec.py
@@ -0,0 +1,216 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module contains the QEC dialect for the unified compiler.
+
+The QEC dialect is a set of operations and types used to represent quantum error correction
+instructions in the xDSL framework.
+
+It was initially generated by xDSL (using the ``xdsl-tblgen`` tool) starting from the
+catalyst/mlir/include/QEC/IR/QECDialect.td file in the catalyst repository.
+"""
+
+from xdsl.dialects.builtin import I16, ArrayAttr, IntegerAttr, IntegerType, StringAttr, i16
+from xdsl.dialects.utils import AbstractYieldOperation
+from xdsl.ir import Attribute, Dialect, EnumAttribute, SpacedOpaqueSyntaxAttribute
+from xdsl.irdl import (
+ AttrSizedOperandSegments,
+ IRDLOperation,
+ irdl_attr_definition,
+ irdl_op_definition,
+ lazy_traits_def,
+ operand_def,
+ opt_operand_def,
+ opt_prop_def,
+ prop_def,
+ region_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import HasParent, IsTerminator, Pure, SingleBlockImplicitTerminator
+from xdsl.utils.str_enum import StrEnum
+
+from .quantum import QubitType
+
+
+class LogicalInitKind(StrEnum):
+ """The initial state of a logical qubit such as |0⟩, |1⟩, |+⟩, |−⟩, |Y⟩, |-Y⟩, |m⟩, or |m̅⟩."""
+
+ Zero = "zero" # |0⟩ Non-magic state
+ One = "one" # |1⟩ Non-magic state
+ Plus = "plus" # |+⟩ = (|0⟩ + |1⟩) / sqrt(2) Non-magic state
+ Minus = "minus" # |-⟩ = (|0⟩ - |1⟩) / sqrt(2) Non-magic state
+ PlusI = "plus_i" # |Y⟩ = (|0⟩ + i|1⟩) / sqrt(2) Non-magic / Magic state
+ MinusI = "minus_i" # |-Y⟩ = (|0⟩ - i|1⟩) / sqrt(2) Non-magic / Magic state
+ Magic = "magic" # |m⟩ = |0⟩ + e^{iπ/4}|1⟩ Magic state
+ MagicConj = "magic_conj" # |m̅⟩ = |0⟩ + e^{-iπ/4}|1⟩ Magic state
+
+
+@irdl_attr_definition
+class LogicalInit(EnumAttribute[LogicalInitKind], SpacedOpaqueSyntaxAttribute):
+ """The initial state of a logical qubit such as |0⟩, |1⟩, |+⟩, |−⟩, |Y⟩, |-Y⟩, |m⟩, or |m̅⟩."""
+
+ name = "qec.enum"
+
+
+# Type alias for a product of Pauli operators, aka a Pauli word.
+PauliWord = ArrayAttr[StringAttr]
+
+
+@irdl_op_definition
+class YieldOp(AbstractYieldOperation[Attribute]):
+ """Return results from a layer region"""
+
+ name = "qec.yield"
+
+ traits = lazy_traits_def(lambda: (IsTerminator(), HasParent(LayerOp), Pure()))
+
+
+@irdl_op_definition
+class FabricateOp(IRDLOperation):
+ """Fabricate axillary qubits from qubit factories."""
+
+ name = "qec.fabricate"
+
+ assembly_format = """
+ $init_state attr-dict `:` type($out_qubits)
+ """
+
+ init_state = prop_def(LogicalInit)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class LayerOp(IRDLOperation):
+ """A layer operation"""
+
+ name = "qec.layer"
+
+ initArgs = var_operand_def()
+
+ results = var_result_def()
+
+ region = region_def("single_block")
+
+ traits = traits_def(SingleBlockImplicitTerminator(YieldOp))
+
+ # TODO: add a custom parse and print for this operation
+
+
+@irdl_op_definition
+class PPMeasurementOp(IRDLOperation):
+ """Pauli Product Measurement on qubits."""
+
+ name = "qec.ppm"
+
+ assembly_format = """
+ $pauli_product (`(` $rotation_sign^ `)`)? $in_qubits (`cond` `(` $condition^ `)`)? attr-dict `:` type(results)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ pauli_product = prop_def(PauliWord)
+
+ rotation_sign = opt_prop_def(IntegerAttr[I16], default_value=IntegerAttr(1, i16))
+
+ in_qubits = var_operand_def(QubitType)
+
+ condition = opt_operand_def(IntegerType(1))
+
+ mres = result_def(IntegerType(1))
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class PPRotationOp(IRDLOperation):
+ """Pauli Product Rotation on qubits."""
+
+ name = "qec.ppr"
+
+ assembly_format = """
+ $pauli_product `(` $rotation_kind `)` $in_qubits attr-dict (`cond` `(` $condition^ `)`)? `:` type($out_qubits)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ pauli_product = prop_def(PauliWord)
+
+ rotation_kind = prop_def(IntegerAttr[IntegerType(16)])
+
+ in_qubits = var_operand_def(QubitType)
+
+ condition = opt_operand_def(IntegerType(1))
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class PrepareStateOp(IRDLOperation):
+ """Initialize existing qubits into a given state."""
+
+ name = "qec.prepare"
+
+ assembly_format = """
+ $init_state $in_qubits attr-dict `:` type($out_qubits)
+ """
+
+ init_state = prop_def(LogicalInit)
+
+ in_qubits = var_operand_def(QubitType)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class SelectPPMeasurementOp(IRDLOperation):
+ """Multiplexed Pauli product measurement."""
+
+ name = "qec.select.ppm"
+
+ assembly_format = """
+ `(` $select_switch `,` $pauli_product_0 `,` $pauli_product_1 `)` $in_qubits attr-dict `:` type(results)
+ """
+
+ select_switch = operand_def(IntegerType(1))
+
+ pauli_product_0 = prop_def(PauliWord)
+
+ pauli_product_1 = prop_def(PauliWord)
+
+ in_qubits = var_operand_def(QubitType)
+
+ mres = result_def(IntegerType(1))
+
+ out_qubits = var_result_def(QubitType)
+
+
+QEC = Dialect(
+ "qec",
+ [
+ FabricateOp,
+ LayerOp,
+ PPMeasurementOp,
+ PPRotationOp,
+ PrepareStateOp,
+ SelectPPMeasurementOp,
+ YieldOp,
+ ],
+ [
+ LogicalInit,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/quantum.py b/frontend/catalyst/python_interface/dialects/quantum.py
new file mode 100644
index 0000000000..8b549f5440
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/quantum.py
@@ -0,0 +1,1138 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This file contains the definition of the Quantum dialect for the unified compiler.
+
+The Quantum dialect is a set of operations and types used to represent quantum computations
+in the xDSL framework.
+
+It was initially generated by xDSL (using the ``xdsl-tblgen`` tool)
+starting from the catalyst/mlir/include/Quantum/IR/QuantumOps.td file in the catalyst repository.
+"""
+# pylint: disable=too-many-lines
+
+from collections.abc import Sequence
+from typing import TypeAlias
+
+from xdsl.dialects.builtin import (
+ I32,
+ I64,
+ ComplexType,
+ Float64Type,
+ FloatAttr,
+ IntegerAttr,
+ IntegerType,
+ MemRefType,
+ StringAttr,
+ TensorType,
+ UnitAttr,
+ i1,
+ i64,
+)
+from xdsl.ir import (
+ Block,
+ Dialect,
+ EnumAttribute,
+ Operation,
+ ParametrizedAttribute,
+ Region,
+ SpacedOpaqueSyntaxAttribute,
+ SSAValue,
+ StrEnum,
+ TypeAttribute,
+)
+from xdsl.irdl import (
+ AtLeast,
+ AttrSizedOperandSegments,
+ AttrSizedResultSegments,
+ IntSetConstraint,
+ IRDLOperation,
+ ParsePropInAttrDict,
+ SameVariadicResultSize,
+ irdl_attr_definition,
+ irdl_op_definition,
+ lazy_traits_def,
+ operand_def,
+ opt_operand_def,
+ opt_prop_def,
+ opt_result_def,
+ prop_def,
+ region_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import (
+ HasParent,
+ IsTerminator,
+ NoMemoryEffect,
+ Pure,
+ ReturnLike,
+ SingleBlockImplicitTerminator,
+)
+
+from catalyst.python_interface.xdsl_extras import MemRefConstraint, TensorConstraint
+
+################################################################
+######################## ATTRIBUTES ############################
+################################################################
+
+
+@irdl_attr_definition
+class ObservableType(ParametrizedAttribute, TypeAttribute):
+ """A quantum observable for use in measurements."""
+
+ name = "quantum.obs"
+
+
+@irdl_attr_definition
+class QubitType(ParametrizedAttribute, TypeAttribute):
+ """A value-semantic qubit (state)."""
+
+ name = "quantum.bit"
+
+
+@irdl_attr_definition
+class QuregType(ParametrizedAttribute, TypeAttribute):
+ """An array of value-semantic qubits (i.e. quantum register)."""
+
+ name = "quantum.reg"
+
+
+@irdl_attr_definition
+class ResultType(ParametrizedAttribute, TypeAttribute):
+ """A quantum measurement result."""
+
+ name = "quantum.res"
+
+
+class NamedObservable(StrEnum):
+ """Known named observables"""
+
+ Identity = "Identity"
+ PauliX = "PauliX"
+ PauliY = "PauliY"
+ PauliZ = "PauliZ"
+ Hadamard = "Hadamard"
+
+
+@irdl_attr_definition
+class NamedObservableAttr(EnumAttribute[NamedObservable], SpacedOpaqueSyntaxAttribute):
+ """Known named observables"""
+
+ name = "quantum.named_observable"
+
+
+################################################################
+######################## OPERATIONS ############################
+################################################################
+
+
+QubitSSAValue: TypeAlias = SSAValue[QubitType]
+QuregSSAValue: TypeAlias = SSAValue[QuregType]
+ObservableSSAValue: TypeAlias = SSAValue[ObservableType]
+
+
+@irdl_op_definition
+class AdjointOp(IRDLOperation):
+ """Calculate the adjoint of the enclosed operations"""
+
+ name = "quantum.adjoint"
+
+ assembly_format = """
+ `(` $qreg `)` attr-dict `:` type(operands) $region
+ """
+
+ qreg = operand_def(QuregType)
+
+ out_qreg = result_def(QuregType)
+
+ region = region_def("single_block")
+
+ traits = lazy_traits_def(lambda: (NoMemoryEffect(), SingleBlockImplicitTerminator(YieldOp)))
+
+ def __init__(
+ self,
+ qreg: QuregSSAValue | Operation,
+ region: Region | Sequence[Operation] | Sequence[Block],
+ ):
+ super().__init__(operands=(qreg,), result_types=(QuregType(),), regions=(region,))
+
+
+@irdl_op_definition
+class AllocOp(IRDLOperation):
+ """Allocate n qubits into a quantum register."""
+
+ name = "quantum.alloc"
+
+ assembly_format = """
+ `(` ($nqubits^):($nqubits_attr)? `)` attr-dict `:` type(results)
+ """
+
+ nqubits = opt_operand_def(i64)
+
+ nqubits_attr = opt_prop_def(IntegerAttr.constr(type=I64, value=AtLeast(0)))
+
+ qreg = result_def(QuregType)
+
+ def __init__(self, nqubits):
+ if isinstance(nqubits, int):
+ nqubits = IntegerAttr.from_int_and_width(nqubits, 64)
+
+ if isinstance(nqubits, IntegerAttr):
+ operands = (None,)
+ properties = {"nqubits_attr": nqubits}
+ else:
+ operands = (nqubits,)
+ properties = {}
+
+ super().__init__(operands=operands, properties=properties, result_types=(QuregType(),))
+
+
+@irdl_op_definition
+class AllocQubitOp(IRDLOperation):
+ """Allocate a single qubit."""
+
+ name = "quantum.alloc_qb"
+
+ assembly_format = """attr-dict `:` type(results)"""
+
+ qubit = result_def(QubitType)
+
+ def __init__(self):
+ super().__init__(
+ result_types=(QubitType(),),
+ )
+
+
+@irdl_op_definition
+class ComputationalBasisOp(IRDLOperation):
+ """Define a pseudo-obeservable of the computational basis for use in measurements"""
+
+ name = "quantum.compbasis"
+
+ assembly_format = """
+ (`qubits` $qubits^)? (`qreg` $qreg^)? attr-dict `:` type(results)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ qubits = var_operand_def(QubitType)
+
+ qreg = opt_operand_def(QuregType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class CountsOp(IRDLOperation):
+ """Compute sample counts for the given observable for the current state"""
+
+ name = "quantum.counts"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $in_eigvals^ `:` type($in_eigvals) `,` $in_counts `:` type($in_counts) `)` )?
+ attr-dict ( `:` type($eigvals)^ `,` type($counts) )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ SameVariadicResultSize(),
+ ]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = opt_operand_def(i64)
+
+ in_eigvals = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1))
+
+ in_counts = opt_operand_def(MemRefConstraint(element_type=i64, rank=1))
+
+ eigvals = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1))
+
+ counts = opt_result_def(TensorConstraint(element_type=i64, rank=1))
+
+
+@irdl_op_definition
+class CustomOp(IRDLOperation):
+ """A generic quantum gate on n qubits with m floating point parameters."""
+
+ name = "quantum.custom"
+
+ assembly_format = """
+ $gate_name `(` $params `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ params = var_operand_def(Float64Type())
+
+ in_qubits = var_operand_def(QubitType)
+
+ gate_name = prop_def(StringAttr)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ gate_name: str | StringAttr,
+ params: SSAValue[Float64Type] | Sequence[SSAValue[Float64Type]] | None = None,
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ params = () if params is None else params
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(params, Sequence):
+ params = (params,)
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ if isinstance(gate_name, str):
+ gate_name = StringAttr(data=gate_name)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {"gate_name": gate_name}
+ if adjoint:
+ properties["adjoint"] = UnitAttr()
+
+ super().__init__(
+ operands=(params, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class DeallocOp(IRDLOperation):
+ """Deallocate a quantum register."""
+
+ name = "quantum.dealloc"
+
+ assembly_format = """
+ $qreg attr-dict `:` type(operands)
+ """
+
+ qreg = operand_def(QuregType)
+
+ def __init__(self, qreg: QuregSSAValue | Operation):
+ super().__init__(operands=(qreg,))
+
+
+@irdl_op_definition
+class DeallocQubitOp(IRDLOperation):
+ """Deallocate a single qubit."""
+
+ name = "quantum.dealloc_qb"
+
+ assembly_format = """$qubit attr-dict `:` type(operands)"""
+
+ qubit = operand_def(QubitType)
+
+ def __init__(self, qubit: QubitSSAValue | Operation):
+ super().__init__(
+ operands=(qubit,),
+ )
+
+
+@irdl_op_definition
+class DeviceInitOp(IRDLOperation):
+ """Initialize a quantum device."""
+
+ name = "quantum.device"
+
+ assembly_format = """
+ (`shots` `(` $shots^ `)`)? `[` $lib `,` $device_name `,` $kwargs `]` attr-dict
+ """
+
+ irdl_options = [ParsePropInAttrDict()]
+
+ shots = opt_operand_def(i64)
+
+ auto_qubit_management = opt_prop_def(UnitAttr)
+
+ lib = prop_def(StringAttr)
+
+ device_name = prop_def(StringAttr)
+
+ kwargs = prop_def(StringAttr)
+
+
+@irdl_op_definition
+class DeviceReleaseOp(IRDLOperation):
+ """Release the active quantum device."""
+
+ name = "quantum.device_release"
+
+ assembly_format = "attr-dict"
+
+
+@irdl_op_definition
+class ExpvalOp(IRDLOperation):
+ """Compute the expectation value of the given observable for the current state"""
+
+ name = "quantum.expval"
+
+ assembly_format = "$obs attr-dict `:` type(results)"
+
+ obs = operand_def(ObservableType)
+
+ expval = result_def(Float64Type())
+
+ def __init__(self, obs: ObservableSSAValue | Operation):
+ super().__init__(operands=(obs,), result_types=(Float64Type(),))
+
+
+@irdl_op_definition
+class ExtractOp(IRDLOperation):
+ """Extract a qubit value from a register."""
+
+ name = "quantum.extract"
+
+ assembly_format = """
+ $qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results)
+ """
+
+ qreg = operand_def(QuregType)
+
+ idx = opt_operand_def(i64)
+
+ idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0)))
+
+ qubit = result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ def __init__(
+ self,
+ qreg: QuregSSAValue | Operation,
+ idx: int | SSAValue[IntegerType] | Operation | IntegerAttr,
+ ):
+ if isinstance(idx, int):
+ idx = IntegerAttr.from_int_and_width(idx, 64)
+
+ if isinstance(idx, IntegerAttr):
+ operands = (qreg, None)
+ properties = {"idx_attr": idx}
+ else:
+ operands = (qreg, idx)
+ properties = {}
+
+ super().__init__(
+ operands=operands,
+ result_types=(QubitType(),),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class FinalizeOp(IRDLOperation):
+ """Teardown the quantum runtime."""
+
+ name = "quantum.finalize"
+
+ assembly_format = "attr-dict"
+
+
+@irdl_op_definition
+class GlobalPhaseOp(IRDLOperation):
+ """Global Phase."""
+
+ name = "quantum.gphase"
+
+ assembly_format = """
+ `(` $params `)`
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type(results)
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]
+
+ params = operand_def(Float64Type())
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ def __init__(
+ self,
+ *,
+ params: float | SSAValue[Float64Type],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ ):
+ if isinstance(params, float):
+ params = FloatAttr(data=params, type=Float64Type())
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+
+ super().__init__(
+ operands=(params, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_ctrl_qubits,),
+ )
+
+
+@irdl_op_definition
+class HamiltonianOp(IRDLOperation):
+ """Define a Hamiltonian observable for use in measurements"""
+
+ name = "quantum.hamiltonian"
+
+ assembly_format = """
+ `(` $coeffs `:` type($coeffs) `)` $terms attr-dict `:` type(results)
+ """
+
+ coeffs = operand_def(
+ TensorConstraint(element_type=Float64Type(), rank=1)
+ | (MemRefConstraint(element_type=Float64Type(), rank=1))
+ )
+
+ terms = var_operand_def(ObservableType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class HermitianOp(IRDLOperation):
+ """Define a Hermitian observable for use in measurements"""
+
+ name = "quantum.hermitian"
+
+ assembly_format = """
+ `(` $matrix `:` type($matrix) `)` $qubits attr-dict `:` type(results)
+ """
+
+ matrix = operand_def(
+ TensorConstraint(element_type=ComplexType(Float64Type()), rank=2)
+ | MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2)
+ )
+
+ qubits = var_operand_def(QubitType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class InitializeOp(IRDLOperation):
+ """Initialize the quantum runtime."""
+
+ name = "quantum.init"
+
+ assembly_format = "attr-dict"
+
+
+@irdl_op_definition
+class InsertOp(IRDLOperation):
+ """Update the qubit value of a register."""
+
+ name = "quantum.insert"
+
+ assembly_format = """
+ $in_qreg `[` ($idx^):($idx_attr)? `]` `,` $qubit attr-dict `:` type($in_qreg) `,` type($qubit)
+ """
+
+ in_qreg = operand_def(QuregType)
+
+ idx = opt_operand_def(i64)
+
+ idx_attr = opt_prop_def(IntegerAttr.constr(type=i64, value=AtLeast(0)))
+
+ qubit = operand_def(QubitType)
+
+ out_qreg = result_def(QuregType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ def __init__(
+ self,
+ in_qreg: QuregSSAValue | Operation,
+ idx: SSAValue[IntegerType] | Operation | int | IntegerAttr,
+ qubit: QubitSSAValue | Operation,
+ ):
+ if isinstance(idx, int):
+ idx = IntegerAttr.from_int_and_width(idx, 64)
+
+ if isinstance(idx, IntegerAttr):
+ operands = (in_qreg, None, qubit)
+ properties = {"idx_attr": idx}
+ else:
+ operands = (in_qreg, idx, qubit)
+ properties = {}
+
+ super().__init__(operands=operands, properties=properties, result_types=(QuregType(),))
+
+
+@irdl_op_definition
+class MeasureOp(IRDLOperation):
+ """A single-qubit projective measurement in the computational basis."""
+
+ name = "quantum.measure"
+
+ assembly_format = """
+ $in_qubit (`postselect` $postselect^)? attr-dict `:` type(results)
+ """
+
+ in_qubit = operand_def(QubitType)
+
+ postselect = opt_prop_def(
+ IntegerAttr.constr(type=I32, value=IntSetConstraint(frozenset((0, 1))))
+ )
+
+ mres = result_def(i1)
+
+ out_qubit = result_def(QubitType)
+
+ def __init__(
+ self, in_qubit: QubitSSAValue | Operation, postselect: int | IntegerAttr | None = None
+ ):
+ if isinstance(postselect, int):
+ postselect = IntegerAttr.from_int_and_width(postselect, 32)
+
+ if postselect is None:
+ properties = {}
+ else:
+ properties = {"postselect": postselect}
+
+ super().__init__(
+ operands=(in_qubit,), properties=properties, result_types=(i1, QubitType())
+ )
+
+
+@irdl_op_definition
+class MultiRZOp(IRDLOperation):
+ """Apply an arbitrary multi Z rotation"""
+
+ name = "quantum.multirz"
+
+ assembly_format = """
+ `(` $theta `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ theta = operand_def(Float64Type())
+
+ in_qubits = var_operand_def(QubitType)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ theta: SSAValue[Float64Type],
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {"adjoint": UnitAttr()} if adjoint else {}
+
+ super().__init__(
+ operands=(theta, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class NamedObsOp(IRDLOperation):
+ """Define a Named observable for use in measurements"""
+
+ name = "quantum.namedobs"
+
+ assembly_format = """
+ $qubit `[` $type `]` attr-dict `:` type(results)
+ """
+
+ qubit = operand_def(QubitType)
+
+ type = prop_def(NamedObservableAttr)
+
+ obs = result_def(ObservableType)
+
+ def __init__(self, qubit: QubitSSAValue | Operation, obs_type: NamedObservableAttr):
+ super().__init__(
+ operands=(qubit,), properties={"type": obs_type}, result_types=(ObservableType(),)
+ )
+
+
+@irdl_op_definition
+class NumQubitsOp(IRDLOperation):
+ """Get the number of currently allocated qubits."""
+
+ name = "quantum.num_qubits"
+
+ assembly_format = """
+ attr-dict `:` type(results)
+ """
+
+ num_qubits = result_def(i64)
+
+
+@irdl_op_definition
+class PCPhaseOp(IRDLOperation):
+ """Apply a projector-controlled phase gate"""
+
+ name = "quantum.pcphase"
+
+ assembly_format = """
+ `(` $theta `,` $dim `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ theta = operand_def(Float64Type())
+
+ dim = operand_def(Float64Type())
+
+ in_qubits = var_operand_def(QubitType)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ theta: SSAValue[Float64Type],
+ dim: SSAValue[Float64Type],
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {"adjoint": UnitAttr()} if adjoint else {}
+
+ super().__init__(
+ operands=(theta, dim, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class ProbsOp(IRDLOperation):
+ """Compute computational basis probabilities for the current state"""
+
+ name = "quantum.probs"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $state_in^ `:` type($state_in) `)` )?
+ attr-dict ( `:` type($probabilities)^ )?
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = opt_operand_def(i64)
+
+ state_in = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=1))
+
+ probabilities = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=1))
+
+
+@irdl_op_definition
+class QubitUnitaryOp(IRDLOperation):
+ """Apply an arbitrary fixed unitary matrix"""
+
+ name = "quantum.unitary"
+
+ assembly_format = """
+ `(` $matrix `:` type($matrix) `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ matrix = operand_def(
+ (TensorConstraint(element_type=ComplexType(Float64Type()), rank=2))
+ | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=2))
+ )
+
+ in_qubits = var_operand_def(QubitType)
+
+ adjoint = opt_prop_def(UnitAttr)
+
+ in_ctrl_qubits = var_operand_def(QubitType)
+
+ in_ctrl_values = var_operand_def(i1)
+
+ out_qubits = var_result_def(QubitType)
+
+ out_ctrl_qubits = var_result_def(QubitType)
+
+ traits = traits_def(NoMemoryEffect())
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ matrix: SSAValue[TensorType | MemRefType],
+ in_qubits: QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation],
+ in_ctrl_qubits: (
+ QubitSSAValue | Operation | Sequence[QubitSSAValue | Operation] | None
+ ) = None,
+ in_ctrl_values: (
+ SSAValue[IntegerType]
+ | Operation
+ | Sequence[SSAValue[IntegerType]]
+ | Sequence[Operation]
+ | None
+ ) = None,
+ adjoint: UnitAttr | bool = False,
+ ):
+ in_ctrl_qubits = () if in_ctrl_qubits is None else in_ctrl_qubits
+ in_ctrl_values = () if in_ctrl_values is None else in_ctrl_values
+
+ if not isinstance(in_qubits, Sequence):
+ in_qubits = (in_qubits,)
+ if not isinstance(in_ctrl_qubits, Sequence):
+ in_ctrl_qubits = (in_ctrl_qubits,)
+ if not isinstance(in_ctrl_values, Sequence):
+ in_ctrl_values = (in_ctrl_values,)
+
+ out_qubits = tuple(QubitType() for _ in in_qubits)
+ out_ctrl_qubits = tuple(QubitType() for _ in in_ctrl_qubits)
+ properties = {}
+ if adjoint:
+ properties["adjoint"] = UnitAttr()
+
+ super().__init__(
+ operands=(matrix, in_qubits, in_ctrl_qubits, in_ctrl_values),
+ result_types=(out_qubits, out_ctrl_qubits),
+ properties=properties,
+ )
+
+
+@irdl_op_definition
+class SampleOp(IRDLOperation):
+ """Sample eigenvalues from the given observable for the current state"""
+
+ name = "quantum.sample"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $in_data^ `:` type($in_data) `)` )?
+ attr-dict ( `:` type($samples)^ )?
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = var_operand_def(i64)
+
+ in_data = opt_operand_def(MemRefConstraint(element_type=Float64Type(), rank=(1, 2)))
+
+ samples = opt_result_def(TensorConstraint(element_type=Float64Type(), rank=(1, 2)))
+
+
+@irdl_op_definition
+class SetBasisStateOp(IRDLOperation):
+ """Set basis state."""
+
+ name = "quantum.set_basis_state"
+
+ assembly_format = """
+ `(` $basis_state`)` $in_qubits attr-dict `:` functional-type(operands, results)
+ """
+
+ basis_state = operand_def(
+ (TensorConstraint(element_type=i1, rank=1)) | (MemRefConstraint(element_type=i1, rank=1))
+ )
+
+ in_qubits = var_operand_def(QubitType)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class SetStateOp(IRDLOperation):
+ """Set state to a complex vector."""
+
+ name = "quantum.set_state"
+
+ assembly_format = """
+ `(` $in_state `)` $in_qubits attr-dict `:` functional-type(operands, results)
+ """
+
+ in_state = operand_def(
+ (TensorConstraint(element_type=ComplexType(Float64Type()), rank=1))
+ | (MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1))
+ )
+
+ in_qubits = var_operand_def(QubitType)
+
+ out_qubits = var_result_def(QubitType)
+
+
+@irdl_op_definition
+class StateOp(IRDLOperation):
+ """Return the current statevector"""
+
+ name = "quantum.state"
+
+ assembly_format = """
+ $obs ( `shape` $dynamic_shape^ )?
+ ( `in` `(` $state_in^ `:` type($state_in) `)` )?
+ attr-dict ( `:` type($state)^ )?
+ """
+
+ irdl_options = [AttrSizedOperandSegments(as_property=True)]
+
+ obs = operand_def(ObservableType)
+
+ dynamic_shape = opt_operand_def(i64)
+
+ state_in = opt_operand_def(MemRefConstraint(element_type=ComplexType(Float64Type()), rank=1))
+
+ state = opt_result_def(TensorConstraint(element_type=ComplexType(Float64Type()), rank=1))
+
+
+@irdl_op_definition
+class TensorOp(IRDLOperation):
+ """Define a tensor product of observables for use in measurements"""
+
+ name = "quantum.tensor"
+
+ assembly_format = """
+ $terms attr-dict `:` type(results)
+ """
+
+ terms = var_operand_def(ObservableType)
+
+ obs = result_def(ObservableType)
+
+
+@irdl_op_definition
+class VarianceOp(IRDLOperation):
+ """Compute the variance of the given observable for the current state"""
+
+ name = "quantum.var"
+
+ assembly_format = """
+ $obs attr-dict `:` type(results)
+ """
+
+ obs = operand_def(ObservableType)
+
+ variance = result_def(Float64Type())
+
+ def __init__(self, obs: ObservableSSAValue | Operation):
+ super().__init__(operands=(obs,), result_types=(Float64Type(),))
+
+
+@irdl_op_definition
+class YieldOp(IRDLOperation):
+ """Return results from quantum program regions"""
+
+ name = "quantum.yield"
+
+ assembly_format = """
+ attr-dict ($retvals^ `:` type($retvals))?
+ """
+
+ retvals = var_operand_def(QuregType)
+
+ traits = traits_def(HasParent(AdjointOp), IsTerminator(), Pure(), ReturnLike())
+
+
+Quantum = Dialect(
+ "quantum",
+ [
+ AdjointOp,
+ AllocOp,
+ AllocQubitOp,
+ ComputationalBasisOp,
+ CountsOp,
+ CustomOp,
+ DeallocOp,
+ DeallocQubitOp,
+ DeviceInitOp,
+ DeviceReleaseOp,
+ ExpvalOp,
+ ExtractOp,
+ FinalizeOp,
+ GlobalPhaseOp,
+ HamiltonianOp,
+ HermitianOp,
+ InitializeOp,
+ InsertOp,
+ MeasureOp,
+ MultiRZOp,
+ NamedObsOp,
+ NumQubitsOp,
+ PCPhaseOp,
+ ProbsOp,
+ QubitUnitaryOp,
+ SampleOp,
+ SetBasisStateOp,
+ SetStateOp,
+ StateOp,
+ TensorOp,
+ VarianceOp,
+ YieldOp,
+ ],
+ [
+ ObservableType,
+ QubitType,
+ QuregType,
+ ResultType,
+ NamedObservableAttr,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py
new file mode 100644
index 0000000000..64380529ac
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/__init__.py
@@ -0,0 +1,161 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+StableHLO dialect package for PennyLane's compiler infrastructure.
+
+This package contains organized elementwise operations and other StableHLO-related
+functionality.
+"""
+
+from .attributes import (
+ CustomCallApiVersion,
+ CustomCallApiVersionAttr,
+ GatherDimensionNumbers,
+ OutputOperandAlias,
+ ResultAccuracyModeAttr,
+ ScatterDimensionNumbers,
+)
+from .control_flow import (
+ IfOp,
+ OptimizationBarrierOp,
+ WhileOp,
+)
+from .data_movement import (
+ BroadcastInDimOp,
+ ConcatenateOp,
+ DynamicSliceOp,
+ GatherOp,
+ ReshapeOp,
+ ScatterOp,
+ SliceOp,
+)
+
+# Import the main StableHLO dialect
+from .dialect import StableHLO
+from .dynamism import (
+ DynamicBroadcastInDimOp,
+)
+from .elementwise_binary import (
+ ComplexOp,
+ DivideOp,
+ MaximumOp,
+ MinimumOp,
+ PowerOp,
+ RemainderOp,
+)
+from .elementwise_other import (
+ ClampOp,
+ CompareOp,
+ ConstantOp,
+ MapOp,
+ ReducePrecisionOp,
+ SelectOp,
+)
+
+# Import all elementwise operations explicitly
+from .elementwise_unary import (
+ ConvertOp,
+ CosineOp,
+ ExponentialMinusOneOp,
+ ExponentialOp,
+ FloorOp,
+ ImagOp,
+ IsFiniteOp,
+ LogisticOp,
+ LogOp,
+ LogPlusOneOp,
+ NegateOp,
+ RealOp,
+ RoundNearestAfzOp,
+ RoundNearestEvenOp,
+ RsqrtOp,
+ SignOp,
+ SineOp,
+ SqrtOp,
+ TanhOp,
+ TanOp,
+)
+from .extensibility import (
+ CustomCallOp,
+)
+from .reduction import (
+ ReduceOp,
+)
+
+# Export all operations and the dialect for external use
+__all__ = [
+ # Main dialect
+ "StableHLO",
+ # Elementwise unary operations
+ "ConvertOp",
+ "CosineOp",
+ "ExponentialMinusOneOp",
+ "ExponentialOp",
+ "FloorOp",
+ "ImagOp",
+ "IsFiniteOp",
+ "LogOp",
+ "LogPlusOneOp",
+ "LogisticOp",
+ "NegateOp",
+ "RealOp",
+ "RoundNearestAfzOp",
+ "RoundNearestEvenOp",
+ "RsqrtOp",
+ "SignOp",
+ "SineOp",
+ "SqrtOp",
+ "TanOp",
+ "TanhOp",
+ # Elementwise binary operations
+ "ComplexOp",
+ "DivideOp",
+ "MaximumOp",
+ "MinimumOp",
+ "PowerOp",
+ "RemainderOp",
+ # Elementwise other operations
+ "ClampOp",
+ "CompareOp",
+ "ConstantOp",
+ "MapOp",
+ "ReducePrecisionOp",
+ "SelectOp",
+ # Control flow operations
+ "IfOp",
+ "WhileOp",
+ "OptimizationBarrierOp",
+ # Data movement operations
+ "BroadcastInDimOp",
+ "ConcatenateOp",
+ "DynamicSliceOp",
+ "GatherOp",
+ "ReshapeOp",
+ "ScatterOp",
+ "SliceOp",
+ # Dynamism operations
+ "DynamicBroadcastInDimOp",
+ # Reduction operations
+ "ReduceOp",
+ # Extensibility operations
+ "CustomCallOp",
+ # Attributes
+ "GatherDimensionNumbers",
+ "ResultAccuracyModeAttr",
+ "ScatterDimensionNumbers",
+ "CustomCallApiVersion",
+ "CustomCallApiVersionAttr",
+ "OutputOperandAlias",
+]
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py
new file mode 100644
index 0000000000..7618d77afc
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/attributes.py
@@ -0,0 +1,394 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+StableHLO attribute definitions for PennyLane's compiler infrastructure.
+
+This module provides attribute definitions based on the StableHLO specification
+(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including
+attributes for StableHLO operations.
+"""
+
+# pylint: disable=line-too-long
+
+from collections.abc import Sequence
+
+from xdsl.dialects.builtin import I64, ArrayAttr, IntegerAttr, i64
+from xdsl.ir import (
+ Attribute,
+ EnumAttribute,
+ ParametrizedAttribute,
+ SpacedOpaqueSyntaxAttribute,
+ StrEnum,
+)
+from xdsl.irdl import irdl_attr_definition
+from xdsl.parser import AttrParser
+from xdsl.printer import Printer
+
+
+# Utility functions for dimension array parsing/printing
+def parse_dims(parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]:
+ """Parse dimension array in [1, 2, 3] format"""
+ value = parser.parse_comma_separated_list(
+ AttrParser.Delimiter.SQUARE,
+ lambda: IntegerAttr(parser.parse_integer(), i64),
+ )
+ return ArrayAttr(value)
+
+
+def print_dims(printer: Printer, dims: ArrayAttr[IntegerAttr[I64]]):
+ """Print dimension array in [1, 2, 3] format"""
+ printer.print_string("[")
+ printer.print_list(
+ dims.data,
+ lambda dim: printer.print_string(f"{dim.value.data}"),
+ )
+ printer.print_string("]")
+
+
+class ResultAccuracyMode(StrEnum):
+ """
+ XLA result accuracy mode.
+ """
+
+ DEFAULT = "DEFAULT"
+ HIGH = "HIGHEST"
+ HIGHEST = "TOLERANCE"
+
+
+@irdl_attr_definition
+class ResultAccuracyModeAttr(EnumAttribute[ResultAccuracyMode], SpacedOpaqueSyntaxAttribute):
+ """
+ XLA result accuracy mode.
+
+ See external [documentation](https://github.com/openxla/stablehlo/blob/7c50d4efeaea30bff6aa5e46c7f71170f5aa06af/stablehlo/dialect/StablehloEnums.td#L49-L70).
+ """
+
+ name = "stablehlo.result_accuracy_mode"
+
+
+@irdl_attr_definition
+class GatherDimensionNumbers(ParametrizedAttribute):
+ """
+ XLA gather dimension numbers.
+
+ This attribute models the dimension information for gather operations.
+ See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42).
+ """
+
+ name = "stablehlo.gather"
+
+ offset_dims: ArrayAttr[IntegerAttr[I64]]
+ collapsed_slice_dims: ArrayAttr[IntegerAttr[I64]]
+ operand_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ start_indices_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ start_index_map: ArrayAttr[IntegerAttr[I64]]
+ index_vector_dim: IntegerAttr[I64]
+
+ def print_parameters(self, printer: Printer) -> None:
+ """Print gather dimension numbers in structured format"""
+ with printer.in_angle_brackets():
+ with printer.indented():
+ # Print offset_dims
+ printer.print_string("\noffset_dims = ")
+ print_dims(printer, self.offset_dims)
+ printer.print_string(",")
+
+ # Print collapsed_slice_dims
+ printer.print_string("\ncollapsed_slice_dims = ")
+ print_dims(printer, self.collapsed_slice_dims)
+ printer.print_string(",")
+
+ # Print operand_batching_dims
+ printer.print_string("\noperand_batching_dims = ")
+ print_dims(printer, self.operand_batching_dims)
+ printer.print_string(",")
+
+ # Print start_indices_batching_dims
+ printer.print_string("\nstart_indices_batching_dims = ")
+ print_dims(printer, self.start_indices_batching_dims)
+ printer.print_string(",")
+
+ # Print start_index_map
+ printer.print_string("\nstart_index_map = ")
+ print_dims(printer, self.start_index_map)
+ printer.print_string(",")
+
+ # Print index_vector_dim
+ printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}")
+ printer.print_string("\n")
+
+ @classmethod
+ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
+ """Parse gather dimension numbers from structured format"""
+ with parser.in_angle_brackets():
+ # Initialize default values for all fields
+ offset_dims = ArrayAttr([])
+ collapsed_slice_dims = ArrayAttr([])
+ operand_batching_dims = ArrayAttr([])
+ start_indices_batching_dims = ArrayAttr([])
+ start_index_map = ArrayAttr([])
+ index_vector_dim = IntegerAttr(0, i64)
+
+ # Try to parse offset_dims
+ if parser.parse_optional_characters("offset_dims") is not None:
+ parser.parse_punctuation("=")
+ offset_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse collapsed_slice_dims
+ if parser.parse_optional_characters("collapsed_slice_dims") is not None:
+ parser.parse_punctuation("=")
+ collapsed_slice_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse operand_batching_dims
+ if parser.parse_optional_characters("operand_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ operand_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse start_indices_batching_dims
+ if parser.parse_optional_characters("start_indices_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ start_indices_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse start_index_map
+ if parser.parse_optional_characters("start_index_map") is not None:
+ parser.parse_punctuation("=")
+ start_index_map = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse index_vector_dim
+ if parser.parse_optional_characters("index_vector_dim") is not None:
+ parser.parse_punctuation("=")
+ index_vector_dim = IntegerAttr(parser.parse_integer(), i64)
+
+ return (
+ offset_dims,
+ collapsed_slice_dims,
+ operand_batching_dims,
+ start_indices_batching_dims,
+ start_index_map,
+ index_vector_dim,
+ )
+
+
+@irdl_attr_definition
+class ScatterDimensionNumbers(ParametrizedAttribute):
+ """
+ XLA scatter dimension numbers.
+
+ This attribute models the dimension information for scatter operations.
+ See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28).
+ """
+
+ name = "stablehlo.scatter"
+
+ update_window_dims: ArrayAttr[IntegerAttr[I64]]
+ inserted_window_dims: ArrayAttr[IntegerAttr[I64]]
+ input_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ scatter_indices_batching_dims: ArrayAttr[IntegerAttr[I64]]
+ scatter_dims_to_operand_dims: ArrayAttr[IntegerAttr[I64]]
+ index_vector_dim: IntegerAttr[I64]
+
+ def print_parameters(self, printer: Printer) -> None:
+ """Print scatter dimension numbers in structured format"""
+ with printer.in_angle_brackets():
+ with printer.indented():
+ # Print update_window_dims
+ printer.print_string("\nupdate_window_dims = ")
+ print_dims(printer, self.update_window_dims)
+ printer.print_string(",")
+
+ # Print inserted_window_dims
+ printer.print_string("\ninserted_window_dims = ")
+ print_dims(printer, self.inserted_window_dims)
+ printer.print_string(",")
+
+ # Print input_batching_dims
+ printer.print_string("\ninput_batching_dims = ")
+ print_dims(printer, self.input_batching_dims)
+ printer.print_string(",")
+
+ # Print scatter_indices_batching_dims
+ printer.print_string("\nscatter_indices_batching_dims = ")
+ print_dims(printer, self.scatter_indices_batching_dims)
+ printer.print_string(",")
+
+ # Print scatter_dims_to_operand_dims
+ printer.print_string("\nscatter_dims_to_operand_dims = ")
+ print_dims(printer, self.scatter_dims_to_operand_dims)
+ printer.print_string(",")
+
+ # Print index_vector_dim
+ printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}")
+ printer.print_string("\n")
+
+ @classmethod
+ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
+ """Parse scatter dimension numbers from structured format"""
+ with parser.in_angle_brackets():
+ # Initialize default values for all fields
+ update_window_dims = ArrayAttr([])
+ inserted_window_dims = ArrayAttr([])
+ input_batching_dims = ArrayAttr([])
+ scatter_indices_batching_dims = ArrayAttr([])
+ scatter_dims_to_operand_dims = ArrayAttr([])
+ index_vector_dim = IntegerAttr(0, i64)
+
+ # Try to parse update_window_dims
+ if parser.parse_optional_characters("update_window_dims") is not None:
+ parser.parse_punctuation("=")
+ update_window_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse inserted_window_dims
+ if parser.parse_optional_characters("inserted_window_dims") is not None:
+ parser.parse_punctuation("=")
+ inserted_window_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse input_batching_dims
+ if parser.parse_optional_characters("input_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ input_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse scatter_indices_batching_dims
+ if parser.parse_optional_characters("scatter_indices_batching_dims") is not None:
+ parser.parse_punctuation("=")
+ scatter_indices_batching_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse scatter_dims_to_operand_dims
+ if parser.parse_optional_characters("scatter_dims_to_operand_dims") is not None:
+ parser.parse_punctuation("=")
+ scatter_dims_to_operand_dims = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ # Try to parse index_vector_dim
+ if parser.parse_optional_characters("index_vector_dim") is not None:
+ parser.parse_punctuation("=")
+ index_vector_dim = IntegerAttr(parser.parse_integer(), i64)
+
+ return (
+ update_window_dims,
+ inserted_window_dims,
+ input_batching_dims,
+ scatter_indices_batching_dims,
+ scatter_dims_to_operand_dims,
+ index_vector_dim,
+ )
+
+
+# ===== CustomCall and layout-related attributes =====
+
+
+class CustomCallApiVersion(StrEnum):
+ """StableHLO CustomCall API version."""
+
+ API_VERSION_UNSPECIFIED = "API_VERSION_UNSPECIFIED"
+ API_VERSION_ORIGINAL = "API_VERSION_ORIGINAL"
+ API_VERSION_STATUS_RETURNING = "API_VERSION_STATUS_RETURNING"
+ API_VERSION_STATUS_RETURNING_UNIFIED = "API_VERSION_STATUS_RETURNING_UNIFIED"
+ API_VERSION_TYPED_FFI = "API_VERSION_TYPED_FFI"
+
+
+@irdl_attr_definition
+class CustomCallApiVersionAttr(EnumAttribute[CustomCallApiVersion], SpacedOpaqueSyntaxAttribute):
+ """StableHLO custom call API version attribute.
+
+ Mirrors StableHLO enum for CustomCall API versions.
+ """
+
+ name = "stablehlo.custom_call_api_version"
+
+
+@irdl_attr_definition
+class OutputOperandAlias(ParametrizedAttribute):
+ """
+ This attribute captures the alias relationship of the output to one of the
+ operands for a ``CustomCall`` op, denoted by ``operand_index``. The
+ ``output_tuple_indices`` and ``operand_tuple_indices`` are used to index into
+ output and operand types. These indices lists are empty if the corresponding
+ types are not tuple types, and can be arbitrarily long in case of
+ arbitrarily nested tuple types.
+
+ See https://www.tensorflow.org/xla/aliasing.
+
+ Example when used as array with in stablehlo.custom-call:
+
+ ```mlir
+ %0 = "stablehlo.custom_call"(%arg0, %arg1) {
+ // other attributes
+ output_operand_alias = [
+ #stablehlo.output_operand_alias
+ ]
+ } : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple>
+
+ The output and the 0th operand are both tuples. The aliasing shows the
+ relationship between the 0th element in output tuple with the 1st element in
+ the 0th operand. And both of them are of the same type: ``tensor<2x3xf32>``.
+ ```
+ """
+
+ name = "stablehlo.output_operand_alias"
+
+ output_tuple_indices: ArrayAttr[IntegerAttr[I64]]
+ operand_index: IntegerAttr[I64]
+ operand_tuple_indices: ArrayAttr[IntegerAttr[I64]]
+
+ def print_parameters(self, printer: Printer) -> None:
+ """Print the OutputOperandAlias attribute."""
+ with printer.in_angle_brackets():
+ with printer.indented():
+ printer.print_string("\noutput_tuple_indices = ")
+ print_dims(printer, self.output_tuple_indices)
+ printer.print_string(",")
+
+ printer.print_string("\noperand_index = ")
+ printer.print_string(f"{self.operand_index.value.data}")
+ printer.print_string(",")
+
+ printer.print_string("\noperand_tuple_indices = ")
+ print_dims(printer, self.operand_tuple_indices)
+ printer.print_string("\n")
+
+ @classmethod
+ def parse_parameters(cls, parser: AttrParser):
+ """Parse the OutputOperandAlias attribute."""
+ with parser.in_angle_brackets():
+ output_tuple_indices = ArrayAttr([])
+ operand_index = IntegerAttr(0, i64)
+ operand_tuple_indices = ArrayAttr([])
+
+ if parser.parse_optional_characters("output_tuple_indices") is not None:
+ parser.parse_punctuation("=")
+ output_tuple_indices = parse_dims(parser)
+ parser.parse_optional_punctuation(",")
+
+ if parser.parse_optional_characters("operand_index") is not None:
+ parser.parse_punctuation("=")
+ operand_index = IntegerAttr(parser.parse_integer(), i64)
+ parser.parse_optional_punctuation(",")
+
+ if parser.parse_optional_characters("operand_tuple_indices") is not None:
+ parser.parse_punctuation("=")
+ operand_tuple_indices = parse_dims(parser)
+
+ return (output_tuple_indices, operand_index, operand_tuple_indices)
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py
new file mode 100644
index 0000000000..c9edced70d
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/control_flow.py
@@ -0,0 +1,160 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Control flow operations for the StableHLO dialect.
+"""
+
+from typing import TypeVar
+
+from xdsl.dialects.builtin import AnyTensorType
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ region_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import (
+ Pure,
+ RecursivelySpeculatable,
+ RecursiveMemoryEffect,
+ SingleBlockImplicitTerminator,
+)
+from xdsl_jax.dialects.stablehlo import ReturnOp
+
+# Import our custom StableHLO types
+from .types import HLO_PredTensor, HLO_TensorOrPerAxisQuantizedTensorOrToken, HLO_TensorOrToken
+
+# Generic type variables for templating
+T_IN = TypeVar("T_IN", bound=AnyTensorType)
+T_OUT = TypeVar("T_OUT", bound=AnyTensorType)
+
+
+@irdl_op_definition
+class IfOp(IRDLOperation):
+ """
+ Produces the output from executing exactly one branch from `true_branch` or
+ `false_branch` depending on the value of `pred`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if
+
+ Example:
+ %result = "stablehlo.if"(%pred) ({
+ "stablehlo.return"(%result_true_branch) : (tensor) -> ()
+ }, {
+ "stablehlo.return"(%result_false_branch) : (tensor) -> ()
+ }) : (tensor) -> tensor
+ """
+
+ name = "stablehlo.if"
+
+ pred = operand_def(HLO_PredTensor)
+
+ res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
+
+ true_branch = region_def("single_block")
+
+ false_branch = region_def("single_block")
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ RecursivelySpeculatable(),
+ SingleBlockImplicitTerminator(ReturnOp),
+ # TODO: InferTypeOpInterface
+ # TODO: OpAsmOpInterface
+ )
+
+ # TODO: Add custom assembly format
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class WhileOp(IRDLOperation):
+ """
+ Produces the output from executing `body` function 0 or more times while the
+ `cond` function outputs `true`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while
+
+ Example:
+ ```mlir
+ %results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor, tensor
+ cond {
+ %cond = stablehlo.compare LT, %arg0, %ten : (tensor, tensor) -> tensor
+ stablehlo.return %cond : tensor
+ } do {
+ %new_sum = stablehlo.add %arg1, %one : tensor
+ %new_i = stablehlo.add %arg0, %one : tensor
+ stablehlo.return %new_i, %new_sum : tensor, tensor
+ }
+ """
+
+ name = "stablehlo.while"
+
+ operand = var_operand_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
+
+ res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
+
+ cond = region_def("single_block")
+
+ body = region_def("single_block")
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ RecursivelySpeculatable(),
+ SingleBlockImplicitTerminator(ReturnOp),
+ # TODO: InferTypeOpInterface
+ # TODO: OpAsmOpInterface
+ )
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class OptimizationBarrierOp(IRDLOperation):
+ """
+ Ensures that the operations that produce the `operand` are executed before any
+ operations that depend on the `result` and prevents compiler transformations
+ from moving operations across the barrier. Other than that, the operation is
+ an identity, i.e. `result` = `operand`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier
+
+ Example:
+ ```mlir
+ %result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor, tensor
+ ```
+ """
+
+ name = "stablehlo.optimization_barrier"
+
+ operand = var_operand_def(HLO_TensorOrToken)
+
+ res = var_result_def(HLO_TensorOrToken)
+
+ traits = traits_def(
+ Pure(),
+ # TODO: HLO_PairwiseSameOperandAndResultType
+ # TODO: InferTypeOpInterface
+ )
+
+ # TODO: Add custom assembly format
+ # assembly_format = """
+ # attr-dict ($operand^ `:` custom(type($operand), type($result))):(`(` `)`)?
+ # """
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py
new file mode 100644
index 0000000000..68ce2bddc7
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/data_movement.py
@@ -0,0 +1,416 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Data movement operations for the StableHLO dialect.
+"""
+
+from xdsl.dialects.builtin import BoolAttr, DenseArrayBase, IntegerAttr, TensorType, i64
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ opt_prop_def,
+ prop_def,
+ region_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.irdl.attributes import eq
+from xdsl.irdl.constraints import AtLeast
+from xdsl.irdl.operations import SameVariadicOperandSize
+from xdsl.traits import (
+ ConditionallySpeculatable,
+ NoMemoryEffect,
+ Pure,
+ RecursiveMemoryEffect,
+)
+from xdsl.utils.exceptions import VerifyException
+from xdsl.utils.type import get_element_type_or_self
+
+from catalyst.python_interface.xdsl_extras import (
+ AllMatchSameOperatorTrait,
+ SameOperandsAndResultElementType,
+ TensorConstraint,
+)
+
+from .attributes import GatherDimensionNumbers, ScatterDimensionNumbers
+from .types import HLO_AnyIntegerOrIndexTensor, HLO_AnyTensor, HLO_Int, HLO_IntTensor, HLO_Tensor
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class BroadcastInDimOp(IRDLOperation):
+ """
+ Expands the dimensions and/or rank of an input tensor by duplicating the
+ data in the ``operand`` tensor and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim
+
+ Example:
+ ```mlir
+ %result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
+ ```
+ """
+
+ name = "stablehlo.broadcast_in_dim"
+ operand = operand_def(HLO_AnyTensor)
+ broadcast_dimensions = prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_AnyTensor)
+
+ assembly_format = """
+ $operand `,` `dims` `=` $broadcast_dimensions
+ attr-dict `:` functional-type(operands, results)
+ """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic,
+ # TODO: HLO_CompatibleOperandsAndResultElementType,
+ )
+
+ def verify_(self) -> None:
+ """Verify non-quantized broadcast_in_dim constraints."""
+ o_type = self.operand_types[0]
+ r_type = self.result_types[0]
+
+ # These are constrained to tensors by the op definition
+ assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType)
+
+ # broadcast_in_dim_c2: broadcast_dimensions size == operand rank
+ dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member
+ operand_rank = o_type.get_num_dims()
+ if len(dims) != operand_rank:
+ raise VerifyException(
+ "broadcast_dimensions size ("
+ f"{len(dims)}"
+ ") does not match operand rank ("
+ f"{operand_rank}"
+ ")"
+ )
+
+ # broadcast_in_dim_c4: broadcast_dimensions should not have duplicates
+ if len(set(dims)) != len(dims):
+ raise VerifyException("broadcast_dimensions should not have duplicates")
+
+ # Result rank and per-dimension checks
+ result_rank = r_type.get_num_dims()
+ o_shape = o_type.get_shape()
+ r_shape = r_type.get_shape()
+
+ for i, dim_index in enumerate(dims):
+ # broadcast_in_dim_c3: each dim index in bounds of result rank
+ if dim_index < 0 or dim_index >= result_rank:
+ raise VerifyException(
+ "broadcast_dimensions contains invalid value "
+ f"{dim_index} for result with rank {result_rank}"
+ )
+
+ # If operand dim is static, enforce broadcast_in_dim_c5
+ if o_shape[i] != -1:
+ dim_size = o_shape[i]
+ result_dim_size = r_shape[dim_index]
+ if dim_size not in (1, result_dim_size):
+ raise VerifyException(
+ "size of operand dimension "
+ f"{i} ({dim_size}) is not equal to 1 or size of result dimension "
+ f"{dim_index} ({result_dim_size})"
+ )
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class ConcatenateOp(IRDLOperation):
+ """
+ Concatenates a variadic number of tensors in ``inputs`` along ``dimension``
+ dimension in the same order as the given arguments and produces a ``result``
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate
+
+ Example:
+ ```mlir
+ %result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
+ ```
+ """
+
+ name = "stablehlo.concatenate"
+
+ inputs = var_operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+ dimension = prop_def(IntegerAttr.constr(type=eq(i64), value=AtLeast(0)))
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ SameOperandsAndResultElementType(),
+ # InferTypeOpInterface(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results)
+ # """
+
+
+@irdl_op_definition
+class DynamicSliceOp(IRDLOperation):
+ """
+ Extracts a slice from the ``operand`` using dynamically-computed starting
+ indices and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice
+
+ Example:
+ ```mlir
+ %result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2]
+ : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x2xi32>
+ ```
+ """
+
+ name = "stablehlo.dynamic_slice"
+ operand = operand_def(HLO_Tensor)
+ start_indices = var_operand_def(TensorConstraint(element_type=HLO_Int, rank=0))
+ slice_sizes = prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_Tensor)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand `,` custom($start_indices)
+ # `sizes` `=` $slice_sizes attr-dict `:` functional-type(operands, results)
+ # """
+
+ traits = traits_def(
+ Pure(),
+ AllMatchSameOperatorTrait(
+ ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type"
+ ),
+ # TODO: InferTensorType(),
+ )
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class GatherOp(IRDLOperation):
+ """
+ Gathers slices from ``operand`` tensor from offsets specified in
+ ``start_indices`` and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather
+
+ Example:
+ ```mlir
+ %result = "stablehlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #stablehlo.gather<
+ offset_dims = [3, 4],
+ collapsed_slice_dims = [1],
+ operand_batching_dims = [0],
+ start_indices_batching_dims = [1],
+ start_index_map = [2, 1],
+ index_vector_dim = 3>,
+ slice_sizes = array,
+ indices_are_sorted = false
+ } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64>
+ ```
+ """
+
+ name = "stablehlo.gather"
+ operand = operand_def(HLO_Tensor)
+ start_indices = operand_def(HLO_IntTensor)
+ dimension_numbers = prop_def(GatherDimensionNumbers)
+ slice_sizes = prop_def(DenseArrayBase.constr(i64))
+ indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ result = result_def(HLO_Tensor)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ AllMatchSameOperatorTrait(
+ ("operand", "result"), lambda x: get_element_type_or_self(x.type), "element type"
+ ),
+ # TODO: InferTensorTypeWithReify(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results)
+ # """
+
+
+@irdl_op_definition
+class ReshapeOp(IRDLOperation):
+ """
+ Performs reshape of ``operand`` tensor to a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape
+
+ Example:
+ ```mlir
+ %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32>
+ ```
+ """
+
+ name = "stablehlo.reshape"
+ operand = operand_def(HLO_AnyTensor)
+ result = result_def(HLO_AnyTensor)
+
+ assembly_format = """
+ operands attr-dict `:` functional-type(operands, results)
+ """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ # TODO: HLO_CompatibleOperandsAndResultElementType,
+ )
+
+ def verify_(self) -> None:
+ """Verify that the operation has the same shape for all operands and results."""
+ o_type = self.operand_types[0]
+ r_type = self.result_types[0]
+
+ # These are constrained to tensors by the op definition
+ assert isinstance(o_type, TensorType) and isinstance(r_type, TensorType)
+
+ # If o_type or r_type is dynamically shaped there is nothing to verify.
+ if not o_type.has_static_shape() or not r_type.has_static_shape():
+ return
+
+ # If the operand type is statically shaped (not required) the number of
+ # elements must match that of the result type.
+ num_operand_elements = 1
+ for dim in o_type.get_shape():
+ num_operand_elements *= dim
+
+ num_result_elements = 1
+ for dim in r_type.get_shape():
+ num_result_elements *= dim
+
+ if num_result_elements != num_operand_elements:
+ raise VerifyException(
+ "number of output elements ("
+ f"{num_result_elements}"
+ ") doesn't match expected number of elements ("
+ f"{num_operand_elements}"
+ ")"
+ )
+
+
+@irdl_op_definition
+class ScatterOp(IRDLOperation):
+ """
+ Produces ``results`` tensors which are equal to ``inputs`` tensors except that
+ several slices specified by ``scatter_indices`` are updated with the values
+ ``updates`` using ``update_computation``.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
+
+ Example:
+ ```mlir
+ %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = stablehlo.add %arg0, %arg1 : tensor
+ stablehlo.return %0 : tensor
+ }) {
+ scatter_dimension_numbers = #stablehlo.scatter<
+ update_window_dims = [3, 4],
+ inserted_window_dims = [1],
+ input_batching_dims = [0],
+ scatter_indices_batching_dims = [1],
+ scatter_dims_to_operand_dims = [2, 1],
+ index_vector_dim = 3>,
+ indices_are_sorted = false,
+ unique_indices = false
+ } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
+ ```
+ """
+
+ name = "stablehlo.scatter"
+ inputs = var_operand_def(HLO_Tensor)
+ scatter_indices = operand_def(HLO_AnyIntegerOrIndexTensor)
+ updates = var_operand_def(HLO_Tensor)
+ scatter_dimension_numbers = prop_def(ScatterDimensionNumbers)
+ indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ unique_indices = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ result = var_result_def(HLO_Tensor)
+ update_computation = region_def("single_block")
+ # TODO: The MLIR implementation doesn't have the SingleBlockImplicitTerminator trait,
+ # However, it is checked to have a terminator in the verifier,
+ # which does not specifically check the terminator to be stablehlo.return.
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ ConditionallySpeculatable(),
+ # TODO: InferTypeOpInterface(),
+ )
+
+ irdl_options = [SameVariadicOperandSize()]
+
+ # TODO: MLIR has a custom verifier for the scatter operation.
+
+
+@irdl_op_definition
+class SliceOp(IRDLOperation):
+ """
+ Extracts a slice from the ``operand`` using statically-computed starting
+ indices and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
+
+ Example:
+ ```mlir
+ %result = stablehlo.slice %operand [1:3, 4:8:2]
+ : (tensor<3x8xi64>) -> tensor<2x2xi64>
+
+ // Same in generic form: the `1:3` above is mapped to the first entry in
+ // `start_indices` and `limit_indices`, while `strides` is implicitly 1.
+ // The `4:8:2` above is parsed into the second entry of `start_indices`,
+ // `limit_indices` and `strides` respectively.
+ %result = "stablehlo.slice" (%operand) {
+ start_indices = array,
+ limit_indices = array,
+ strides = array
+ } : (tensor<3x8xi64>) -> tensor<2x2xi64>
+ ```
+ """
+
+ name = "stablehlo.slice"
+
+ operand = operand_def(HLO_Tensor)
+ start_indices = prop_def(DenseArrayBase.constr(i64))
+ limit_indices = prop_def(DenseArrayBase.constr(i64))
+ strides = prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_Tensor)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand custom($start_indices, $limit_indices, $strides)
+ # attr-dict `:` functional-type(operands, results)
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ ConditionallySpeculatable(),
+ AllMatchSameOperatorTrait(("start_indices", "limit_indices", "strides"), len, "size"),
+ SameOperandsAndResultElementType(),
+ )
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py
new file mode 100644
index 0000000000..ce880bba2b
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/dialect.py
@@ -0,0 +1,207 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Extended StableHLO dialect that dynamically includes all upstream operations
+plus custom operations for PennyLane's compiler infrastructure.
+
+This module automatically imports all operations and attributes from the upstream
+xdsl_jax.dialects.stablehlo and adds custom ones without needing to hardcode
+the upstream operation list.
+"""
+
+import xdsl_jax.dialects.stablehlo as xstablehlo
+from xdsl.ir import Dialect
+
+from .attributes import (
+ CustomCallApiVersionAttr,
+ GatherDimensionNumbers,
+ OutputOperandAlias,
+ ResultAccuracyModeAttr,
+ ScatterDimensionNumbers,
+)
+from .control_flow import (
+ IfOp,
+ OptimizationBarrierOp,
+ WhileOp,
+)
+from .data_movement import (
+ BroadcastInDimOp,
+ ConcatenateOp,
+ DynamicSliceOp,
+ GatherOp,
+ ReshapeOp,
+ ScatterOp,
+ SliceOp,
+)
+from .dynamism import (
+ DynamicBroadcastInDimOp,
+)
+from .elementwise_binary import (
+ ComplexOp,
+ DivideOp,
+ MaximumOp,
+ MinimumOp,
+ PowerOp,
+ RemainderOp,
+)
+from .elementwise_other import (
+ ClampOp,
+ CompareOp,
+ ConstantOp,
+ MapOp,
+ ReducePrecisionOp,
+ SelectOp,
+)
+
+# Import all elementwise operations from organized files
+from .elementwise_unary import (
+ ConvertOp,
+ CosineOp,
+ ExponentialMinusOneOp,
+ ExponentialOp,
+ FloorOp,
+ ImagOp,
+ IsFiniteOp,
+ LogisticOp,
+ LogOp,
+ LogPlusOneOp,
+ NegateOp,
+ RealOp,
+ RoundNearestAfzOp,
+ RoundNearestEvenOp,
+ RsqrtOp,
+ SignOp,
+ SineOp,
+ SqrtOp,
+ TanhOp,
+ TanOp,
+)
+from .extensibility import (
+ CustomCallOp,
+)
+from .reduction import (
+ ReduceOp,
+)
+from .types import UniformQuantizedPerAxisType, UniformQuantizedType
+
+# Operations to add to the dialect
+OPERATIONS = [
+ ClampOp,
+ CompareOp,
+ ComplexOp,
+ ConstantOp,
+ ConvertOp,
+ CosineOp,
+ DivideOp,
+ ExponentialMinusOneOp,
+ ExponentialOp,
+ FloorOp,
+ ImagOp,
+ IsFiniteOp,
+ LogOp,
+ LogPlusOneOp,
+ LogisticOp,
+ MapOp,
+ MaximumOp,
+ MinimumOp,
+ NegateOp,
+ PowerOp,
+ RealOp,
+ ReducePrecisionOp,
+ RemainderOp,
+ RoundNearestAfzOp,
+ RoundNearestEvenOp,
+ RsqrtOp,
+ SelectOp,
+ SignOp,
+ SineOp,
+ SqrtOp,
+ TanOp,
+ TanhOp,
+ # Data movement operations
+ BroadcastInDimOp,
+ ConcatenateOp,
+ DynamicSliceOp,
+ GatherOp,
+ ReshapeOp,
+ ScatterOp,
+ SliceOp,
+ # Control flow operations
+ IfOp,
+ WhileOp,
+ OptimizationBarrierOp,
+ # Dynamism operations
+ DynamicBroadcastInDimOp,
+ # Reduction operations
+ ReduceOp,
+ # Extensibility operations
+ CustomCallOp,
+]
+
+# Attributes to add to the dialect
+ATTRIBUTES = [
+ CustomCallApiVersionAttr,
+ GatherDimensionNumbers,
+ ResultAccuracyModeAttr,
+ OutputOperandAlias,
+ ScatterDimensionNumbers,
+ UniformQuantizedPerAxisType,
+ UniformQuantizedType,
+]
+
+# Operations/attributes from upstream that should be deleted/replaced in the local version
+UPSTREAM_OPERATIONS_TO_DELETE = [
+ xstablehlo.ConstantOp,
+]
+UPSTREAM_ATTRIBUTES_TO_DELETE = []
+
+
+def filter_and_extend_upstream(upstream_list, to_delete, to_add):
+ """Filter out operations/attributes from upstream list and add new ones.
+
+ Args:
+ upstream_list: List of operations/attributes to filter
+ to_delete: List of operations/attributes to remove
+ to_add: List of operations/attributes to add
+
+ Returns:
+ Modified list of operations/attributes
+ """
+ filtered_ops = list(upstream_list)
+
+ # Remove operations that should be deleted
+ for op_to_delete in to_delete:
+ if op_to_delete in filtered_ops:
+ filtered_ops.remove(op_to_delete)
+
+ # Add new operations
+ filtered_ops.extend(to_add)
+
+ return filtered_ops
+
+
+all_operations = filter_and_extend_upstream(
+ xstablehlo.StableHLO.operations, UPSTREAM_OPERATIONS_TO_DELETE, OPERATIONS
+)
+all_attributes = filter_and_extend_upstream(
+ xstablehlo.StableHLO.attributes, UPSTREAM_ATTRIBUTES_TO_DELETE, ATTRIBUTES
+)
+
+# Create the extended StableHLO dialect by dynamically getting upstream components
+StableHLO = Dialect(
+ "stablehlo",
+ all_operations,
+ all_attributes,
+)
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py
new file mode 100644
index 0000000000..e38b4986f7
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/dynamism.py
@@ -0,0 +1,198 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dynamism operations for the StableHLO dialect.
+"""
+
+from xdsl.dialects.builtin import DenseArrayBase, TensorType, i64
+from xdsl.irdl import (
+ IRDLOperation,
+ ParsePropInAttrDict,
+ irdl_op_definition,
+ operand_def,
+ opt_prop_def,
+ prop_def,
+ result_def,
+ traits_def,
+)
+from xdsl.traits import (
+ ConditionallySpeculatable,
+ NoMemoryEffect,
+)
+from xdsl.utils.exceptions import VerifyException
+
+from catalyst.python_interface.xdsl_extras import TensorConstraint
+
+from .types import HLO_AnyTensor, HLO_DimensionValue
+
+
+@irdl_op_definition
+class DynamicBroadcastInDimOp(IRDLOperation):
+ """
+ This operation is functionally identical to
+ [broadcast_in_dim](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim)
+ op, but the result shape is specified dynamically via ``output_dimensions``.
+
+ It also accepts optional attributes to express static knowledge about the
+ expanding behavior of dimensions. If not specified, all dimensions are
+ assumed to be possibly expanding. The sets of dimensions that are known to
+ be expanding and the set of dimensions that are known to be non-expanding
+ must be disjoint and they must be a subset of the operand's dimensions.
+
+ See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim
+
+ Example:
+ ```mlir
+ %operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
+ %output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
+ %result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
+ broadcast_dimensions = array,
+ known_expanding_dimensions = array,
+ known_nonexpanding_dimensions = array
+ } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
+ ```
+ """
+
+ name = "stablehlo.dynamic_broadcast_in_dim"
+
+ operand = operand_def(HLO_AnyTensor)
+ output_dimensions = operand_def(TensorConstraint(element_type=HLO_DimensionValue, rank=1))
+ broadcast_dimensions = prop_def(DenseArrayBase.constr(i64))
+ known_expanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64))
+ known_nonexpanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64))
+ result = result_def(HLO_AnyTensor)
+
+ assembly_format = (
+ "$operand `,` $output_dimensions `,` `dims` `=` $broadcast_dimensions "
+ "attr-dict `:` functional-type(operands, results)"
+ )
+
+ traits = traits_def(
+ ConditionallySpeculatable(),
+ NoMemoryEffect(),
+ # TODO: InferShapedTypeOpInterface(),
+ )
+
+ irdl_options = [ParsePropInAttrDict()]
+
+ # pylint: disable=too-many-branches
+ def verify_(self):
+ """Verify the operation."""
+ operand_ty = self.operand_types[0]
+ result_ty = self.result_types[0]
+ bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member
+
+ # Operand and result must be tensors
+ assert isinstance(operand_ty, TensorType) and isinstance(result_ty, TensorType)
+
+ self._verify_rank_constraints(bcast_dims, operand_ty, result_ty)
+
+ # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates
+ if len(set(bcast_dims)) != len(bcast_dims):
+ raise VerifyException("broadcast_dimensions should not have duplicates")
+
+ self._verify_per_dimension_bounds(bcast_dims, operand_ty, result_ty)
+
+ self._verify_expansion_hints(operand_ty)
+
+ def _verify_rank_constraints(self, bcast_dims, operand_ty, result_ty):
+ """Verify then operand and result tensors against the rank constraints."""
+
+ operand_rank = operand_ty.get_num_dims()
+ result_rank = result_ty.get_num_dims()
+
+ # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank
+ if len(bcast_dims) != operand_rank:
+ raise VerifyException(
+ "broadcast_dimensions size ("
+ f"{len(bcast_dims)}"
+ ") does not match operand rank ("
+ f"{operand_rank}"
+ ")"
+ )
+
+ # dynamic_broadcast_in_dim_c3: result rank >= operand rank
+ if result_rank < operand_rank:
+ raise VerifyException(
+ "result rank ("
+ f"{result_rank}"
+ ") is less than operand rank ("
+ f"{operand_rank}"
+ ")"
+ )
+
+ # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank
+ out_dims_ty = self.output_dimensions.type # pylint: disable=no-member
+ assert isinstance(out_dims_ty, TensorType)
+ # Must be rank-1 tensor (enforced by type constraint), and length must match result
+ # rank when statically known
+ out_shape = out_dims_ty.get_shape()
+ if len(out_shape) != 1:
+ raise VerifyException("output_dimensions must be a 1D tensor")
+ if out_shape[0] != -1 and out_shape[0] != result_rank:
+ raise VerifyException(
+ "length of output_dimensions ("
+ f"{out_shape[0]}"
+ ") is not compatible with result rank ("
+ f"{result_rank}"
+ ")"
+ )
+
+ def _verify_per_dimension_bounds(self, bcast_dims, operand_ty, result_ty):
+ """Verify compatibility of operand and result dimensions."""
+ # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility
+ operand_shape = operand_ty.get_shape()
+ result_shape = result_ty.get_shape()
+ result_rank = result_ty.get_num_dims()
+
+ for i, dim_index in enumerate(bcast_dims):
+ if dim_index < 0 or dim_index >= result_rank:
+ raise VerifyException(
+ "broadcast_dimensions contains invalid value "
+ f"{dim_index} for result with rank {result_rank}"
+ )
+ op_dim = operand_shape[i]
+ res_dim = result_shape[dim_index]
+ # If operand dim is static and not size-1, require compatibility with result dim
+ if op_dim not in (-1, 1):
+ if res_dim not in (-1, op_dim):
+ raise VerifyException(
+ "size of operand dimension "
+ f"{i} ({op_dim}) is not compatible with size of result dimension "
+ f"{dim_index} ({res_dim})"
+ )
+
+ def _verify_expansion_hints(self, operand_ty):
+ """Verify the operation's expansion hints."""
+ # dynamic_broadcast_in_dim_c8: no duplicate expansion hints across both lists
+ operand_rank = operand_ty.get_num_dims()
+
+ hints = []
+ if self.known_expanding_dimensions is not None:
+ hints.extend(self.known_expanding_dimensions.get_values()) # pylint: disable=no-member
+ if self.known_nonexpanding_dimensions is not None:
+ hints.extend(
+ self.known_nonexpanding_dimensions.get_values() # pylint: disable=no-member
+ )
+ if len(set(hints)) != len(hints):
+ raise VerifyException("duplicate expansion hint for at least one operand dimension")
+
+ # dynamic_broadcast_in_dim_c9/c10: each hint must reference a valid operand dimension
+ for h in set(hints):
+ if h < 0 or h >= operand_rank:
+ raise VerifyException(
+ "hint for expanding dimension "
+ f"{h} does not refer to a valid operand dimension"
+ )
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py
new file mode 100644
index 0000000000..8180c7a208
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_binary.py
@@ -0,0 +1,214 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Binary elementwise operations for the StableHLO dialect.
+"""
+
+import abc
+from typing import Generic, TypeVar
+
+from xdsl.dialects.builtin import AnyTensorType, ComplexType, Float32Type, Float64Type, TensorType
+from xdsl.ir import Attribute, SSAValue
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ result_def,
+ traits_def,
+)
+from xdsl.traits import NoMemoryEffect
+
+from catalyst.python_interface.xdsl_extras import (
+ Elementwise,
+ SameOperandsAndResultShape,
+ SameOperandsElementType,
+)
+
+from .types import (
+ HLO_ComplexTensor,
+ HLO_Fp32Or64Tensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_Tensor,
+)
+
+# Type aliases
+F32Or64Type = Float32Type | Float64Type
+F32Or64TensorType = TensorType[F32Or64Type]
+ComplexTensorType = TensorType[ComplexType]
+
+# Generic type variables for templating
+T_LHS = TypeVar("T_LHS", bound=AnyTensorType)
+T_RHS = TypeVar("T_RHS", bound=AnyTensorType)
+T_OUT = TypeVar("T_OUT", bound=AnyTensorType)
+
+
+class ElementwiseBinaryOperation(IRDLOperation, abc.ABC, Generic[T_LHS, T_RHS, T_OUT]):
+ """
+ Templated base class for elementwise binary operations.
+
+ This class provides a flexible template for binary operations that can work
+ with different tensor types.
+
+ For more information about the semantics, see:
+ https://openxla.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
+ """
+
+ lhs = operand_def(T_LHS)
+ rhs = operand_def(T_RHS)
+ result = result_def(T_OUT)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ SameOperandsAndResultShape(),
+ Elementwise(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $lhs `,` $rhs attr-dict
+ # `:` custom(type($lhs), type($rhs), type($result))
+ # """
+
+ def __init__(self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None):
+ if result_type is None:
+ result_type = lhs.type
+ super().__init__(operands=(lhs, rhs), result_types=(result_type,))
+
+
+@irdl_op_definition
+class ComplexOp(
+ ElementwiseBinaryOperation[HLO_Fp32Or64Tensor, HLO_Fp32Or64Tensor, HLO_ComplexTensor]
+):
+ """
+ Performs element-wise conversion to a complex value from a pair of real and
+ imaginary values, `lhs` and `rhs`, and produces a `result` tensor.
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex
+ Example:
+ ```mlir
+ %result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex>
+ ```
+ """
+
+ name = "stablehlo.complex"
+
+ # assembly_format = """
+ # operands attr-dict
+ # `:` custom(type($lhs), type($rhs), type($result))
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ SameOperandsElementType(),
+ SameOperandsAndResultShape(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ )
+
+
+@irdl_op_definition
+class DivideOp(
+ ElementwiseBinaryOperation[
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ ]
+):
+ """
+ Performs element-wise division of dividend `lhs` and divisor `rhs` tensors
+ and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide
+
+ Example:
+ ```mlir
+ %result = stablehlo.divide %lhs, %rhs : tensor<4xf32>
+ ```
+ """
+
+ name = "stablehlo.divide"
+
+
+@irdl_op_definition
+class MaximumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise max operation on tensors `lhs` and `rhs` and produces
+ a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum
+
+ Example:
+ ```mlir
+ %result = stablehlo.maximum %lhs, %rhs : tensor<4xf32>
+ ```
+ """
+
+ name = "stablehlo.maximum"
+
+
+@irdl_op_definition
+class MinimumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise min operation on tensors `lhs` and `rhs` and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum
+
+ Example:
+ ```mlir
+ %result = stablehlo.minimum %lhs, %rhs : tensor<4xf32>
+ ```
+ """
+
+ name = "stablehlo.minimum"
+
+
+@irdl_op_definition
+class PowerOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
+
+ Example:
+ ```mlir
+ %result = stablehlo.power %lhs, %rhs : tensor<6xf64>
+ ```
+ """
+
+ name = "stablehlo.power"
+
+
+@irdl_op_definition
+class RemainderOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]):
+ """
+ Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors
+ and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder
+
+ Example:
+ ```mlir
+ %result = stablehlo.remainder %lhs, %rhs : tensor<4xi64>
+ ```
+ """
+
+ name = "stablehlo.remainder"
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py
new file mode 100644
index 0000000000..18527a028d
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_other.py
@@ -0,0 +1,236 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Other elementwise operations for the StableHLO dialect.
+"""
+
+import xdsl_jax.dialects.stablehlo as xstablehlo
+from xdsl.dialects.builtin import (
+ AnyFloat,
+ DenseArrayBase,
+ DenseIntOrFPElementsAttr,
+ IntegerAttr,
+ TensorType,
+ i32,
+ i64,
+)
+from xdsl.irdl import (
+ IRDLOperation,
+ attr_def,
+ irdl_op_definition,
+ operand_def,
+ opt_attr_def,
+ prop_def,
+ result_def,
+ traits_def,
+ var_operand_def,
+ var_region_def,
+)
+from xdsl.irdl.attributes import eq
+from xdsl.irdl.constraints import AtLeast
+from xdsl.traits import NoMemoryEffect, RecursiveMemoryEffect, SingleBlockImplicitTerminator
+
+from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape
+
+from .types import HLO_AnyTensor, HLO_FpOrQuantizedIntTensor, HLO_PredTensor, HLO_Tensor
+
+# Type aliases
+FloatTensorType = TensorType[AnyFloat]
+
+
+@irdl_op_definition
+class ClampOp(IRDLOperation):
+ """Element-wise clamp with min and max bounds.
+
+ See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp
+ """
+
+ name = "stablehlo.clamp"
+
+ min = operand_def(HLO_Tensor)
+ operand = operand_def(HLO_Tensor)
+ max = operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $min `,` $operand `,` $max attr-dict
+ # `:` custom(type($min), type($operand), type($max), type($result))
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ # TODO: HLO_CompatibleOperandsAndResultElementType(),
+ # TODO: HLO_BroadcastingElementwise(),
+ # TODO: InferTensorType(),
+ # TODO: InferShapedTypeOpInterface(),
+ )
+
+
+@irdl_op_definition
+class CompareOp(IRDLOperation):
+ """Element-wise compare with direction and type attributes."""
+
+ name = "stablehlo.compare"
+
+ assembly_format = """
+ $comparison_direction `,` $lhs `,` $rhs (`,` $comparison_type^)? attr-dict `:` functional-type(operands, results)
+ """
+
+ lhs = operand_def(HLO_Tensor)
+ rhs = operand_def(HLO_Tensor)
+ result = result_def(HLO_PredTensor)
+ comparison_direction = attr_def(xstablehlo.ComparisonDirectionAttr)
+ comparison_type = opt_attr_def(xstablehlo.ComparisonTypeAttr)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ Elementwise(),
+ SameOperandsAndResultShape(),
+ # TODO: HLO_SpeculatableIfAllInputsStatic(),
+ # TODO: HLO_CompatibleOperandsElementType(),
+ # TODO: InferTensorTypeWithReify(),
+ )
+
+
+@irdl_op_definition
+class MapOp(IRDLOperation):
+ """
+ Applies a map function `computation` to `inputs` along the `dimensions` and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map
+
+ Example:
+ ```mlir
+ %result = "stablehlo.map"(%input0, %input1) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = stablehlo.multiply %arg0, %arg1 : tensor
+ stablehlo.return %0 : tensor
+ }) {
+ dimensions = array
+ } : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
+ ```
+ """
+
+ name = "stablehlo.map"
+
+ inputs = var_operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+ dimensions = attr_def(DenseArrayBase.constr(i64))
+ computation = var_region_def("single_block")
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ SameOperandsAndResultShape(),
+ SingleBlockImplicitTerminator(xstablehlo.ReturnOp),
+ # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic(),
+ # TODO: InferTypeOpInterface
+ # TODO: InferShapedTypeOpInterface(),
+ )
+
+
+@irdl_op_definition
+class ReducePrecisionOp(IRDLOperation):
+ """
+ Performs element-wise conversion of `operand` to another floating-point type
+ that uses `exponent_bits` and `mantissa_bits` and back to the original
+ floating-point type and produces an `output` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision
+
+ Example:
+ ```mlir
+ %output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64>
+ ```
+ """
+
+ name = "stablehlo.reduce_precision"
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand `,` `format` `=` custom($exponent_bits, $mantissa_bits)
+ # attr-dict `:` custom(type($operand), type($output))
+ # """
+
+ operand = operand_def(HLO_FpOrQuantizedIntTensor)
+ result = result_def(HLO_FpOrQuantizedIntTensor)
+
+ exponent_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(1)))
+ mantissa_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(0)))
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ Elementwise(),
+ # TODO: HLO_CompatibleOperandsAndResultType(),
+ # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(),
+ )
+
+
+@irdl_op_definition
+class SelectOp(IRDLOperation):
+ """
+ Produces a `result` tensor where each element is selected from `on_true` or
+ `on_false` tensor based on the value of the corresponding element of `pred`.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select
+
+ Example:
+ ```mlir
+ %result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32>
+ ```
+ """
+
+ name = "stablehlo.select"
+
+ # assembly_format = """
+ # operands attr-dict `:`
+ # custom(type($pred), type($on_true), type($on_false), type($result))
+ # """
+
+ pred = operand_def(HLO_PredTensor)
+ on_true = operand_def(HLO_Tensor)
+ on_false = operand_def(HLO_Tensor)
+ result = result_def(HLO_Tensor)
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ )
+
+
+@irdl_op_definition
+class ConstantOp(IRDLOperation):
+ """
+ Produces an ``output`` tensor from a constant ``value``.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
+
+ Example:
+ ```mlir
+ %output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
+ """
+
+ name = "stablehlo.constant"
+
+ value = prop_def(DenseIntOrFPElementsAttr)
+ output = result_def(HLO_AnyTensor)
+
+ def __init__(self, value: DenseIntOrFPElementsAttr):
+ super().__init__(properties={"value": value}, result_types=(value.type,))
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py
new file mode 100644
index 0000000000..a8a6e6ca86
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/elementwise_unary.py
@@ -0,0 +1,552 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Unary elementwise operations for the StableHLO dialect.
+"""
+
+import abc
+from typing import Generic, TypeVar
+
+from xdsl.dialects.builtin import (
+ I1,
+ AnyFloat,
+ AnyTensorType,
+ ComplexType,
+ TensorType,
+)
+from xdsl.ir import Attribute, SSAValue
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ operand_def,
+ opt_attr_def,
+ result_def,
+ traits_def,
+)
+from xdsl.traits import NoMemoryEffect
+
+from catalyst.python_interface.xdsl_extras import Elementwise, SameOperandsAndResultShape
+
+from .attributes import ResultAccuracyMode, ResultAccuracyModeAttr
+from .types import (
+ HLO_FloatTensor,
+ HLO_FpComplexOrQuantizedIntTensor,
+ HLO_FpOrComplexTensor,
+ HLO_FpOrQuantizedIntTensor,
+ HLO_IntFpOrComplexOrQuantizedIntTensor,
+ HLO_NonQuantizedTensor,
+ HLO_PredTensor,
+ HLO_SIntFpComplexOrQuantizedIntTensor,
+)
+
+# Type aliases
+I1TensorType = TensorType[I1]
+FloatTensorType = TensorType[AnyFloat]
+FloatOrComplexType = AnyFloat | ComplexType
+FloatOrComplexTensorType = TensorType[FloatOrComplexType]
+ComplexTensorType = TensorType[ComplexType]
+
+# Generic type variables for templating
+T_IN = TypeVar("T_IN", bound=AnyTensorType)
+T_OUT = TypeVar("T_OUT", bound=AnyTensorType)
+
+
+class ElementwiseUnaryOperation(IRDLOperation, abc.ABC, Generic[T_IN, T_OUT]):
+ """
+ Templated base class for elementwise unary operations.
+
+ This class provides a flexible template for unary operations that can work
+ with different tensor types.
+
+ For more informtation about the semantics, see:
+ https://openxla.org/xla/operation_semantics#element-wise_unary_functions
+ """
+
+ operand = operand_def(T_IN)
+ result = result_def(T_OUT)
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # $operand attr-dict `:` custom(type($operand), type($result))
+ # """
+
+ traits = traits_def(
+ NoMemoryEffect(),
+ SameOperandsAndResultShape(),
+ Elementwise(),
+ # TODO: InferShapedTypeOpInterface(),
+ # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(),
+ )
+
+ def __init__(self, operand: SSAValue, result_type: Attribute | None = None):
+ if result_type is None:
+ result_type = operand.type
+ super().__init__(operands=(operand,), result_types=(result_type,))
+
+
+@irdl_op_definition
+class ConvertOp(ElementwiseUnaryOperation[HLO_NonQuantizedTensor, HLO_NonQuantizedTensor]):
+ """
+ Performs an element-wise conversion from one element type to another on
+ `operand` tensor and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert
+
+ Example:
+ ```mlir
+ %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex>
+ ```
+ """
+
+ name = "stablehlo.convert"
+
+ traits = traits_def(SameOperandsAndResultShape())
+
+
+@irdl_op_definition
+class CosineOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise cosine operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine
+
+ Example:
+ ```mlir
+ %result = stablehlo.cosine %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.cosine"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+ # TODO: implement HLO_CompatibleOperandsAndResultType()
+ # traits = traits_def(
+ # HLO_CompatibleOperandsAndResultType()
+ # )
+
+
+@irdl_op_definition
+class ExponentialMinusOneOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise exponential minus one operation on `operand` tensor
+ and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one
+
+ Example:
+ ```mlir
+ %result = stablehlo.exponential_minus_one %operand : tensor<2xf64>
+ ```
+ """
+
+ name = "stablehlo.exponential_minus_one"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+ # TODO: implement HLO_CompatibleOperandsAndResultType()
+ # traits = traits_def(
+ # HLO_CompatibleOperandsAndResultType()
+ # )
+
+
+@irdl_op_definition
+class ExponentialOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise exponential operation on `operand` tensor and produces
+ a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential
+
+ Example:
+ ```mlir
+ %result = stablehlo.exponential %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.exponential"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+ # TODO: implement HLO_CompatibleOperandsAndResultType()
+ # traits = traits_def(
+ # HLO_CompatibleOperandsAndResultType()
+ # )
+
+
+@irdl_op_definition
+class FloorOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]):
+ """
+ Performs element-wise floor of `operand` tensor and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor
+
+ Example:
+ ```mlir
+ %result = stablehlo.floor %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.floor"
+
+
+@irdl_op_definition
+class ImagOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]):
+ """
+ Extracts the imaginary part, element-wise, from the `operand` and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag
+
+ Example:
+ ```mlir
+ %result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.imag"
+
+
+@irdl_op_definition
+class IsFiniteOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_PredTensor]):
+ """
+ Performs element-wise check whether the value in `x` is finite (i.e. is
+ neither +Inf, -Inf, nor NaN) and produces a `y` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite
+
+ Example:
+ ```mlir
+ %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1>
+ ```
+ """
+
+ name = "stablehlo.is_finite"
+
+
+@irdl_op_definition
+class LogOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise logarithm operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log
+
+ Example:
+ ```mlir
+ %result = stablehlo.log %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.log"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class LogPlusOneOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise logarithm plus one operation on `operand` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one
+
+ Example:
+ ```mlir
+ %result = stablehlo.log_plus_one %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.log_plus_one"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class LogisticOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise logistic operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic
+
+ Example:
+ ```mlir
+ %result = stablehlo.logistic %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.logistic"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class NegateOp(
+ ElementwiseUnaryOperation[
+ HLO_IntFpOrComplexOrQuantizedIntTensor, HLO_IntFpOrComplexOrQuantizedIntTensor
+ ]
+):
+ """
+ Performs element-wise negation of `operand` tensor and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate
+
+ Example:
+ ```mlir
+ %result = stablehlo.negate %operand : tensor<2x3xi32>
+ ```
+ """
+
+ name = "stablehlo.negate"
+
+
+@irdl_op_definition
+class RealOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]):
+ """
+ Extracts the real part, element-wise, from the `operand` and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real
+
+ Example:
+ ```mlir
+ %result = stablehlo.real %operand : tensor<2xcomplex> : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.real"
+
+
+@irdl_op_definition
+class RoundNearestAfzOp(
+ ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise rounding towards the nearest integer, breaking ties
+ away from zero, on the `operand` tensor and produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz
+
+ Example:
+ ```mlir
+ %result = stablehlo.round_nearest_afz %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.round_nearest_afz"
+
+
+@irdl_op_definition
+class RoundNearestEvenOp(
+ ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise rounding towards the nearest integer, breaking ties
+ towards the even integer, on the `operand` tensor and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even
+
+ Example:
+ ```mlir
+ %result = stablehlo.round_nearest_even %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.round_nearest_even"
+
+
+@irdl_op_definition
+class RsqrtOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise reciprocal square root operation on `operand` tensor
+ and produces a `result` tensor, implementing the `rSqrt` operation from the
+ IEEE-754 specification.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt
+
+ Example:
+ ```mlir
+ %result = stablehlo.rsqrt %operand : tensor<2x2xf32>
+ ```
+ """
+
+ name = "stablehlo.rsqrt"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class SignOp(
+ ElementwiseUnaryOperation[
+ HLO_SIntFpComplexOrQuantizedIntTensor, HLO_SIntFpComplexOrQuantizedIntTensor
+ ]
+):
+ """
+ Returns the sign of the `operand` element-wise and produces a `result`
+ tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
+
+ Example:
+ ```mlir
+ %result = stablehlo.sign %operand : tensor<5xf64>
+ ```
+ """
+
+ name = "stablehlo.sign"
+
+
+@irdl_op_definition
+class SineOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise sine operation on `operand` tensor and produces a
+ `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine
+
+ Example:
+ ```mlir
+ %result = stablehlo.sine %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.sine"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class SqrtOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise square root operation on `operand` tensor and produces
+ a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt
+
+ Example:
+ ```mlir
+ %result = stablehlo.sqrt %operand : tensor<2x2xf32>
+ ```
+ """
+
+ name = "stablehlo.sqrt"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class TanOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise tangent operation on `operand` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan
+
+ Example:
+ ```mlir
+ %result = stablehlo.tan %operand : tensor<2x2xf64>
+ ```
+ """
+
+ name = "stablehlo.tan"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
+
+
+@irdl_op_definition
+class TanhOp(
+ ElementwiseUnaryOperation[HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor]
+):
+ """
+ Performs element-wise hyperbolic tangent operation on `operand` tensor and
+ produces a `result` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh
+
+ Example:
+ ```mlir
+ %result = stablehlo.tanh %operand : tensor<2xf32>
+ ```
+ """
+
+ name = "stablehlo.tanh"
+
+ result_accuracy = opt_attr_def(
+ ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT)
+ )
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py
new file mode 100644
index 0000000000..e6f8f0542d
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/extensibility.py
@@ -0,0 +1,167 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dynamism operations for the StableHLO dialect.
+"""
+
+
+from xdsl.dialects.builtin import (
+ ArrayAttr,
+ BoolAttr,
+ DenseIntElementsAttr,
+ DictionaryAttr,
+ FlatSymbolRefAttr,
+ StringAttr,
+ TensorType,
+ TupleType,
+)
+from xdsl.ir import Attribute
+from xdsl.irdl import (
+ AnyAttr,
+ IRDLOperation,
+ irdl_op_definition,
+ opt_prop_def,
+ prop_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.traits import (
+ MemoryEffect,
+)
+from xdsl.utils.exceptions import VerifyException
+
+from .attributes import CustomCallApiVersion, CustomCallApiVersionAttr, OutputOperandAlias
+
+
+@irdl_op_definition
+class CustomCallOp(IRDLOperation):
+ """
+ Encapsulates an implementation-defined operation ``call_target_name`` that
+ takes ``inputs`` and ``called_computations`` and produces ``results``.
+
+ Depending on the API version there are two ways to pass extra bits of static
+ information to the external function:
+ 1. Use ``API_VERSION_TYPED_FFI`` which allows passing a dictionary attribute.
+ 2. Use a previous API version with a ``StringAttr`` to encode backend config.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call
+
+ Example:
+ ```mlir
+ %results = stablehlo.custom_call @foo(%input0) {
+ backend_config = {bar = 42 : i32},
+ api_version = 4 : i32,
+ called_computations = [@foo]
+ } : (tensor) -> tensor
+ ```
+ """
+
+ name = "stablehlo.custom_call"
+
+ inputs = var_operand_def(AnyAttr())
+ call_target_name = prop_def(StringAttr)
+ has_side_effect = prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))
+ backend_config = opt_prop_def(DictionaryAttr | StringAttr)
+ api_version = prop_def(
+ CustomCallApiVersionAttr,
+ default_value=CustomCallApiVersionAttr(CustomCallApiVersion.API_VERSION_ORIGINAL),
+ )
+ called_computations = opt_prop_def(ArrayAttr[FlatSymbolRefAttr], default_value=ArrayAttr([]))
+ operand_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr])
+ result_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr])
+ output_operand_aliases = prop_def(ArrayAttr[OutputOperandAlias])
+
+ result = var_result_def(AnyAttr())
+
+ traits = traits_def(
+ MemoryEffect(),
+ )
+
+ # TODO: Implement CustomDirective
+ # assembly_format = """
+ # custom($call_target_name) `(` $inputs `)`
+ # attr-dict `:` functional-type(operands, results)
+ # """
+
+ def verify_(self) -> None:
+ """Verify the CustomCallOp."""
+ # If both operand and result layout attributes are not specified then nothing to verify.
+ if self.operand_layouts is None and self.result_layouts is None:
+ return
+
+ # Layout constraints for either both operands & results or none should be specified.
+ if (self.operand_layouts is None) != (self.result_layouts is None):
+ raise VerifyException(
+ "Layout attributes should be specified for either both operands and results "
+ "or none."
+ )
+
+ assert self.operand_layouts is not None and self.result_layouts is not None
+
+ def verify_types_and_layouts(
+ types: tuple[Attribute, ...], layouts: ArrayAttr, value_name: str
+ ):
+ if len(types) != len(layouts.data):
+ raise VerifyException(
+ "Number of "
+ f"{value_name}s must match the number of {value_name} layouts, "
+ f"{len(types)} != {len(layouts.data)}"
+ )
+
+ for index, (ty, layout_attr) in enumerate(zip(types, layouts.data)):
+ # Tuple types are not fully supported with layout constraints yet
+ if isinstance(ty, TupleType):
+ raise VerifyException(
+ "Tuple types are not fully supported with layout constraints yet"
+ )
+
+ try:
+ dims = list(layout_attr.get_values())
+ except Exception as exc:
+ raise VerifyException("invalid layout attribute") from exc
+
+ # For non-tensor types, layout must be empty
+ if not isinstance(ty, TensorType):
+ if len(dims) == 0:
+ continue
+ raise VerifyException(
+ "Only tensor types can have non-empty layout: "
+ f"{value_name} #{index} of type {ty} has layout {dims}"
+ )
+
+ # For ranked tensors, require permutation of [0, rank)
+ rank = ty.get_num_dims()
+ if rank != len(dims) or sorted(dims) != list(range(rank)):
+ raise VerifyException(
+ f"incorrect layout {dims} for type {ty}, layout must be a permutation "
+ f"of [0, {rank})"
+ )
+
+ # Operand types
+ operand_types: tuple[Attribute, ...] = tuple(op.type for op in self.operands)
+
+ # Result types: if single tuple result, use its element types
+ if len(self.result_types) == 1 and isinstance(self.result_types[0], TupleType):
+ tuple_ty: TupleType = self.result_types[0]
+ result_types = tuple(tuple_ty.types.data)
+ else:
+ result_types = tuple(self.result_types)
+
+ # Verify that operands and operand layouts match.
+ verify_types_and_layouts(operand_types, self.operand_layouts, "operand")
+ # Verify that results and result layouts match.
+ verify_types_and_layouts(result_types, self.result_layouts, "result")
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py
new file mode 100644
index 0000000000..7e9bdcfdf8
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/reduction.py
@@ -0,0 +1,169 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dynamism operations for the StableHLO dialect.
+"""
+
+from xdsl.dialects.builtin import DenseArrayBase, i64
+from xdsl.irdl import (
+ IRDLOperation,
+ irdl_op_definition,
+ prop_def,
+ region_def,
+ traits_def,
+ var_operand_def,
+ var_result_def,
+)
+from xdsl.irdl.operations import SameVariadicOperandSize
+from xdsl.traits import (
+ RecursiveMemoryEffect,
+ SingleBlockImplicitTerminator,
+)
+from xdsl.utils.exceptions import VerifyException
+from xdsl_jax.dialects import stablehlo as xstablehlo
+
+from .types import HLO_Tensor
+
+
+@irdl_op_definition
+class ReduceOp(IRDLOperation):
+ """
+ Applies a reduction function ``body`` to ``inputs`` and ``init_values`` along the
+ ``dimensions`` and produces a ``result`` tensor.
+
+ See:
+ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce
+
+ Example:
+ ```mlir
+ %result = "stablehlo.reduce"(%input, %init_value) ({
+ ^bb0(%arg0: tensor, %arg1: tensor):
+ %0 = stablehlo.add %arg0, %arg1 : tensor
+ stablehlo.return %0 : tensor
+ }) {
+ dimensions = array
+ } : (tensor<1x6xi64>, tensor) -> tensor<1xi64>
+ ```
+ """
+
+ name = "stablehlo.reduce"
+
+ inputs = var_operand_def(HLO_Tensor)
+ init_values = var_operand_def(HLO_Tensor)
+ dimensions = prop_def(DenseArrayBase.constr(i64))
+ result = var_result_def(HLO_Tensor)
+ body = region_def("single_block")
+
+ irdl_options = [SameVariadicOperandSize()]
+
+ traits = traits_def(
+ RecursiveMemoryEffect(),
+ # TODO: InferShapedTypeOpInterface(),
+ # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic,
+ # TODO: InferTensorTypeWithReify(),
+ SingleBlockImplicitTerminator(xstablehlo.ReturnOp),
+ )
+
+ # pylint: disable=no-member
+ # pylint: disable=too-many-branches
+ def verify_(self):
+ """Verify the ReduceOp."""
+ # Gather shaped operand/result types
+ input_types = [op.type for op in self.inputs]
+ init_types = [op.type for op in self.init_values]
+
+ self._verify_input_and_init_types(input_types, init_types)
+
+ self._verify_reducer_region(input_types)
+
+ def _verify_input_and_init_types(self, input_types, init_types):
+ """Verify the types of the inputs and init values."""
+
+ # Basic structural checks mirroring verifyReduceOpInputsAndInferShape
+ if len(input_types) == 0:
+ raise VerifyException("expected at least 1 input for reduce")
+ if len(input_types) != len(init_types):
+ raise VerifyException("number of inputs must match number of init_values")
+
+ # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility
+ dims_attr = self.dimensions
+ dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple()
+
+ # All inputs must have equal rank; dimensions must be within rank and unique
+ # and not empty.
+ ranks = []
+ for t in input_types:
+ # Tensors by op definition
+ assert hasattr(t, "get_num_dims")
+ ranks.append(t.get_num_dims())
+ rank0 = ranks[0]
+ if any(r != rank0 for r in ranks):
+ raise VerifyException("all inputs must have the same rank")
+
+ if len(dims) == 0:
+ raise VerifyException("dimensions cannot be empty for reduce")
+ if len(set(dims)) != len(dims):
+ raise VerifyException("dimensions should not have duplicates")
+ if any(d < 0 or d >= rank0 for d in dims):
+ raise VerifyException("dimensions contains an invalid value")
+
+ # Element type compatibility between each input and its init value
+ for it, iv in zip(input_types, init_types):
+ it_elem = it.get_element_type()
+ iv_elem = iv.get_element_type()
+ if it_elem != iv_elem:
+ raise VerifyException("input and init_value must have the same element type")
+
+ def _verify_reducer_region(self, input_types):
+ """Verify the operation's reducer region."""
+
+ # reduce_c2/c6: verify reducer region shape
+ # Expect block with arity 2 * number of inputs, with matching tensor element types
+ # and 0D tensors
+ if len(self.body.blocks) != 1:
+ raise VerifyException("reducer must have a single block")
+ block = self.body.blocks[0]
+
+ expected_args = 2 * len(input_types)
+ if len(block.args) != expected_args:
+ raise VerifyException(
+ f"reducer must take {expected_args} arguments, got {len(block.args)}"
+ )
+
+ # Each pair (arg_i, arg_{i+N}) must be 0D tensors of the input element type
+ for i, it in enumerate(input_types):
+ it_elem = it.get_element_type()
+ acc = block.args[i]
+ val = block.args[i + len(input_types)]
+ for a in (acc, val):
+ a_ty = a.type
+ if not hasattr(a_ty, "get_num_dims") or a_ty.get_num_dims() != 0:
+ raise VerifyException("reducer arguments must be rank-0 tensors")
+ if a_ty.get_element_type() != it_elem:
+ raise VerifyException(
+ "reducer argument element types must match input element type"
+ )
+
+ # Region must terminate with exactly len(inputs) results
+ ret = block.ops.last
+ if len(ret.operands) != len(input_types):
+ raise VerifyException("reducer must return exactly one value per input")
+ for i, it in enumerate(input_types):
+ it_elem = it.get_element_type()
+ rty = ret.operands[i].type
+ if not hasattr(rty, "get_num_dims") or rty.get_num_dims() != 0:
+ raise VerifyException("reducer return values must be rank-0 tensors")
+ if rty.get_element_type() != it_elem:
+ raise VerifyException("reducer return element types must match input element type")
diff --git a/frontend/catalyst/python_interface/dialects/stablehlo/types.py b/frontend/catalyst/python_interface/dialects/stablehlo/types.py
new file mode 100644
index 0000000000..a27f4fc8f2
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/stablehlo/types.py
@@ -0,0 +1,247 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+StableHLO type definitions for PennyLane's compiler infrastructure.
+
+This module provides type definitions based on the StableHLO specification
+(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including
+token types and other necessary type definitions for StableHLO operations.
+"""
+
+from typing import TypeAlias
+
+from xdsl.dialects.builtin import (
+ AnyFloatConstr,
+ ComplexType,
+ Float32Type,
+ Float64Type,
+ IndexType,
+ IntAttr,
+ IntAttrConstraint,
+ IntegerType,
+ ParametrizedAttribute,
+ Signedness,
+ SignednessAttr,
+ TensorType,
+ i1,
+)
+from xdsl.irdl import eq, irdl_attr_definition
+from xdsl.irdl.attributes import EqAttrConstraint, ParamAttrConstraint
+from xdsl.irdl.constraints import IntSetConstraint
+from xdsl_jax.dialects.stablehlo import TokenType
+
+from catalyst.python_interface.xdsl_extras.constraints import (
+ NestedTupleOfConstraint,
+)
+
+
+def _create_param_constrained_type(
+ base_attr: type, widths: list[int], signedness: Signedness | None = None
+):
+ """Create an integer type constrained using ParamAttrConstraint with IntSetConstraint."""
+ width_constraint = IntAttrConstraint(IntSetConstraint(frozenset(widths)))
+
+ if signedness is None:
+ signedness_constraint = None
+ else:
+ signedness_constraint = EqAttrConstraint(SignednessAttr(signedness))
+
+ return ParamAttrConstraint(base_attr, [width_constraint, signedness_constraint])
+
+
+# =============================================================================
+# Core StableHLO types constraints
+# =============================================================================
+
+HLO_Pred = eq(i1)
+HLO_PredTensor: TypeAlias = TensorType[HLO_Pred]
+
+# NOTE: IntegerType is defined in the StableHLO spec as:
+# IntegerType ::= SignedIntegerType | UnsignedIntegerType,
+# but the MLIR implementation is using signless integers instead of signed,
+# and there is a TODO to fix it.
+
+_HLO_INT_WIDTHS = [2, 4, 8, 16, 32, 64]
+HLO_SignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.SIGNED)
+HLO_UnsignedInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, Signedness.UNSIGNED)
+HLO_SignlessInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, None)
+
+HLO_Int: TypeAlias = HLO_UnsignedInt | HLO_SignlessInt
+HLO_IntTensor: TypeAlias = TensorType[HLO_Int]
+
+_HLO_INT_OR_PRED_WIDTHS = [1, 2, 4, 8, 16, 32, 64]
+HLO_IntOrPred = _create_param_constrained_type(IntegerType, _HLO_INT_OR_PRED_WIDTHS, None)
+
+
+HLO_AnyIntegerOrIndex: TypeAlias = IntegerType | IndexType
+HLO_AnyIntegerOrIndexTensor: TypeAlias = TensorType.constr(HLO_AnyIntegerOrIndex)
+
+HLO_DimensionValue: TypeAlias = HLO_Int | IndexType
+
+# Constraint variants for use in unions with ParamAttrConstraint
+HLO_Float: TypeAlias = AnyFloatConstr
+HLO_Float32Or64: TypeAlias = Float32Type | Float64Type
+HLO_FloatTensor: TypeAlias = TensorType.constr(HLO_Float)
+HLO_Fp32Or64Tensor: TypeAlias = TensorType.constr(HLO_Float32Or64)
+
+# Complex as a constraint over element types {f32,f64}
+HLO_Complex: TypeAlias = ComplexType[HLO_Float32Or64]
+HLO_ComplexTensor: TypeAlias = TensorType.constr(HLO_Complex)
+
+# =============================================================================
+# Quantized element type definitions
+# =============================================================================
+
+
+@irdl_attr_definition
+class UniformQuantizedType(ParametrizedAttribute):
+ """
+ Placeholder for StableHLO per-tensor uniform quantized types.
+
+ Parameterized by width to support different quantized integer widths
+ (e.g., 8-bit, 16-bit quantization).
+ """
+
+ name = "stablehlo.uniform_quantized"
+ width: IntAttr
+ signedness: SignednessAttr
+
+
+@irdl_attr_definition
+class UniformQuantizedPerAxisType(ParametrizedAttribute):
+ """
+ Placeholder for StableHLO per-axis uniform quantized types.
+
+ Parameterized by width to support different quantized integer widths
+ (e.g., 8-bit, 16-bit quantization).
+ """
+
+ name = "stablehlo.uniform_quantized_per_axis"
+ width: IntAttr
+ signedness: SignednessAttr
+
+
+# =============================================================================
+# StableHLO quantized type aliases
+# =============================================================================
+
+_HLO_QUANTIZED_WIDTHS = [2, 4, 8, 16, 32]
+
+# Constraint-based types for operation definitions
+HLO_QuantizedSignedInt = _create_param_constrained_type(
+ UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED
+)
+HLO_QuantizedUnsignedInt = _create_param_constrained_type(
+ UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED
+)
+HLO_QuantizedAnySignednessInt = _create_param_constrained_type(
+ UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, None
+)
+HLO_QuantizedInt: TypeAlias = HLO_QuantizedSignedInt | HLO_QuantizedUnsignedInt
+
+HLO_PerAxisQuantizedSignedInt = _create_param_constrained_type(
+ UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED
+)
+HLO_PerAxisQuantizedUnsignedInt = _create_param_constrained_type(
+ UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED
+)
+HLO_PerAxisQuantizedAnySignednessInt = _create_param_constrained_type(
+ UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, None
+)
+HLO_PerAxisQuantizedInt: TypeAlias = HLO_PerAxisQuantizedSignedInt | HLO_PerAxisQuantizedUnsignedInt
+
+# =============================================================================
+# Main tensor type definitions
+# =============================================================================
+
+HLO_Tensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt]
+HLO_NonQuantizedTensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred]
+
+# Note: There is a discrepancy between the StableHLO spec and the MLIR implementation.
+# The spec does not allow unranked tensors, but the MLIR implementation
+# defines it as a tensor of any type and rank. There is a TODO to fix this in MLIR.
+# Therefore, we use the correct ranked tensor type.
+HLO_AnyTensor: TypeAlias = TensorType[
+ HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt | HLO_PerAxisQuantizedInt
+]
+HLO_TensorOrToken: TypeAlias = HLO_Tensor | TokenType
+HLO_TensorOrPerAxisQuantizedTensorOrToken: TypeAlias = HLO_AnyTensor | TokenType
+
+# HLO_AnyTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Token]>
+HLO_AnyTuple = NestedTupleOfConstraint([HLO_AnyTensor, TokenType])
+
+HLO_CustomCallValue: TypeAlias = HLO_Tensor | TokenType | HLO_AnyTuple
+
+# =============================================================================
+# HLO combined type definitions
+# =============================================================================
+
+HLO_PredOrIntTensor: TypeAlias = TensorType.constr(HLO_IntOrPred)
+
+HLO_FpOrComplexTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_Complex)
+HLO_FpOrQuantizedIntTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_QuantizedInt)
+HLO_FpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr(
+ HLO_Float | HLO_Complex | HLO_QuantizedInt
+)
+HLO_IntFpOrComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr(
+ HLO_Int | HLO_Float | HLO_Complex | HLO_QuantizedInt
+)
+HLO_SIntFpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr(
+ HLO_SignedInt | HLO_Float | HLO_Complex | HLO_QuantizedInt
+)
+
+
+__all__ = [
+ # Core types
+ "HLO_Pred",
+ "HLO_PredTensor",
+ "HLO_Int",
+ "HLO_IntTensor",
+ "HLO_AnyIntegerOrIndex",
+ "HLO_AnyIntegerOrIndexTensor",
+ "HLO_DimensionValue",
+ "HLO_Float",
+ "HLO_Float32Or64",
+ "HLO_FloatTensor",
+ "HLO_Fp32Or64Tensor",
+ "HLO_ComplexTensor",
+ "HLO_SignedInt",
+ "HLO_UnsignedInt",
+ "HLO_SignlessInt",
+ "HLO_QuantizedSignedInt",
+ "HLO_QuantizedUnsignedInt",
+ "HLO_QuantizedAnySignednessInt",
+ "HLO_QuantizedInt",
+ "HLO_PerAxisQuantizedSignedInt",
+ "HLO_PerAxisQuantizedUnsignedInt",
+ "HLO_PerAxisQuantizedAnySignednessInt",
+ "HLO_PerAxisQuantizedInt",
+ # Quantized types
+ "UniformQuantizedType",
+ "UniformQuantizedPerAxisType",
+ "HLO_Tensor",
+ "HLO_NonQuantizedTensor",
+ "HLO_AnyTensor",
+ "HLO_TensorOrToken",
+ "HLO_TensorOrPerAxisQuantizedTensorOrToken",
+ "HLO_CustomCallValue",
+ # Combined types
+ "HLO_PredOrIntTensor",
+ "HLO_FpOrComplexTensor",
+ "HLO_FpOrQuantizedIntTensor",
+ "HLO_FpComplexOrQuantizedIntTensor",
+ "HLO_IntFpOrComplexOrQuantizedIntTensor",
+ "HLO_SIntFpComplexOrQuantizedIntTensor",
+]
diff --git a/frontend/catalyst/python_interface/dialects/transform.py b/frontend/catalyst/python_interface/dialects/transform.py
new file mode 100644
index 0000000000..32449e7d5e
--- /dev/null
+++ b/frontend/catalyst/python_interface/dialects/transform.py
@@ -0,0 +1,122 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This file contains an updated version of the transform dialect.
+As of the time of writing, xDSL uses the MLIR released with LLVM's
+version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916
+will be updating MLIR where the transform dialect has the
+`apply_registered_pass` operation re-defined.
+
+See the following changelog on the above PR
+
+ Things related to transform.apply_registered_pass op:
+
+ It now takes in a dynamic_options
+
+ [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as
+ a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static.
+ https://github.com/llvm/llvm-project/pull/142683
+
+ The options it takes in are now dictionaries instead of strings
+ [MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159
+ https://github.com/llvm/llvm-project/pull/143159
+
+This file will re-define the apply_registered_pass operation in xDSL
+and the transform dialect.
+
+Once xDSL moves to a newer version of MLIR, these changes should
+be contributed upstream.
+"""
+
+from xdsl.dialects.builtin import Dialect
+from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp
+from xdsl.dialects.transform import (
+ DictionaryAttr,
+ StringAttr,
+)
+from xdsl.dialects.transform import Transform as xTransform
+from xdsl.dialects.transform import (
+ TransformHandleType,
+ irdl_op_definition,
+ operand_def,
+ prop_def,
+ result_def,
+)
+from xdsl.ir import Attribute, SSAValue
+from xdsl.irdl import IRDLOperation, ParsePropInAttrDict
+
+
+# pylint: disable=line-too-long
+@irdl_op_definition
+class ApplyRegisteredPassOp(IRDLOperation):
+ """
+ See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop).
+ """
+
+ name = "transform.apply_registered_pass"
+
+ options = prop_def(DictionaryAttr, default_value=DictionaryAttr({}))
+ pass_name = prop_def(StringAttr)
+ target = operand_def(TransformHandleType)
+ result = result_def(TransformHandleType)
+ # While this assembly format doesn't match
+ # the one in upstream MLIR,
+ # this is because xDSL currently lacks CustomDirectives
+ # https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives
+ # https://github.com/xdslproject/xdsl/pull/4829
+ # However, storing the property in the attribute should still work
+ # specially when parsing and printing in generic format.
+ # Which is how Catalyst and XDSL currently communicate at the moment.
+ # TODO: Add test.
+ assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)"
+ irdl_options = [ParsePropInAttrDict()]
+
+ def __init__(
+ self,
+ pass_name: str | StringAttr,
+ target: SSAValue,
+ options: dict[str | StringAttr, Attribute | str | bool | int] | None = None,
+ ):
+ if isinstance(pass_name, str):
+ pass_name = StringAttr(pass_name)
+
+ if isinstance(options, dict):
+ options = DictionaryAttr(options)
+
+ super().__init__(
+ properties={
+ "pass_name": pass_name,
+ "options": options,
+ },
+ operands=[target],
+ result_types=[target.type],
+ )
+
+
+# Copied over from xDSL's sources
+# the main difference will be the use
+# of a different ApplyRegisteredPassOp
+operations = list(xTransform.operations)
+del operations[operations.index(xApplyRegisteredPassOp)]
+operations.append(ApplyRegisteredPassOp)
+
+Transform = Dialect(
+ "transform",
+ [
+ *operations,
+ ],
+ [
+ *xTransform.attributes,
+ ],
+)
diff --git a/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst
new file mode 100644
index 0000000000..25e2d23c57
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst
@@ -0,0 +1,1375 @@
+Unified Compiler Cookbook
+=========================
+
+**Note:** The cookbook is developed with the following package versions,
+on Python 3.12.11:
+
+.. code-block:: bash
+
+ jax==0.6.2
+ jaxlib==0.6.2
+ numpy==2.3.1
+ pennylane==0.44.0-dev19
+ pennylane-lightning==0.43.0
+ pennylane-catalyst==0.14.0-dev15
+ xdsl==0.53.0
+ xdsl-jax==git+https://github.com/xdslproject/xdsl-jax.git@895f7c13e8d0f02bbe99d7fb9ebcaafea4ea629f#egg=xdsl_jax
+
+Note that ``xdsl-jax`` does not currently have a release published on
+PyPI, so it needs to be installed from GitHub by running the following:
+
+.. code-block:: bash
+
+ pip install git+https://github.com/xdslproject/xdsl-jax.git
+
+Motivation
+==========
+
+As we approach FTQC, quantum compilation becomes a more and more
+important research area. Catalyst uses MLIR as its intermediate
+representation, which is also the layer in which a majority of the
+optimizations happen. However, quantum compilation researchers are not
+likely to be accustomed with MLIR or C++.
+
+So, the motivation of the “Unified Compiler” is to provide a Python
+layer in which compilation passes can be implemented and applied.
+Additionally, we also want to enable researchers to use abstractions
+that they are familiar with. We’re aiming to do this using xDSL, which
+is a reimplementation of MLIR in Python.
+
+This document is meant to be a quickstart for users that are interested
+in developing compiler passes for Catalyst, but are not familiar with
+MLIR or xDSL.
+
+MLIR basics
+===========
+
+If readers are already familiar to some degree with MLIR, this section
+can be skipped.
+
+SSA
+---
+
+“In compiler design, static single assignment form (often abbreviated as
+SSA form or simply SSA) is a type of intermediate representation (IR)
+where each variable is assigned exactly once” [`1 <#references>`__].
+
+SSA is powerful because it allows us to define chains of uses and
+definitions, i.e, we can keep track of which operation created a
+variable (since each variable only gets created once), and all
+operations that use that variable. These chains of uses and definitions,
+if used well, can make transformations quite performant, and in some
+cases, very simple to implement, but they can also make the IR harder to
+parse.
+
+IR structure
+------------
+
+MLIR represents programs using a hierarchical, graph-like structure. In
+this structure, nodes are called *operations*, and edges are called
+*values*. Each value is the result of exactly one operation or argument
+for a block (more on that later), and has a type that is defined by the
+type system (more on that later also) [`2 <#references>`__].
+
+The IR is recursively nested - operations may contain regions, which
+contain blocks, which contain operations. More concretely, an operation
+may contain zero or more regions, each of which must contain one or more
+blocks, each of which holds a list of arguments and an ordered list of
+operations that may use those arguments [`3 <#references>`__].
+
+Operations
+~~~~~~~~~~
+
+Operations are basic units of execution. Operations are fully
+extensible, i.e., there is no fixed list of operations. Operations can
+return zero or more results, take in zero or more operands, declare
+properties and attributes, and can have zero or more successors and
+regions [`2 <#references>`__].
+
+Regions
+~~~~~~~
+
+A region is an ordered list of blocks. Its semantics are defined by the
+operation that contains them [`2 <#references>`__]. For example, an
+``scf.IfOp``, which represents a conditional operation, contains two
+regions, one for the true branch, and another for the false branch, and
+the way we interpret these regions is dependent on the fact that they
+belong to the ``scf.IfOp``.
+
+Blocks
+~~~~~~
+
+Blocks are lists of operations. The operations inside blocks are
+executed in order. Blocks take a list of block arguments, annotated in a
+function-like way. The first block in a region is special, and is called
+the “entry block”. The block arguments of the entry block are also
+arguments to its outer region [`2 <#references>`__].
+
+Values
+~~~~~~
+
+In MLIR, there are computable values with a type, a single defining
+operation, and zero or more uses [`4 <#references>`__]. These values are
+either the results of operations or block arguments, and adhere to SSA
+semantics.
+
+Dialects
+--------
+
+MLIR uses dialects to allow developers to define a set of high level
+operations, attributes, and types, which can be used in the intermediate
+representation to represent custom subroutines, etc. and can be
+converted into lower level representations using interpretation rules
+that can be defined. For example, the MLIR ``arith`` dialect defines
+operations for arithmetic computations, the ``linalg`` dialect defines
+linear algebra operations, etc. We use dialects to define quantum
+instructions such as gates, measurements, and attributes such as qubits,
+quantum registers.
+
+Def-use and use-def chains
+--------------------------
+
+Frequently, we might have a value and we want to determine which
+instructions use this value. The list of all users of a value is the
+*def-use chain*. Conversely, we might have an instruction and we want to
+determine which values it uses. The list of values used by an
+instruction is the *use-def chain* [`5 <#references>`__]. These chains
+allow us to iterate through operations topologically, which can be very
+powerful when implementing passes.
+
+xDSL API
+========
+
+Now that we’re familiar with the high-level constructs that MLIR uses,
+let’s go over what their xDSL actual implementation looks like.
+
+``SSAValue``
+------------
+
+``SSAValue`` is the class used to represent variables. SSA values may
+used by operations as operands, and returned as results. The three key
+properties that this class provides are listed below:
+
+- ``type``: The type of the value. Since SSA values are variables, their
+ value is not known at compile time, but their type is.
+- ``uses``: A set of all operations that use a given ``SSAValue`` as an
+ operand. The operations are wrapped around a convenience class called
+ ``Use``, and the corresponding operation can be accessed using
+ ``Use.operation``.
+- ``owner``: The operation or block that defined a given ``SSAValue``
+ (more on that later)
+
+``SSAValue`` has two subclasses, which will be seen a lot when
+inspecting xDSL modules. These are:
+
+- ``OpResult``: This subclass represents SSA values that are defined by
+ an operation
+
+ .. code-block::
+
+ c = add a b
+
+ In the above pseudocode, ``c`` is defined by the operation ``add``,
+ and in xDSL, will be represented as an ``OpResult`` instance.
+
+- ``BlockArgument``: This subclass represents SSA values that are block
+ arguments
+
+``Attribute``
+-------------
+
+An attribute defines a compile-time value. These can be used to define
+types, and can also be used to define properties of operations that are
+known at compile time. For example:
+
+- ``CustomOp``, which is an operation from the ``Quantum`` dialect (more
+ on that later) has a ``gate_name`` that must be a string, represented
+ by the ``StringAttr`` attribute. ``StringAttr`` has a reference to the
+ concrete string representing the gate name, and we can access this
+ concrete value at compile-time using ``CustomOp.gate_name.data``.
+- ``CustomOp`` also takes qubits as inputs and outputs, which are of
+ type ``QubitType``. The ``QubitType`` class inherits ``Attribute``,
+ but we use it to declare a type that can be used to define SSA values.
+
+Below is the definition of ``QubitType`` to illustrate:
+
+.. code-block:: python
+
+ # TypeAttribute just means that this attribute can be used to represent
+ # the types of the operands and results of an operation, but not its
+ # properties.
+ # ParametrizedAttribute means that this attribute can take parameters as
+ # input (although QubitType doesn't have any parameters).
+ @irdl_attr_definition
+ class QubitType(ParametrizedAttribute, TypeAttribute):
+ """A value-semantic qubit (state)."""
+
+ name = "quantum.bit"
+
+``Operation``
+-------------
+
+The ``Operation`` class is used to represent operations, which are basic
+units of execution. All instructions in a module are operations. In
+fact, the modules themselves are operations. Operations contain several
+fields used to define their form and function:
+
+- Operands: operands are runtime values that the operation consumes.
+ Note that when defining an operation, we only declare the *types* of
+ the operands, not the actual operands. Only when we construct an
+ instance of an operation do we provide actual ``SSAValue``\ s as the
+ operands, and these must adhere to the type system.
+- Properties: properties are compile-time values used to define the
+ semantics of an operation. For example, the ``CustomOp`` operation
+ that is used to define quantum gates in Catalyst has a property called
+ ``gate_name``, which is a string specifier of the gate’s name. This
+ name directly impacts how the operation should be interpreted, and its
+ value is known at compile-time.
+- Attributes: Operation attributes are stored as a dictionary,
+ containing more compile-time values. Generally, these don’t get used
+ much in xDSL, but they serve a purpose very similar to properties.
+- Result types: operations may return values, and if so, the types of
+ the return values must be defined.
+
+Below is the definition of ``CustomOp`` from the ``Quantum`` dialect,
+which represents general quantum gates to illustrate what defining
+operations looks like:
+
+.. code-block:: python
+
+ @irdl_op_definition
+ class CustomOp(IRDLOperation):
+ """A generic quantum gate on n qubits with m floating point parameters."""
+
+ name = "quantum.custom"
+
+ # assembly_format defines what the operation should look like
+ # when pretty-printed in the IR
+ assembly_format = """
+ $gate_name `(` $params `)` $in_qubits
+ (`adj` $adjoint^)?
+ attr-dict
+ ( `ctrls` `(` $in_ctrl_qubits^ `)` )?
+ ( `ctrlvals` `(` $in_ctrl_values^ `)` )?
+ `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )?
+ """
+
+ # These options are used because we have operands whose lengths aren't
+ # known (eg. different instances of CustomOp may have different number
+ # of qubits depending on the gate they represent (2 for CNOT, 1 for
+ # RX, etc.). These options basically say, "when the operation instance
+ # is initialized, create 2 properties that store the length of each of
+ # the different groups of operands and results.
+ irdl_options = [
+ AttrSizedOperandSegments(as_property=True),
+ AttrSizedResultSegments(as_property=True),
+ ]
+
+ # var_operand_def means that the length of this operand
+ # can vary.
+ params = var_operand_def(EqAttrConstraint(Float64Type()))
+
+ in_qubits = var_operand_def(BaseAttr(QubitType))
+
+ # prop_def means that gate_name is a required property
+ gate_name = prop_def(BaseAttr(StringAttr))
+
+ # opt_prop_def means adjoint is an optional property. Additionally,
+ # it's type is a UnitAttr(), which essentially means that a given
+ # instance of CustomOp is an adjoint gate iff it has an adjoint property.
+ # The value of the property is irrelevant; it only gets meaning from its
+ # existance.
+ adjoint = opt_prop_def(EqAttrConstraint(UnitAttr()))
+
+ in_ctrl_qubits = var_operand_def(BaseAttr(QubitType))
+
+ in_ctrl_values = var_operand_def(EqAttrConstraint(IntegerType(1)))
+
+ # var_result_def means that the length of out_qubits can vary
+ out_qubits = var_result_def(BaseAttr(QubitType))
+
+ out_ctrl_qubits = var_result_def(BaseAttr(QubitType))
+
+.. _dialects-1:
+
+Dialects
+--------
+
+The ``Dialect`` class in xDSL is used as a container around a list of
+operations, types, and attributes, and it also declares a name for the
+dialect. At the time of writing this, there are currently 4 custom
+dialects available in the xDSL layer of Catalyst:
+
+- ``Quantum``: this dialect contains operations and attributes necessary
+ for general qubit-level operations, such as gates, measurements,
+ qubits, etc.
+- ``Catalyst``: this dialect contains operations and attributes for
+ classical computing features unavailable out of the box with xDSL/MLIR
+- ``QEC``: This dialect contains operations and attributes useful for
+ QEC, such as PPRs/PPMs.
+- ``MBQC``: This dialect contains operations and attributes for
+ representing MBQC formalism.
+
+Pass API
+--------
+
+xDSL provides an API for defining and applying transformations on
+programs (or modules), which is described below:
+
+``ModulePass``
+~~~~~~~~~~~~~~
+
+``ModulePass`` is used to create rewrite passes over an IR module. It is
+the parent class used to define compiler passes (or transforms).
+``ModulePass`` has two key fields that must be implemented
+
+- ``name``: This is the name that is used to reference the pass.
+- ``apply``: This method takes a ``ModuleOp`` as input and applies the
+ rewrite pattern of the pass to the module. Note that this mutates the
+ input module *in-place* rather than returning a new, transformed
+ module.
+
+``RewritePattern``
+~~~~~~~~~~~~~~~~~~
+
+``RewritePattern`` is the class that provides the API for pattern
+matching. The most important method of this class is
+``match_and_rewrite``. The first argument to this method is an
+operation, and it must be type-hinted using the specific operation we’re
+trying to match. This type hint gets used by xDSL to match the operation
+we’re trying to rewrite. The second argument is a ``PatternRewriter``
+instance. I will cover this class in detail below, but it is essentially
+the class that provides the API for rewriting the operations that we’re
+matching.
+
+For example, if I wanted to match all Hadamard gates, I would use
+``CustomOp`` from the ``Quantum`` dialect in the type hint for the first
+argument (since there is no ``Hadamard`` operation in the ``Quantum``
+dialect), and check in the body of the method if the op is a
+``Hadamard``:
+
+.. code-block:: python
+
+ from xdsl import pattern_rewriter
+ from catalyst.python_interface.dialects.quantum import CustomOp
+
+ class MyPattern(pattern_rewriter.RewritePattern):
+ """Dummy class for example."""
+
+ # This decorator is what xDSL uses to match operations
+ # based on the type hint.
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(
+ self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter
+ ):
+ if op.gate_name.data != "Hadamard":
+ # If not Hadamard, we do nothing
+ return
+
+ # Do whatever we need
+
+``PatternRewriter``
+~~~~~~~~~~~~~~~~~~~
+
+``PatternRewriter`` is the class that provides the API for rewriting the
+IR. It includes several methods for replacing/removing/updating
+operations, replacing values, replacing uses of a value with another
+value, etc. In most cases, any rewriting that users want to do must be
+done through this API rather than manually, as it includes state
+management for keeping track of whether any changes were made, which is
+necessary for the worklist algorithm (more on that later).
+
+Some key methods are:
+
+- ``replace_op``: Replaces one operation with another.
+- ``replace_all_uses_with``: Replaces all uses of one value with
+ another.
+- ``erase_op``: Erases an operation. If this operation returns any
+ values, all uses of these values must be updated accordingly before
+ the erasure.
+- ``notify_op_modified``: Method to notify the rewriter that a change
+ was made to an operation manually.
+
+The example below shows us implementing a ``RewritePattern`` that
+updates all ``Hadamard``\ s with ``PauliX``\ s:
+
+.. code-block:: python
+
+ from xdsl import pattern_rewriter
+ from xdsl.dialects import builtin
+ from catalyst.python_interface.dialects.quantum import CustomOp
+
+ class HToXPattern(pattern_rewriter.RewritePattern):
+ """Dummy class for example."""
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(
+ self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter
+ ):
+ if op.gate_name.data != "Hadamard":
+ # If not Hadamard, we do nothing
+ return
+
+ # Update the gate name to PauliX and notify the rewriter that
+ # the op was manually updated
+ op.gate_name = builtin.StringAttr("PauliX")
+ rewriter.notify_op_modified(op)
+
+ # Alternatively, we could also create a new CustomOp for
+ # the PauliX from scratch, and replace the Hadamard with
+ # the new op:
+ # new_op = CustomOp(
+ # gate_name="PauliX",
+ # in_qubits=op.in_qubits,
+ # )
+ # rewriter.replace_op(op, new_op)
+
+``PatternRewriteWalker``
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+``PatternRewriteWalker`` walks over the IR in depth-first order, and
+applies a provided ``RewritePattern`` to it. By default, it implements a
+worklist algorithm that keeps iterating over the operations and matching
+and rewriting until a steady state is reached (i.e. no new changes are
+detected; this is why ``PatternRewriter`` needs to keep track of whether
+any changes were made).
+
+Putting everything together, we can create a ``ModulePass`` that
+replaces all ``Hadamard``\ s with ``PauliX``\ s
+
+.. code-block:: python
+
+ from xdsl import passes, pattern_rewriter
+ from xdsl.dialects import builtin
+ from catalyst.python_interface.dialects.quantum import CustomOp
+ from catalyst.python_interface import compiler_transform
+
+ class HToXPattern(pattern_rewriter.RewritePattern):
+ """Dummy class for example."""
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(
+ self, op: CustomOp, rewriter: pattern_rewriter.PatternRewriter
+ ):
+ if op.gate_name.data != "Hadamard":
+ # If not Hadamard, we do nothing
+ return
+
+ # Update the gate name to PauliX and notify the rewriter that
+ # the op was manually updated
+ op.gate_name = builtin.StringAttr("PauliX")
+ rewriter.notify_op_modified(op)
+
+ class HToXPass(passes.ModulePass):
+ """Pass that replaces Hadamards with PauliXs"""
+
+ name = "h-to-x"
+
+ def apply(self, ctx, module):
+ """Apply the iterative pass."""
+ walker = pattern_rewriter.PatternRewriteWalker(
+ pattern=HToXPattern()
+ )
+ walker.rewrite_module(module)
+
+ # We will cover this later
+ h_to_x_pass = compiler_transform(HToXPass)
+
+``PassPipeline``
+~~~~~~~~~~~~~~~~
+
+``PassPipeline`` is a meta-pass that takes a sequence of
+``ModulePass``\ es as input, and applies them to the input module. The
+following example shows how a sequence of ``ModulePass``\ s can be
+applied to a module ``mod``:
+
+.. code-block:: python
+
+ from xdsl import passes
+
+ pipeline = passes.PassPipeline((Pass1(), Pass2(), Pass3()))
+ pipeline.apply(xdsl.context.Context(), mod)
+
+To complete the example we’ve been building in this section, let’s put
+it all together and implement a ``PassPipeline`` to apply the
+``HToXPass`` to an xDSL module.
+
+Let’s first create the module to which we want to apply the pass. For
+this, we will use the ``xdsl_from_qjit`` utility, which is described in
+the “PennyLane integration” section below.
+
+- **Creating the module**
+
+ .. code-block:: python
+
+ import pennylane as qml
+ from catalyst.python_interface.conversion import xdsl_from_qjit
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @xdsl_from_qjit
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit():
+ qml.Hadamard(0)
+ qml.Hadamard(1)
+ qml.Hadamard(2)
+ return qml.state()
+
+ >>> mod = circuit()
+ >>> print(mod)
+ builtin.module @circuit {
+ func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex>
+ func.return %0 : tensor<8xcomplex>
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor
+ %3 = quantum.alloc(3) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = quantum.custom "Hadamard"() %5 : !quantum.bit
+ %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor
+ %8 = tensor.extract %7[] : tensor
+ %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit
+ %10 = quantum.custom "Hadamard"() %9 : !quantum.bit
+ %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor
+ %12 = tensor.extract %11[] : tensor
+ %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit
+ %14 = quantum.custom "Hadamard"() %13 : !quantum.bit
+ %15 = tensor.extract %0[] : tensor
+ %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit
+ %17 = tensor.extract %7[] : tensor
+ %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit
+ %19 = tensor.extract %11[] : tensor
+ %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit
+ %21 = quantum.compbasis qreg %20 : !quantum.obs
+ %22 = quantum.state %21 : tensor<8xcomplex>
+ quantum.dealloc %20 : !quantum.reg
+ quantum.device_release
+ func.return %22 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ }
+
+- **Transforming the module**
+
+ In the above module, there are 3 ``CustomOp``\ s, each with gate name
+ ``Hadamard``. Let’s try applying our pass to it. Bear in mind that
+ passes update modules in-place:
+
+ .. code-block:: python
+
+ from xdsl import passes
+
+ pipeline = passes.PassPipeline((HToXPass(),))
+ pipeline.apply(xdsl.context.Context(), mod)
+
+ >>> print(mod)
+ builtin.module @circuit {
+ func.func public @jit_circuit() -> (tensor<8xcomplex>) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex>
+ func.return %0 : tensor<8xcomplex>
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> (tensor<8xcomplex>) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<3> : tensor}> : () -> tensor
+ %3 = quantum.alloc(3) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = quantum.custom "PauliX"() %5 : !quantum.bit
+ %7 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor
+ %8 = tensor.extract %7[] : tensor
+ %9 = quantum.extract %3[%8] : !quantum.reg -> !quantum.bit
+ %10 = quantum.custom "PauliX"() %9 : !quantum.bit
+ %11 = "stablehlo.constant"() <{value = dense<2> : tensor}> : () -> tensor
+ %12 = tensor.extract %11[] : tensor
+ %13 = quantum.extract %3[%12] : !quantum.reg -> !quantum.bit
+ %14 = quantum.custom "PauliX"() %13 : !quantum.bit
+ %15 = tensor.extract %0[] : tensor
+ %16 = quantum.insert %3[%15], %6 : !quantum.reg, !quantum.bit
+ %17 = tensor.extract %7[] : tensor
+ %18 = quantum.insert %16[%17], %10 : !quantum.reg, !quantum.bit
+ %19 = tensor.extract %11[] : tensor
+ %20 = quantum.insert %18[%19], %14 : !quantum.reg, !quantum.bit
+ %21 = quantum.compbasis qreg %20 : !quantum.obs
+ %22 = quantum.state %21 : tensor<8xcomplex>
+ quantum.dealloc %20 : !quantum.reg
+ quantum.device_release
+ func.return %22 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ }
+
+ Great! We can see that all the ``Hadamard``\ s have been replaced with
+ ``PauliX``\ s, just how we wanted.
+
+PennyLane integration
+=====================
+
+This section will cover the API in the ``catalyst.python_interface``
+submodule.
+
+Lowering to MLIR
+----------------
+
+Catalyst compiles programs using the following workflow, when program
+capture is enabled:
+
+- Trace user function to create ``ClosedJaxpr`` (plxpr)
+- Convert plxpr to value-semantic jaxpr
+- Lower value-semantic jaxpr to MLIR
+- Apply passes defined in a static pipeline to the MLIR
+
+ - These passes include optimizations, further lowering to more
+ elementary dialects, and lowering to LLVM
+
+- Generate machine code
+
+The integration with the xDSL layer happens after we lower to MLIR. We
+currently rely on JAX’s API to lower to MLIR. This has the special
+effect of lowering to a specific dialect called StableHLO, which is used
+to represent all arithmetic operations present in the program.
+
+Once lowered to MLIR, if the original ``qjit`` decorator specified the
+xDSL pass plugin, we pass control over to the xDSL layer, which applies
+all transforms that were requested by the user. We can request the use
+of the xDSL plugin like so:
+
+.. code-block:: python
+
+ from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
+
+ @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
+ ...
+
+.. _ir-structure-1:
+
+IR structure
+~~~~~~~~~~~~
+
+The lowered MLIR has a lot of structure to aid developers, which is
+described below using the following example:
+
+.. code-block:: python
+
+ import pennylane as qml
+
+ qml.capture.enable()
+ dev = qml.device("lightning.qubit", wires=1)
+
+ @qml.qjit
+ @qml.transforms.cancel_inverses
+ @qml.transforms.merge_rotations
+ @qml.qnode(dev)
+ def circuit():
+ qml.X(0)
+ return qml.state()
+
+>>> print(circuit.mlir)
+module @circuit {
+ func.func public @jit_circuit() -> tensor<2xcomplex> attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<2xcomplex>
+ return %0 : tensor<2xcomplex>
+ }
+ module @module_circuit {
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ %0 = transform.apply_registered_pass "merge-rotations" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ %1 = transform.apply_registered_pass "cancel-inverses" to %0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> tensor<2xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %c0_i64 = arith.constant 0 : i64
+ quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %0 = quantum.alloc( 1) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.state %3 : tensor<2xcomplex>
+ quantum.dealloc %2 : !quantum.reg
+ quantum.device_release
+ return %4 : tensor<2xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+}
+
+- The program is represented as a module with a function inside it. This
+ function contains non-QNode user code (which there is none in our
+ example for simplicity), and calls to QNodes using the
+ ``catalyst.launch_kernel`` operation.
+- QNodes are represented as modules as well—each QNode has its own
+ module. These modules have 4 key components:
+
+ - The ``module attributes {transform.with_named_sequence}`` contains
+ the transforms that the user requested for the QNode. In our
+ example, those are the ``merge_rotations`` and ``cancel_inverses``
+ transforms.
+ - The ``@circuit`` function represents the body of the QNode. Inside
+ this body, we initialize a device using the specified shots,
+ allocate a quantum register, and then apply both quantum and
+ classical instructions. Note that quantum instructions can *only* be
+ present inside this function. Note that this function will contain
+ an attribute called ``qnode``.
+
+SSA in the ``Quantum`` dialect
+------------------------------
+
+In the ``Quantum`` dialect, as mentioned earlier, gates are represented
+using the ``CustomOp`` operation. This operation accepts ``in_qubits``
+and ``in_ctrl_qubits``, which are variable length sequences of
+``QubitType`` attributes, which correspond to wire indices.
+``CustomOp``\ s also return ``out_qubits`` and ``out_ctrl_qubits`` of
+the same type.
+
+Before a ``QubitType`` can be used by any gates, it must be extracted
+from the quantum register, or a ``QregType``. Quantum registers can be
+thought of as a sequence containing all valid wire indices that are
+available to be used by gates/measurements. Qubits can be extracted from
+and inserted into a quantum register using the ``ExtractOp`` and
+``InsertOp`` operations. An ``AllocOp`` is used to allocate a quantum
+register with the user-provided number of device wires.
+
+Let’s take a look at a very simple example. Below, I have a very simple
+circuit with two gates applied to the same wire. Let’s take a look at
+its MLIR representation:
+
+- **Example**
+
+ .. code-block:: python
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit():
+ qml.X(0)
+ qml.H(0)
+ return qml.state()
+
+ >>> print(circuit.mlir)
+ module @circuit {
+ func.func public @jit_circuit() -> tensor<8xcomplex> attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<8xcomplex>
+ return %0 : tensor<8xcomplex>
+ }
+ module @module_circuit {
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit() -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %c0_i64 = arith.constant 0 : i64
+ quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %0 = quantum.alloc( 3) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit
+ %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit
+ %3 = quantum.compbasis qreg %2 : !quantum.obs
+ %4 = quantum.state %3 : tensor<8xcomplex>
+ quantum.dealloc %2 : !quantum.reg
+ quantum.device_release
+ return %4 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+ }
+
+**Notes**
+
+- ``quantum.alloc`` initializes a quantum register (``%0``) with 3
+ wires, which is the number of wires used to create the PennyLane
+ device.
+- ``quantum.extract`` extracts a qubit corresponding to wire index 0
+ (``%1``)
+
+ - ``%1`` is used as input by the ``X`` gate.
+
+- The ``X`` gate returns a qubit (``%out_qubits``), which is used by the
+ ``quantum.custom`` corresponding to the ``H`` gate.
+
+ - Note that this is different from how the circuit is defined in
+ Python. Instead of just using ``%1`` again, the ``H`` gate uses
+ ``%out_qubits``.
+
+- The ``H`` gate returns a new qubit (``%out_qubits_0``). This qubit is
+ consumed by ``quantum.insert``, which inserts updates the quantum
+ register to essentially say that ``%out_qubits_0`` should be the new
+ qubit that corresponds to wire index 0.
+- In this example, note that ``%1``, ``%out_qubits``, and
+ ``%out_qubits_0`` all correspond to wire index 0. This has cool
+ implications, that I will discuss below.
+
+Dynamic wires
+~~~~~~~~~~~~~
+
+The lowering rules for Catalyst automatically handle dynamic wires. When
+a new dynamic wire, ``w``, is used, all wires used before it are first
+inserted into the quantum register using ``InsertOp``. Only then does
+the ``QubitType`` corresponding to ``w`` get extracted from the quantum
+register using ``ExtractOp``. Consider the following example:
+
+- **Example**
+
+ .. code-block:: python
+
+ import pennylane as qml
+
+ qml.capture.enable()
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit(w1: int, w2: int):
+ qml.X(0)
+ qml.Y(w1)
+ qml.Z(w1)
+ qml.S(w2)
+ qml.T(w1)
+ qml.H(0)
+ return qml.state()
+
+ >>> print(circuit.mlir)
+ module @circuit {
+ func.func public @jit_circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg0, %arg1) : (tensor, tensor) -> tensor<8xcomplex>
+ return %0 : tensor<8xcomplex>
+ }
+ module @module_circuit {
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0: tensor, %arg1: tensor) -> tensor<8xcomplex> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %c0_i64 = arith.constant 0 : i64
+ quantum.device shots(%c0_i64) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %0 = quantum.alloc( 3) : !quantum.reg
+ %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits = quantum.custom "PauliX"() %1 : !quantum.bit
+ %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
+ %extracted = tensor.extract %arg0[] : tensor
+ %3 = quantum.extract %2[%extracted] : !quantum.reg -> !quantum.bit
+ %out_qubits_0 = quantum.custom "PauliY"() %3 : !quantum.bit
+ %out_qubits_1 = quantum.custom "PauliZ"() %out_qubits_0 : !quantum.bit
+ %extracted_2 = tensor.extract %arg0[] : tensor
+ %4 = quantum.insert %2[%extracted_2], %out_qubits_1 : !quantum.reg, !quantum.bit
+ %extracted_3 = tensor.extract %arg1[] : tensor
+ %5 = quantum.extract %4[%extracted_3] : !quantum.reg -> !quantum.bit
+ %out_qubits_4 = quantum.custom "S"() %5 : !quantum.bit
+ %extracted_5 = tensor.extract %arg1[] : tensor
+ %6 = quantum.insert %4[%extracted_5], %out_qubits_4 : !quantum.reg, !quantum.bit
+ %extracted_6 = tensor.extract %arg0[] : tensor
+ %7 = quantum.extract %6[%extracted_6] : !quantum.reg -> !quantum.bit
+ %out_qubits_7 = quantum.custom "T"() %7 : !quantum.bit
+ %extracted_8 = tensor.extract %arg0[] : tensor
+ %8 = quantum.insert %6[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit
+ %9 = quantum.extract %8[ 0] : !quantum.reg -> !quantum.bit
+ %out_qubits_9 = quantum.custom "Hadamard"() %9 : !quantum.bit
+ %10 = quantum.insert %8[ 0], %out_qubits_9 : !quantum.reg, !quantum.bit
+ %11 = quantum.compbasis qreg %10 : !quantum.obs
+ %12 = quantum.state %11 : tensor<8xcomplex>
+ quantum.dealloc %10 : !quantum.reg
+ quantum.device_release
+ return %12 : tensor<8xcomplex>
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+ }
+
+ **Notes**
+
+ - If a qubit is inserted into the quantum register, it is not reused
+ again, and a new qubit corresponding to the same wire label must be
+ extracted from the register.
+ - When a dynamic wire is going to be used for the first time, all
+ qubits that have been used previously without being re-inserted into
+ the quantum register must be inserted.
+ - If a dynamic wire is reused before any other wires, then it does not
+ need to be inserted to and extracted from the quantum register
+ again.
+ - To use static wires after dynamic wires, the dynamic wires are again
+ re-inserted into the quantum register.
+ - All of the above make it such that dynamic wires essentially create
+ barriers that break the qubit def-use chains. This causes some
+ functionality to be lost, but makes sure that we’re using qubits
+ safely.
+
+Implications/notes
+~~~~~~~~~~~~~~~~~~
+
+- Operations on the same wires can be tracked quickly using the chain of
+ definitions and uses of the qubits (def-use chains).
+- Operations on dynamic wires (i.e wires whose values are only known at
+ run-time) are handled automatically and we don’t need to worry about
+ managing how we work around them. For context, this is an issue that
+ we found in the plxpr variant of ``cancel_inverses``, where two
+ operations on the same wire that were separated by another operation
+ on a dynamic wire would get cancelled, which is incorrect
+ (`source `__).
+- One thing to keep in mind is that qubits (``QubitType``) and quantum
+ registers (``QregType``) are there so that we conform to SSA semantics
+ in a way that works well for our purposes. They get meaning from how
+ we choose to interpret them. We could just as easily have defined the
+ ``Quantum`` dialect and Catalyst’s lowering rules to use wire indices
+ the same way we do in Python, but we may lose capabilities that MLIR
+ and xDSL enable through the SSA form.
+
+``compiler_transform``
+----------------------
+
+``compiler_transform`` is the function used to register xDSL
+``ModulePass``\ es to be used with ``qjit``-ed workflows. It is
+currently accessible as
+``catalyst.python_interface.compiler_transform``.
+
+.. code-block:: python
+
+ from catalyst.python_interface import compiler_transform
+
+ class MyPass(xdsl.passes.ModulePass):
+ """MyPass that does something"""
+
+ name = "my-pass"
+
+ def apply(self, ctx, module):
+ # Apply the pass to module
+ return
+
+ my_pass = compiler_transform(MyPass)
+
+ # Program capture must be enabled to use the compiler transform
+ # as a decorator
+ qml.capture.enable()
+ dev = qml.device("lightning.qubit", wires=1)
+
+ @qml.qjit(
+ pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()]
+ )
+ @my_pass
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ circuit(1.5)
+
+The ``compiler_transform`` function returns an object that gives easy
+access to the underlying ``ModulePass``, as well as its name as seen by
+the compiler.
+
+>>> my_pass.module_pass
+__main__.MyPass
+>>> my_pass.name
+'my-pass'
+
+Additionally, we don’t need to manually apply passes using
+``PassPipeline`` when decorating QNodes with registered compiler
+transforms. Those transforms get applied automatically when the workflow
+is compiled!
+
+Conversion utilities
+--------------------
+
+The ``catalyst.python_interface.conversion`` submodule provides several utilities
+for creating xDSL modules from Python functions. There are many more
+utilities in the submodule, but I will focus on the most important ones,
+and provide examples of how to use them.
+
+- ``xdsl_module(jitted_fn: Callable) -> Callable[Any, [xdsl.dialects.builtin.ModuleOp]]``:
+ Create a wrapper around a ``jax.jit``-ed function that returns an xDSL
+ module
+- ``xdsl_from_qjit(qjitted_fn: QJIT) -> Callable[..., xbuiltin.ModuleOp]``:
+ Create a wrapper around a ``qjit``-ed function that returns an xDSL
+ module. This is currently not merged to ``master``
+- ``inline_module(from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None) -> None``:
+ This function takes two modules as input, and inlines the whole body
+ of the first module into the second. Additionally, if
+ ``change_main_to`` is provided, it looks for a function named
+ ``main``, and updates its name to ``change_main_to``
+- ``inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp, *args, **kwargs) -> None``:
+ This function takes a ``jax.jit``-ed function, converts it into an
+ xDSL module, and then inlines the contents of the xDSL module into
+ ``mod``. Note that this function does not return anything; instead, it
+ modifies ``mod`` in-place.
+
+Check out the :doc:`xDSL conversion utilities tutorial <./xdsl_utils_tutorial>` to see examples of how each of
+the utilities can be used.
+
+Useful patterns
+===============
+
+Now that we have gone over compilers, xDSL, and how it’s being used in
+Catalyst, let’s take a look at some common patterns that might be
+useful.
+
+Post-processing functions
+-------------------------
+
+Post-processing functions are purely classical, so we can leverage the
+``xdsl_module`` utility function to create xDSL modules from Python
+code, and inject it into the modules we are rewriting as needed. The
+:doc:`xDSL post-processing tutorial <./xdsl_post_processing>` shows an
+example where we perform very simple post-processing on a QNode that
+returns an expectation value by squaring the expectation value.
+
+Splitting tapes
+---------------
+
+When implementing passes that may transform our module such that
+multiple device executions are required (akin to tape transforms that
+return multiple tapes), there are various strategies that can be used.
+Because there is no guarantee of shared structure between the tapes,
+there is no one perfect strategy that can be used for all transforms.
+I’ll use tapes to provide details below about some common cases:
+
+- If all the tapes are identical (eg. ``dynamic_one_shot``), then the
+ entire execution can be put inside a ``for`` loop, with
+ post-processing on the execution outputs done how I showed in the
+ above notebook.
+- If all gates are identical, but measurements are different (eg.
+ ``split_non_commuting``), we can capture all gates in a single
+ function, and then use a ``for`` loop that iterates over the number of
+ tapes. Within this ``for`` loop, we would call the aforementioned
+ function, which would evolve the state, and apply a different
+ measurement inside each iteration of the loop. Post-processing can be
+ handled same as above.
+- If there is little/no shared structure between the tapes, we would
+ need separate functions for each of the transformed tapes. We would
+ need to call each function one by one, and then use the results for
+ post-processing.
+
+All of the above are very non-trivial. I will leave out code examples
+for now, as that may be unnecessarily time consuming. If we get to the
+stage where we need to write a transform that splits into multiple
+tapes, we can revisit this section and the unified compiler/compilation
+team can assist in developing such transforms.
+
+Note
+~~~~
+
+Currently, we don’t have a consistent API for implementing transforms
+that split QNodes into multiple device executions (eg.
+``split_non_commuting``, etc.). Thus, it is *very* hard to implement
+transforms that do that. We cannot guarantee that one transform that
+does split QNodes into multiple device executions will work well when
+there are other transforms in the pipeline.
+
+Writing tests
+=============
+
+**Note to readers**: this section is written based on how the testing
+infrastructure for the unified compiler exists in Catalyst currently.
+However, this infrastructure may change in the future.
+
+FileCheck
+---------
+
+`FileCheck `__ is a
+pattern matching file verifier developed by the LLVM project and is
+widely used to test MLIR. Since xDSL and MLIR are syntactically
+identical, we can use the string representation of xDSL modules for
+testing with FileCheck.
+
+Below is an example of how an MLIR/xDSL string may look when populated
+with directives that FileCheck can use for testing. In this example, we
+have a void function ``test_func`` that takes no arguments, creates two
+qubits (corresponding to different wires), and applies a ``PauliX`` to
+each of these wires. We can see that there are 4 comments starting with
+``CHECK`` - these comments are what FileCheck uses to match the expected
+string against the actual string.
+
+The number of ``CHECK`` comments does not need to be the same as the
+number of operations in the program - we can see below that there is no
+``CHECK`` for ``func.func`` and ``return``.
+
+``CHECK`` statements can assign parameters that can be reused by later
+``CHECK``\ s. Below, the first two ``CHECK``\ s create ``q0`` and
+``q1``, which are matched using the regular expression ``%.+``, which
+matches any expressions that start with ``%`` and contain at least one
+character after it. The latter 2 ``CHECK``\ s then use ``q0`` and ``q1``
+as the input qubits for the assertion.
+
+``CHECK`` statements can contain partial statements. Below, the last two
+``CHECK``\ s don’t include the outputs of the ``quantum.custom``
+operations.
+
+.. code-block:: python
+
+ program = """
+ func.func @test_func() {
+ // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
+ // CHECK: [[q1:%.+]] = "test.op"() : () -> !quantum.bit
+ %0 = "test.op"() : () -> !quantum.bit
+ %1 = "test.op"() : () -> !quantum.bit
+ // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit
+ // CHECK: quantum.custom "PauliX"() [[q1]] : !quantum.bit
+ %2 = quantum.custom "PauliX"() %0 : !quantum.bit
+ %3 = quantum.custom "PauliX"() %0 : !quantum.bit
+ return
+ }
+ """
+
+FileCheck essentially uses the MLIR generated by a program and asserts
+it against directives that can be specified by the user. Commonly used
+directives are:
+
+- ``CHECK``: This is used to check if a specified pattern is found. Its
+ syntax is very simple: they are fixed strings that must occur in
+ order, and horizontal whitespace is ignored by default.
+- ``CHECK-NOT``: This directive is used to check that the provided
+ string does not occur between two matches, before the first match, or
+ after the last match.
+- ``CHECK-DAG``: This directive may be used when it’s necessary to match
+ strings that don’t have to occur in the same order, but it is able to
+ match valid topological orderings of the program DAG, with edges from
+ the definition to the uses of a variable.
+- ``CHECK-NEXT``: This directive is used to check that the matched line
+ occurs right after the previously matched line with no other lines in
+ between them.
+- ``CHECK-SAME``: This directive is used when we want to match lines and
+ would like to verify that matches happen on the same line as the
+ previous match.
+
+To find more details about the above directives, or to learn about other
+available directives, please refer to the `FileCheck
+documentation `__.
+
+Test dialect
+------------
+
+xDSL provides a ``Test`` dialect
+(`source `__),
+which contains many operations that are useful for unit testing. In our
+testing, we found the ``TestOp`` operation to be the most useful. This
+operation can produce arbitrary results, which we can use to limit
+artificial dependencies on other dialects.
+
+For example, if I just need to assert that a specific gate is present,
+without ``TestOp``, I would need my module to contain an ``AllocOp``
+that creates a quantum register, an ``ExtractOp`` that extracts a qubit
+from the register, and only then I can use the qubit for the gate I’m
+trying to match. Instead, I can just insert a ``TestOp`` that returns a
+qubit and use that.
+
+This is very powerful for unit testing, as it makes writing tests much
+simpler, while also limiting the scope of the test as one would expect
+for unit tests.
+
+In the code block in the previous section, ``TestOp`` has been used in
+exactly the described way - we use it to create 2 qubits that are then
+used as input for the 2 ``PauliX`` gates.
+
+.. _pennylane-integration-1:
+
+PennyLane integration
+---------------------
+
+To use FileCheck with ``pytest``, we use the ```filecheck`` Python
+package `__, which allows us to use
+assertions for testing in a way that ``pytest`` can understand. All of
+the ``filecheck`` API has been captured inside two fixtures available
+within the ``tests/python_interface`` folder:
+
+- ``run_filecheck``: This fixture is for unit testing. One can specify a
+ program along with filecheck directives as a multi-line string.
+- ``run_filecheck_qjit``: This fixture is for integration testing. One
+ can create a normal ``qml.qjit``-ed workflow and include filecheck
+ directives as in-line comments.
+
+Let’s write tests for the ``HToXPass`` that was implemented in the
+`“Pass API” <#pass-api>`__ sub-section to illustrate. The dev comments
+will explain what is going on.
+
+``run_filecheck`` example
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ def test_h_to_x_pass(run_filecheck):
+ """Test that Hadamard gets converted into PauliX."""
+ # The original program creates a qubit, and gives it to a
+ # CustomOp that is a Hadamard. The CHECK directives check that
+ # the transformed program has a CustomOp that is a PauliX applied
+ # to the same qubit, and no CustomOp that is a Hadamard.
+
+ # Below we also see how we can use `test.op`, which we use to
+ # create a qubit to give to the Hadamard.
+ program = """
+ func.func @test_func() {
+ // CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
+ %0 = "test.op"() : () -> !quantum.bit
+ // CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit
+ // CHECK-NOT: quantum.custom "Hadamard"
+ %1 = quantum.custom "Hadamard"() %0 : !quantum.bit
+ return
+ }
+ """
+
+ # First, we must create the pass pipeline that we want to apply
+ # to our program.
+ pipeline = (HToXPass(),)
+
+ # Next, we use run_filecheck to run filecheck testing on the
+ # program. The fixture will create an xDSL module from the program
+ # string, call the filecheck API, and assert correctness.
+ run_filecheck(program, pipeline)
+
+ # Optionally, there are two keyword arguments that can modify the
+ # behaviour of run_filecheck. Both the roundtrip and verify
+ # arguments are false by default.
+ run_filecheck(program, pipeline, roundtrip=True, verify=True)
+
+ # roundtrip=True makes it so that after parsing the program string
+ # to an xDSL module, we print it as a string again, and then parse
+ # it back into an xDSL module. This is useful when writing tests
+ # for dialects, so that we can check that the dialects can be printed
+ # parsed correctly.
+
+ # verify=True simply runs `module.verify()`, which iteratively uses
+ # xDSL's verifiers to verify all operations and attributes in the
+ # module.
+
+``run_filecheck_qjit`` example
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
+
+ def test_h_to_x_pass_integration(run_filecheck_qjit):
+ """Test that Hadamard gets converted into PauliX."""
+ # The original program simply applies a Hadamard to a circuit
+ # Remember that we need to specify the pass plugin. Additionally,
+ # with qjit, we can use the decorator created using
+ # `compiler_transform`. To make sure that the xDSL API works
+ # correctly, program capture must be enabled.
+ # qml.capture.enable()
+ @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath])
+ @h_to_x_pass
+ def circuit():
+ # CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
+ # CHECK: quantum.custom "PauliX"() [[q0]] : !quantum.bit
+ # CHECK-NOT: quantum.custom "Hadamard"
+ qml.Hadamard(0)
+ return qml.state()
+
+ # Finally, we use the run_filecheck_qjit fixture. We pass it our
+ # original qjitted workflow. It extracts the filecheck directives
+ # from the workflow, creates an MLIR program, parses it to xDSL,
+ # applies the specified transforms, and uses the filecheck API to
+ # assert correctness.
+ run_filecheck_qjit(circuit)
+
+ # run_filecheck_qjit also accepts a boolean `verify` argument
+ # that is false by default, which works exactly the same way as
+ # the `verify` argument of `run_filecheck`.
+ run_filecheck_qjit(circuit, verify=True)
+
+Key blockers
+============
+
+There are several blockers that are currently disabling developers from
+taking full advantage of the ``catalyst.python_interface`` submodule. These
+include:
+
+* Lack of support for quantum subroutines. This impacts pattern
+ matching passes that need to substitute the matched operation(s) with
+ subroutines containing quantum instructions.
+
+Strategies to circumvent blockers
+---------------------------------
+
+* We can use dummy subroutines for now. We know what the inputs and
+ outputs of these subroutines should be, so we can create our own
+ ``FuncOp``\ s that adhere to the input/output spec and just have
+ their body be empty for now. To see an example where we create a
+ dummy quantum subroutine and use it to develop a pass, check out the
+ :doc:`xDSL subroutines tutorial <./xdsl_dummy_quantum_subroutines>`.
+
+Suggested reading
+=================
+
+Useful dialects
+---------------
+
+- ``scf``: Structured control flow
+- ``func``: Functions
+- ``builtin``: Core types and attributes
+- ``arith``: arithmetic operations
+- ``stablehlo``: Advanced match dialect
+
+References
+==========
+
+#. Wikimedia Foundation. (2025, August 11). *Static single-assignment
+ form*. Wikipedia.
+ https://en.wikipedia.org/wiki/Static_single-assignment_form
+#. *MLIR Language Reference*. MLIR. (n.d.).
+ https://mlir.llvm.org/docs/LangRef/
+#. *Understanding the IR Structure*. MLIR. (n.d.-b).
+ https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/
+#. *Mlir::Value class reference*. MLIR. (n.d.-b).
+ https://mlir.llvm.org/doxygen/classmlir_1_1Value.html
+#. *LLVM Programmer’s Manual*. LLVM. (n.d.).
+ https://llvm.org/docs/ProgrammersManual.html
diff --git a/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst
new file mode 100644
index 0000000000..a1e3a53324
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/xdsl_dummy_quantum_subroutines.rst
@@ -0,0 +1,216 @@
+.. code-block:: python
+
+ from dataclasses import dataclass
+
+ import pennylane as qml
+ from catalyst.python_interface.conversion import xdsl_from_qjit
+ from catalyst.python_interface.dialects.quantum import CustomOp, QubitType
+
+ from xdsl import context, passes, pattern_rewriter
+ from xdsl.builder import ImplicitBuilder
+ from xdsl.dialects import builtin, func
+ from xdsl.ir import Block, Region
+
+Convert into xDSL module
+========================
+
+.. code-block:: python
+
+ dev = qml.device("lightning.qubit", wires=5)
+
+ @xdsl_from_qjit
+ @qml.qjit(target="mlir")
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.H(0)
+ return qml.expval(qml.Z(0))
+
+
+>>> qjit_mod = circuit(1.5)
+>>> print(qjit_mod)
+builtin.module @circuit {
+ func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ func.return %0 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = quantum.custom "Hadamard"() %5 : !quantum.bit
+ %7 = quantum.namedobs %6[PauliZ] : !quantum.obs
+ %8 = quantum.expval %7 : f64
+ %9 = tensor.from_elements %8 : tensor
+ %10 = tensor.extract %0[] : tensor
+ %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit
+ quantum.dealloc %11 : !quantum.reg
+ quantum.device_release
+ func.return %9 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+}
+
+
+Let’s create a quantum subroutine
+=================================
+
+This subroutine’s purpose is to replace Hadamard gates with a blackbox
+(which will be empty for now), so it must take a single qubit as its
+argument. Additionally, we are assuming that the qubits on which the
+subroutine will act are not just the ones that the subroutine takes as
+input, so we must also provide the quantum register as input, and also
+give it as the output.
+
+``FuncOp``\ s have a single region with a single block that contains the
+body of the function.
+
+We need to build the function’s body by populating its inner ``Block``
+with operations. We can do so using the ``xdsl.builder.ImplicitBuilder``
+class. This class can be used as a context manager that takes a
+``Block`` as input, and any operations created within the context of the
+builder get added to the block. Let’s try it out.
+
+Here, we create a subroutine that we will use to replace
+``Hadamard``\ s. This subroutine applies a gate provided by the user,
+and returns the ``out_qubit`` of the gate.
+
+.. code-block:: python
+
+ def create_hadamard_replacement_subroutine(gate_name):
+ input_types = (QubitType(),)
+ output_types = (QubitType(),)
+ block = Block(arg_types=input_types)
+
+ with ImplicitBuilder(block):
+ in_qubits = [block.args[0]]
+ op1 = CustomOp(in_qubits=in_qubits, gate_name=gate_name)
+ func.ReturnOp(op1.out_qubits[0])
+
+ region = Region([block])
+ funcOp = func.FuncOp("replace_hadamard", (input_types, output_types), region=region)
+ return funcOp
+
+
+>>> funcOp = create_hadamard_replacement_subroutine("S")
+>>> print(funcOp)
+func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit {
+ %1 = quantum.custom "S"() %0 : !quantum.bit
+ func.return %1 : !quantum.bit
+}
+
+
+Now, we write a pass to do the substitution
+===========================================
+
+.. code-block:: python
+
+ class ReplaceHadamardPattern(pattern_rewriter.RewritePattern):
+
+ def __init__(self, subroutine: func.FuncOp):
+ self.subroutine = subroutine
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(self, customOp: CustomOp, rewriter: pattern_rewriter.PatternRewriter):
+ if customOp.gate_name.data != "Hadamard":
+ return
+
+ callOp = func.CallOp(
+ builtin.SymbolRefAttr("replace_hadamard"),
+ [customOp.in_qubits[0]],
+ self.subroutine.function_type.outputs.data,
+ )
+ rewriter.insert_op_after_matched_op(callOp)
+ rewriter.replace_all_uses_with(customOp.out_qubits[0], callOp.results[0])
+ rewriter.erase_op(customOp)
+
+
+ @dataclass(frozen=True)
+ class ReplaceHadamardPass(passes.ModulePass):
+ name = "replace-hadamard"
+ gate_name: str
+
+ def apply(self, ctx: context.Context, module: builtin.ModuleOp):
+ funcOp = create_hadamard_replacement_subroutine(self.gate_name)
+ module.regions[0].blocks.first.add_op(funcOp)
+
+ pattern_rewriter.PatternRewriteWalker(
+ pattern_rewriter.GreedyRewritePatternApplier([ReplaceHadamardPattern(funcOp)])
+ ).rewrite_module(module)
+
+Let’s see it in action
+======================
+
+Here, we will replace all ``Hadamard``\ s with ``S``\ s
+
+.. code-block:: python
+
+ ctx = context.Context()
+
+ pipeline = passes.PassPipeline((ReplaceHadamardPass("S"),))
+ pipeline.apply(ctx, qjit_mod)
+
+Great! We can see below that ``Hadamard`` was replaced by a call to
+``replace_hadamard``, which applies a single ``S`` gate.
+
+>>> print(qjit_mod)
+builtin.module @circuit {
+ func.func public @jit_circuit(%arg2 : tensor) -> (tensor) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ func.return %0 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = func.call @replace_hadamard(%5) : (!quantum.bit) -> !quantum.bit
+ %7 = quantum.namedobs %6[PauliZ] : !quantum.obs
+ %8 = quantum.expval %7 : f64
+ %9 = tensor.from_elements %8 : tensor
+ %10 = tensor.extract %0[] : tensor
+ %11 = quantum.insert %3[%10], %6 : !quantum.reg, !quantum.bit
+ quantum.dealloc %11 : !quantum.reg
+ quantum.device_release
+ func.return %9 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ func.func @replace_hadamard(%0 : !quantum.bit) -> !quantum.bit {
+ %1 = quantum.custom "S"() %0 : !quantum.bit
+ func.return %1 : !quantum.bit
+ }
+}
diff --git a/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst
new file mode 100644
index 0000000000..b9ca0ac3ae
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/xdsl_post_processing.rst
@@ -0,0 +1,225 @@
+Simple tutorial for injecting functions into xDSL modules
+=========================================================
+
+.. code-block:: python
+
+ from dataclasses import dataclass
+ import jax
+
+ import pennylane as qml
+ from catalyst.python_interface.conversion import inline_module, xdsl_from_qjit, xdsl_module
+
+ from xdsl import context, passes, pattern_rewriter
+ from xdsl.dialects import builtin, func
+ from xdsl.traits import SymbolTable
+ from xdsl.rewriter import InsertPoint
+
+Create workflow and convert to xDSL module
+==========================================
+
+.. code-block:: python
+
+ @xdsl_from_qjit
+ @qml.qjit(target="mlir")
+ def workflow(x, y):
+ dev = qml.device("lightning.qubit", wires=5)
+
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ res = circuit(x)
+ return res - y
+
+
+>>> xmod = workflow(3.5, 4.5)
+>>> print(xmod)
+builtin.module @workflow {
+ func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ %1 = "stablehlo.convert"(%arg3) : (tensor) -> tensor
+ %2 = "stablehlo.subtract"(%0, %1) : (tensor, tensor) -> tensor
+ func.return %2 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = tensor.extract %arg0[] : tensor
+ %7 = quantum.custom "RX"(%6) %5 : !quantum.bit
+ %8 = quantum.namedobs %7[PauliZ] : !quantum.obs
+ %9 = quantum.expval %8 : f64
+ %10 = tensor.from_elements %9 : tensor
+ %11 = tensor.extract %0[] : tensor
+ %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit
+ quantum.dealloc %12 : !quantum.reg
+ quantum.device_release
+ func.return %10 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+}
+
+
+Now, let’s try creating a pass that squares the output of the qnode
+===================================================================
+
+To do so, we can use the ``inline_module`` utility to easily add our
+post-processing function into the module we’re transforming. First, we
+create the function that squares the input value and turn it into an
+xDSL module.
+
+.. code-block:: python
+
+ @jax.jit
+ def square(x):
+ return x * x
+
+
+>>> square_mod = xdsl_module(square)(1.5)
+>>> print(square_mod)
+builtin.module @jit_square attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) {
+ %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor
+ func.return %0 : tensor
+ }
+}
+
+
+.. code-block:: python
+
+ def is_kernel_launch(op):
+ return op.name == "catalyst.launch_kernel"
+
+
+ class SquarePattern(pattern_rewriter.RewritePattern):
+
+ @pattern_rewriter.op_type_rewrite_pattern
+ def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter):
+ # We only rewrite the function that calls the qnode, and the caller of the qnode will
+ # always have catalyst.launch_kernel present. Additionally, we only rewrite the caller
+ # if it hasn't already been rewritten. We can put a UnitAttr() inside the caller's
+ # attributes to indicate whether it has been rewritten or not.
+ if funcOp.attributes.get("transformed") == builtin.UnitAttr() or not any(
+ is_kernel_launch(op) for op in funcOp.body.ops
+ ):
+ return
+
+ # Update funcOp to inidicate that it has been rewritten
+ funcOp.attributes["transformed"] = builtin.UnitAttr()
+
+ # Insert square into the module
+ mod = funcOp.parent_op()
+ inline_module(square_mod, mod, change_main_to="square")
+ square_fn = SymbolTable.lookup_symbol(mod, "square")
+
+ # Call square_fn and use its results instead of the qnode's results for
+ # the rest of the function
+ for op in funcOp.body.walk():
+ if is_kernel_launch(op):
+ callOp = func.CallOp(
+ builtin.SymbolRefAttr(square_fn.sym_name),
+ op.results,
+ square_fn.function_type.outputs.data,
+ )
+ rewriter.insert_op(callOp, InsertPoint.after(op))
+
+ # We have inserted a CallOp that takes the output of the qnode as input. Let's call
+ # the qnode output %0, and the CallOp output %1. The following replaces all uses of
+ # %0 with %1 EXCEPT for the case where %0 is an input to callOp
+ op.results[0].replace_by_if(callOp.results[0], lambda use: use.operation != callOp)
+ rewriter.notify_op_modified(funcOp)
+
+
+ @dataclass(frozen=True)
+ class SquarePass(passes.ModulePass):
+ name = "square"
+
+ def apply(self, ctx: context.Context, module: builtin.ModuleOp):
+ pattern_rewriter.PatternRewriteWalker(
+ pattern_rewriter.GreedyRewritePatternApplier([SquarePattern()])
+ ).rewrite_module(module)
+
+Let’s apply the pass to our workflow
+====================================
+
+.. code-block:: python
+
+ ctx = context.Context()
+
+ pipeline = passes.PassPipeline((SquarePass(),))
+ pipeline.apply(ctx, xmod)
+
+Great! Let’s see what the transformed module looks like
+=======================================================
+
+As you can see below, the ``square_xdsl`` function is the first function
+in the module, and it gets called by ``jit_workflow``, and its
+inputs/outputs are consistent with the behaviour we wanted.
+
+>>> print(xmod)
+builtin.module @workflow {
+ func.func public @jit_workflow(%arg2 : tensor, %arg3 : tensor) -> (tensor) attributes {llvm.emit_c_interface, transformed} {
+ %0 = catalyst.launch_kernel @module_circuit::@circuit(%arg2) : (tensor) -> tensor
+ %1 = func.call @square(%0) : (tensor) -> tensor
+ %2 = "stablehlo.convert"(%arg3) : (tensor) -> tensor
+ %3 = "stablehlo.subtract"(%1, %2) : (tensor, tensor) -> tensor
+ func.return %3 : tensor
+ }
+ builtin.module @module_circuit {
+ builtin.module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.op<"builtin.module">) {
+ transform.yield
+ }
+ }
+ func.func public @circuit(%arg0 : tensor) -> (tensor) attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} {
+ %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> tensor
+ %1 = tensor.extract %0[] : tensor
+ quantum.device shots(%1) ["/Users/mudit.pandey/.pyenv/versions/pennylane-xdsl/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
+ %2 = "stablehlo.constant"() <{value = dense<5> : tensor}> : () -> tensor
+ %3 = quantum.alloc(5) : !quantum.reg
+ %4 = tensor.extract %0[] : tensor
+ %5 = quantum.extract %3[%4] : !quantum.reg -> !quantum.bit
+ %6 = tensor.extract %arg0[] : tensor
+ %7 = quantum.custom "RX"(%6) %5 : !quantum.bit
+ %8 = quantum.namedobs %7[PauliZ] : !quantum.obs
+ %9 = quantum.expval %8 : f64
+ %10 = tensor.from_elements %9 : tensor
+ %11 = tensor.extract %0[] : tensor
+ %12 = quantum.insert %3[%11], %7 : !quantum.reg, !quantum.bit
+ quantum.dealloc %12 : !quantum.reg
+ quantum.device_release
+ func.return %10 : tensor
+ }
+ }
+ func.func @setup() {
+ quantum.init
+ func.return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ func.return
+ }
+ func.func public @square(%arg0 : tensor) -> (tensor {jax.result_info = "result"}) {
+ %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor
+ func.return %0 : tensor
+ }
+}
diff --git a/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst
new file mode 100644
index 0000000000..2e39c07404
--- /dev/null
+++ b/frontend/catalyst/python_interface/doc/xdsl_utils_tutorial.rst
@@ -0,0 +1,404 @@
+Unified compiler utilities
+=========================
+
+All utilities we care about are in the
+``catalyst.python_interface.conversion`` submodule.
+
+.. code-block:: python
+
+ import pennylane as qml
+
+ from catalyst.python_interface.conversion import (
+ inline_jit_to_module,
+ inline_module,
+ xdsl_from_qjit,
+ xdsl_module,
+ )
+
+``xdsl_module``
+===============
+
+This function takes a ``jax.jit``-ed function as input, and returns a
+wrapper. This wrapper can be called to return an xDSL module. Note that
+this function is intended to be used to covert purely classical
+functions into xDSL modules. Let’s take a look at a very simple example:
+
+.. code-block:: python
+
+ import jax
+
+
+ @jax.jit
+ def inner(x):
+ return x**2
+
+
+ @jax.jit
+ def outer(x, y):
+ return inner(x) - y
+
+
+>>> wrapped_outer = xdsl_module(outer)
+>>> jit_mod = wrapped_outer(1.5, 2.5)
+>>> print(jit_mod)
+builtin.module @jit_outer attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg1 : tensor, %arg2 : tensor) -> (tensor {jax.result_info = "result"}) {
+ %0 = func.call @inner(%arg1) : (tensor) -> tensor
+ %1 = "stablehlo.subtract"(%0, %arg2) : (tensor, tensor) -> tensor
+ func.return %1 : tensor
+ }
+ func.func private @inner(%arg0 : tensor) -> (tensor) {
+ %0 = "stablehlo.multiply"(%arg0, %arg0) : (tensor, tensor) -> tensor
+ func.return %0 : tensor