From 53df1bce8ea42bb325ab7e45391d1c322304ae98 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 24 Oct 2025 17:32:56 -0400 Subject: [PATCH 01/20] use pass name from transform --- .github/workflows/check-catalyst.yaml | 12 +++++++ frontend/catalyst/from_plxpr/from_plxpr.py | 41 +++++++++++----------- frontend/catalyst/jax_primitives_utils.py | 16 +++++++-- 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 999e5ad137..2f97fec69d 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -475,6 +475,10 @@ jobs: python3 -m pip install oqc-qcaas-client make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -558,6 +562,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -620,6 +628,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 64d74400ba..a949c5968c 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -138,7 +138,7 @@ class WorkflowInterpreter(PlxprInterpreter): """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" def __init__(self): - self._pass_pipeline = [] + self._pass_pipeline = qml.transforms.core.TransformProgram() self.init_qreg = None # Compiler options for the new decomposition system @@ -231,23 +231,22 @@ def calling_convention(*args): # The map below describes the parity between PL transforms and Catalyst passes. -# PL transforms having a Catalyst pass counterpart will have a name as value, -# otherwise their value will be None. The second value indicates if the transform +# The value indicates if the transform # requires decomposition to be supported by Catalyst. transforms_to_passes = { - pl_cancel_inverses: ("remove-chained-self-inverse", False), - pl_commute_controlled: (None, False), - pl_decompose: (None, False), - pl_map_wires: (None, False), - pl_merge_amplitude_embedding: (None, True), - pl_merge_rotations: ("merge-rotations", False), - pl_single_qubit_fusion: (None, False), - pl_unitary_to_rot: (None, False), + pl_cancel_inverses: False, + pl_commute_controlled: False, + pl_decompose: False, + pl_map_wires: False, + pl_merge_amplitude_embedding: True, + pl_merge_rotations: False, + pl_single_qubit_fusion: False, + pl_unitary_to_rot: False, } # pylint: disable-next=redefined-outer-name -def register_transform(pl_transform, pass_name, decomposition): +def register_transform(pl_transform, decomposition): """Register pennylane transforms and their conversion to Catalyst transforms""" # pylint: disable=too-many-arguments @@ -260,9 +259,8 @@ def handle_transform( inner_jaxpr, targs_slice, tkwargs, - catalyst_pass_name=pass_name, requires_decomposition=decomposition, - pl_plxpr_transform=pl_transform._plxpr_transform, + transform=pl_transform, ): """Handle the conversion from plxpr to Catalyst jaxpr for a PL transform.""" @@ -273,8 +271,8 @@ def handle_transform( # If the transform is a decomposition transform # and the graph-based decomposition is enabled if ( - hasattr(pl_plxpr_transform, "__name__") - and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" + hasattr(transform._plxpr_transform, "__name__") + and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): if not self.requires_decompose_lowering: @@ -315,7 +313,7 @@ def handle_transform( # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) return self.eval(inner_jaxpr, consts, *non_const_args) - if catalyst_pass_name is None: + if transform.pass_name is None: # Use PL's ExpandTransformsInterpreter to expand this and any embedded # transform according to PL rules. It works by overriding the primitive # registration, making all embedded transforms follow the PL rules @@ -324,7 +322,7 @@ def wrapper(*args): return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) - final_jaxpr = pl_plxpr_transform( + final_jaxpr = transform._plxpr_transform( unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args ) @@ -336,7 +334,8 @@ def wrapper(*args): return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) # Apply the corresponding Catalyst pass counterpart - self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) + container = qml.transforms.core.TransformContainer(transform, args=targs, kwargs=tkwargs) + self._pass_pipeline.insert_front(container) return self.eval(inner_jaxpr, consts, *non_const_args) @@ -344,8 +343,8 @@ def wrapper(*args): # across the map above and generates a custom handler for each transform. # In order to ensure early binding, we pass the PL plxpr transform and the # Catalyst pass as arguments whose default values are set by the loop. -for pl_transform, (pass_name, decomposition) in transforms_to_passes.items(): - register_transform(pl_transform, pass_name, decomposition) +for pl_transform, decomposition in transforms_to_passes.items(): + register_transform(pl_transform, decomposition) # pylint: disable=too-many-positional-arguments diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 1227b7f1b5..ab52e9a5cd 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -308,6 +308,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.ctx.module_context = self.old_module_context +def _lowered_options(kwargs): + lowered_options = {} + for option, value in kwargs.items(): + mlir_option = str(option).replace("_", "-") + lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value) + return lowered_options + def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline): """Generate a transform module embedded in the current module and schedule the transformations in pipeline""" @@ -350,11 +357,16 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin with ir.InsertionPoint(bb_named_sequence): target = bb_named_sequence.arguments[0] for _pass in pipeline: - options = _pass.get_options() + if isinstance(_pass, qml.transforms.core.TransformContainer): + options = _lowered_options(_pass.kwargs) + name = _pass.pass_name + else: + options = _pass.get_options() + name = _pass.name apply_registered_pass_op = ApplyRegisteredPassOp( result=transform_mod_type, target=target, - pass_name=_pass.name, + pass_name=name, options=options, dynamic_options={}, ) From cf0150129523ba8656cda8289c405b0b95d319e4 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 3 Nov 2025 13:47:17 -0500 Subject: [PATCH 02/20] some udpates --- frontend/catalyst/from_plxpr/__init__.py | 2 +- frontend/catalyst/from_plxpr/from_plxpr.py | 199 +++++++++------------ 2 files changed, 87 insertions(+), 114 deletions(-) diff --git a/frontend/catalyst/from_plxpr/__init__.py b/frontend/catalyst/from_plxpr/__init__.py index a39039fe1d..3599025f6e 100644 --- a/frontend/catalyst/from_plxpr/__init__.py +++ b/frontend/catalyst/from_plxpr/__init__.py @@ -15,4 +15,4 @@ """Conversion from plxpr to catalyst jaxpr""" from catalyst.from_plxpr.control_flow import handle_cond, handle_for_loop, handle_while_loop -from catalyst.from_plxpr.from_plxpr import from_plxpr, register_transform, trace_from_pennylane +from catalyst.from_plxpr.from_plxpr import from_plxpr, trace_from_pennylane diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index a949c5968c..9e3f93e691 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -27,14 +27,9 @@ from jax.extend.linear_util import wrap_init from pennylane.capture import PlxprInterpreter, qnode_prim from pennylane.capture.expand_transforms import ExpandTransformsInterpreter -from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires -from pennylane.transforms import cancel_inverses as pl_cancel_inverses -from pennylane.transforms import commute_controlled as pl_commute_controlled from pennylane.transforms import decompose as pl_decompose from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding -from pennylane.transforms import merge_rotations as pl_merge_rotations -from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion -from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot +from pennylane.transforms.core.transform_dispatcher import _create_transform_primitive from catalyst.device import extract_backend_info from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter @@ -54,6 +49,8 @@ QubitIndexRecorder, ) +transform_prim = _create_transform_primitive() + def _get_device_kwargs(device) -> dict: """Calulcate the params for a device equation.""" @@ -226,125 +223,101 @@ def calling_convention(*args): wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), *non_const_args, qnode=qnode, - pipeline=self._pass_pipeline, + pipeline=tuple(self._pass_pipeline), ) -# The map below describes the parity between PL transforms and Catalyst passes. -# The value indicates if the transform -# requires decomposition to be supported by Catalyst. -transforms_to_passes = { - pl_cancel_inverses: False, - pl_commute_controlled: False, - pl_decompose: False, - pl_map_wires: False, - pl_merge_amplitude_embedding: True, - pl_merge_rotations: False, - pl_single_qubit_fusion: False, - pl_unitary_to_rot: False, -} - - -# pylint: disable-next=redefined-outer-name -def register_transform(pl_transform, decomposition): - """Register pennylane transforms and their conversion to Catalyst transforms""" - - # pylint: disable=too-many-arguments - @WorkflowInterpreter.register_primitive(pl_transform._primitive) - def handle_transform( - self, - *args, - args_slice, - consts_slice, - inner_jaxpr, - targs_slice, - tkwargs, - requires_decomposition=decomposition, - transform=pl_transform, +require_decomposition_transforms = {pl_merge_amplitude_embedding} + + +# pylint: disable=too-many-arguments +@WorkflowInterpreter.register_primitive(transform_prim) +def handle_transform( + self, + *args, + args_slice, + consts_slice, + inner_jaxpr, + targs_slice, + tkwargs, + transform, +): + """Handle the conversion from plxpr to Catalyst jaxpr for a + PL transform.""" + consts = args[consts_slice] + non_const_args = args[args_slice] + targs = args[targs_slice] + + # If the transform is a decomposition transform + # and the graph-based decomposition is enabled + if ( + hasattr(transform._plxpr_transform, "__name__") + and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" + and qml.decomposition.enabled_graph() ): - """Handle the conversion from plxpr to Catalyst jaxpr for a - PL transform.""" - consts = args[consts_slice] - non_const_args = args[args_slice] - targs = args[targs_slice] - - # If the transform is a decomposition transform - # and the graph-based decomposition is enabled - if ( - hasattr(transform._plxpr_transform, "__name__") - and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" - and qml.decomposition.enabled_graph() - ): - if not self.requires_decompose_lowering: - self.requires_decompose_lowering = True - else: - raise NotImplementedError( - "Multiple decomposition transforms are not yet supported." - ) - - # Update the decompose_gateset to be used by the quantum kernel primitive - # TODO: we originally wanted to treat decompose_gateset as a queue of - # gatesets to be used by the decompose-lowering pass at MLIR - # but this requires a C++ implementation of the graph-based decomposition - # which doesn't exist yet. - self.decompose_tkwargs = tkwargs - - # Note. We don't perform the compiler-specific decomposition here - # to be able to support multiple decomposition transforms - # and collect all the required gatesets - # as well as being able to support other transforms in between. - - # The compiler specific transformation will be performed - # in the qnode handler. - - # Add the decompose-lowering pass to the start of the pipeline - self._pass_pipeline.insert(0, Pass("decompose-lowering")) - - # We still need to construct and solve the graph based on - # the current jaxpr based on the current gateset - # but we don't rewrite the jaxpr at this stage. - - # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) - - # def gds_wrapper(*args): - # return gds_interpreter.eval(inner_jaxpr, consts, *args) - - # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) - # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) - return self.eval(inner_jaxpr, consts, *non_const_args) - - if transform.pass_name is None: - # Use PL's ExpandTransformsInterpreter to expand this and any embedded - # transform according to PL rules. It works by overriding the primitive - # registration, making all embedded transforms follow the PL rules - # from now on, hence ignoring the Catalyst pass conversion - def wrapper(*args): - return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) - - unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) - final_jaxpr = transform._plxpr_transform( - unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args + if not self.requires_decompose_lowering: + self.requires_decompose_lowering = True + else: + raise NotImplementedError( + "Multiple decomposition transforms are not yet supported." ) - if requires_decomposition: - final_jaxpr = pl_decompose._plxpr_transform( - final_jaxpr.jaxpr, final_jaxpr.consts, targs, tkwargs, *non_const_args - ) + # Update the decompose_gateset to be used by the quantum kernel primitive + # TODO: we originally wanted to treat decompose_gateset as a queue of + # gatesets to be used by the decompose-lowering pass at MLIR + # but this requires a C++ implementation of the graph-based decomposition + # which doesn't exist yet. + self.decompose_tkwargs = tkwargs - return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) + # Note. We don't perform the compiler-specific decomposition here + # to be able to support multiple decomposition transforms + # and collect all the required gatesets + # as well as being able to support other transforms in between. - # Apply the corresponding Catalyst pass counterpart - container = qml.transforms.core.TransformContainer(transform, args=targs, kwargs=tkwargs) + # The compiler specific transformation will be performed + # in the qnode handler. + + # Add the decompose-lowering pass to the start of the pipeline + container = qml.transforms.core.TransformContainer(qml.transforms.decompose) self._pass_pipeline.insert_front(container) + # We still need to construct and solve the graph based on + # the current jaxpr based on the current gateset + # but we don't rewrite the jaxpr at this stage. + + # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) + + # def gds_wrapper(*args): + # return gds_interpreter.eval(inner_jaxpr, consts, *args) + + # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) return self.eval(inner_jaxpr, consts, *non_const_args) + if transform.pass_name is None: + # Use PL's ExpandTransformsInterpreter to expand this and any embedded + # transform according to PL rules. It works by overriding the primitive + # registration, making all embedded transforms follow the PL rules + # from now on, hence ignoring the Catalyst pass conversion + def wrapper(*args): + return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) + + unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) + final_jaxpr = transform._plxpr_transform( + unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args + ) + + if transform in require_decomposition_transforms: + final_jaxpr = pl_decompose._plxpr_transform( + final_jaxpr.jaxpr, final_jaxpr.consts, targs, tkwargs, *non_const_args + ) + + return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) + + # Apply the corresponding Catalyst pass counterpart + container = qml.transforms.core.TransformContainer(transform, args=targs, kwargs=tkwargs) + self._pass_pipeline.insert_front(container) + return self.eval(inner_jaxpr, consts, *non_const_args) -# This is our registration factory for PL transforms. The loop below iterates -# across the map above and generates a custom handler for each transform. -# In order to ensure early binding, we pass the PL plxpr transform and the -# Catalyst pass as arguments whose default values are set by the loop. -for pl_transform, decomposition in transforms_to_passes.items(): - register_transform(pl_transform, decomposition) # pylint: disable=too-many-positional-arguments From a08905b25b49cd5a8b903b7ac0bf21ccce7acb15 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 5 Nov 2025 12:03:14 -0500 Subject: [PATCH 03/20] make backwards compatible --- frontend/catalyst/from_plxpr/from_plxpr.py | 970 ++++++++++++--------- frontend/catalyst/qfunc.py | 27 +- 2 files changed, 599 insertions(+), 398 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 9e3f93e691..16997f3cf6 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -1,4 +1,4 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. +# Copyright 2022-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,474 +11,660 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ -This submodule defines a utility for converting plxpr into Catalyst jaxpr. +This module contains a patch for the upstream qml.QNode behaviour, in particular around +what happens when a QNode object is called during tracing. Mostly this involves bypassing +the default behaviour and replacing it with a function-like "QNode" primitive. """ -# pylint: disable=protected-access - - -import warnings -from functools import partial -from typing import Callable +import logging +from copy import copy +from dataclasses import dataclass, replace +from typing import Callable, Sequence -import jax +import jax.numpy as jnp import pennylane as qml -from jax.extend.core import ClosedJaxpr, Jaxpr -from jax.extend.linear_util import wrap_init -from pennylane.capture import PlxprInterpreter, qnode_prim -from pennylane.capture.expand_transforms import ExpandTransformsInterpreter -from pennylane.transforms import decompose as pl_decompose -from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding -from pennylane.transforms.core.transform_dispatcher import _create_transform_primitive - -from catalyst.device import extract_backend_info -from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter -from catalyst.jax_extras import make_jaxpr2, transient_jax_config -from catalyst.jax_primitives import ( - device_init_p, - device_release_p, - qalloc_p, - qdealloc_p, - quantum_kernel_p, +from jax.core import eval_jaxpr +from jax.tree_util import tree_flatten, tree_unflatten +from pennylane import exceptions +from pennylane.measurements import CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP +from pennylane.transforms.dynamic_one_shot import ( + gather_non_mcm, + init_auxiliary_tape, + parse_native_mid_circuit_measurements, ) -from catalyst.passes.pass_api import Pass -from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter -from .qubit_handler import ( - QubitHandler, - QubitIndexRecorder, -) +import catalyst +from catalyst.api_extensions import MidCircuitMeasure +from catalyst.device import QJITDevice +from catalyst.device.qjit_device import is_dynamic_wires +from catalyst.jax_extras import deduce_avals, get_implicit_and_explicit_flat_args, unzip2 +from catalyst.jax_extras.tracing import uses_transform +from catalyst.jax_primitives import quantum_kernel_p +from catalyst.jax_tracer import Function, trace_quantum_function +from catalyst.logging import debug_logger +from catalyst.passes.pass_api import dictionary_to_list_of_passes, Pass +from catalyst.tracing.contexts import EvaluationContext +from catalyst.tracing.type_signatures import filter_static_args +from catalyst.utils.exceptions import CompileError + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +@dataclass +class OutputContext: + """Context containing parameters needed for finalizing quantum function output.""" + + cpy_tape: any + classical_values: any + classical_return_indices: any + out_tree_expected: any + snapshots: any + shot_vector: any + num_mcm: int + + +def _resolve_mcm_config(mcm_config, shots): + """Helper function for resolving and validating that the mcm_config is valid for executing.""" + updated_values = {} + + updated_values["postselect_mode"] = ( + None if isinstance(shots, int) and shots == 0 else mcm_config.postselect_mode + ) + if mcm_config.mcm_method is None: + updated_values["mcm_method"] = "one-shot" + if mcm_config.mcm_method == "deferred": + raise ValueError("mcm_method='deferred' is not supported with Catalyst.") + if ( + mcm_config.mcm_method == "single-branch-statistics" + and mcm_config.postselect_mode == "hw-like" + ): + raise ValueError( + "Cannot use postselect_mode='hw-like' with Catalyst when mcm_method != 'one-shot'." + ) + if mcm_config.mcm_method == "one-shot" and shots == 0: + raise ValueError( + "Cannot use the 'one-shot' method for mid-circuit measurements with analytic mode." + ) -transform_prim = _create_transform_primitive() + return replace(mcm_config, **updated_values) -def _get_device_kwargs(device) -> dict: - """Calulcate the params for a device equation.""" - info = extract_backend_info(device) - # Note that the value of rtd_kwargs is a string version of - # the info kwargs, not the info kwargs itself - # this is due to ease of serialization to MLIR - return { - "rtd_kwargs": str(info.kwargs), - "rtd_lib": info.lpath, - "rtd_name": info.c_interface_name, - } +def _get_total_shots(qnode): + """ + Extract total shots from qnode. + If shots is None on the qnode, this method returns 0 (static). + This method allows the qnode shots to be either static (python int + literals) or dynamic (tracers). + """ + # due to possibility of tracer, we cannot use a simple `or` here to simplify + shots_value = qnode._shots.total_shots # pylint: disable=protected-access + if shots_value is None: + shots = 0 + else: + shots = shots_value + return shots + + +def _is_one_shot_compatible_device(qnode): + device_name = qnode.device.name + exclude_devices = {"softwareq.qpp", "nvidia.custatevec", "nvidia.cutensornet"} + + # Check device name against exclude list + if device_name in exclude_devices: + return False + + # Additional check for OQDDevice class + device_class_name = qnode.device.__class__.__name__ + return device_class_name != "OQDDevice" + + +def configure_mcm_and_try_one_shot(qnode, args, kwargs): + """Configure mid-circuit measurement settings and handle one-shot execution.""" + dynamic_one_shot_called = getattr(qnode, "_dynamic_one_shot_called", False) + if not dynamic_one_shot_called: + mcm_config = copy( + qml.devices.MCMConfig( + postselect_mode=qnode.execute_kwargs["postselect_mode"], + mcm_method=qnode.execute_kwargs["mcm_method"], + ) + ) + total_shots = _get_total_shots(qnode) + user_specified_mcm_method = mcm_config.mcm_method + mcm_config = _resolve_mcm_config(mcm_config, total_shots) + + # Check if measurements_from_{samples/counts} is being used + uses_measurements_from_samples = uses_transform(qnode, "measurements_from_samples") + uses_measurements_from_counts = uses_transform(qnode, "measurements_from_counts") + has_finite_shots = isinstance(total_shots, int) and total_shots > 0 + + # For cases that user are not tend to executed with one-shot, and facing + # 1. non-one-shot compatible device, + # 2. non-finite shots, + # 3. measurement transform, + # fallback to single-branch-statistics + one_shot_compatible = _is_one_shot_compatible_device(qnode) + one_shot_compatible &= has_finite_shots + one_shot_compatible &= not uses_measurements_from_samples + one_shot_compatible &= not uses_measurements_from_counts + + should_fallback = ( + not one_shot_compatible + and user_specified_mcm_method is None + and mcm_config.mcm_method == "one-shot" + ) + + if should_fallback: + mcm_config = replace(mcm_config, mcm_method="single-branch-statistics") + if mcm_config.mcm_method == "one-shot": + # If measurements_from_samples/counts while one-shot is used, raise an error + if uses_measurements_from_samples: + raise CompileError("measurements_from_samples is not supported with one-shot") + if uses_measurements_from_counts: + raise CompileError("measurements_from_counts is not supported with one-shot") -# code example has long lines -# pylint: disable=line-too-long -def from_plxpr(plxpr: ClosedJaxpr) -> Callable[..., Jaxpr]: - """Convert PennyLane variant jaxpr to Catalyst variant jaxpr. + mcm_config = replace( + mcm_config, postselect_mode=mcm_config.postselect_mode or "hw-like" + ) + try: + return Function(dynamic_one_shot(qnode, mcm_config=mcm_config))(*args, **kwargs) + except (TypeError, ValueError, CompileError, NotImplementedError) as e: + # If user specified mcm_method, we can't fallback to single-branch-statistics, + # reraise the original error + if user_specified_mcm_method is not None: + raise + + # Fallback only if mcm was auto-determined + error_msg = str(e) + unsupported_measurement_error = any( + pattern in error_msg + for pattern in [ + "Native mid-circuit measurement mode does not support", + "qml.var(obs) cannot be returned when `mcm_method='one-shot'`", + "empty wires is not supported with dynamic wires in one-shot mode", + "No need to run one-shot mode", + ] + ) + + # Fallback if error is related to unsupported measurements + if unsupported_measurement_error: + logger.warning("Fallback to single-branch-statistics: %s", e) + mcm_config = replace(mcm_config, mcm_method="single-branch-statistics") + else: + raise + return None + + +def _reconstruct_output_with_classical_values( + measurement_results, classical_values, classical_return_indices +): + """ + Reconstruct the output values from the classical values and measurement results. Args: - jaxpr (ClosedJaxpr): PennyLane variant jaxpr - + out: Output from measurement processing + classical_values: Classical values + classical_return_indices: Indices of classical values Returns: - Callable: A function that accepts the same arguments as the plxpr and returns catalyst - variant jaxpr. + results: Reconstructed output with classical values inserted + """ + if not classical_values: + return measurement_results - Note that the input jaxpr should be workflow level and contain qnode primitives, rather than - qfunc level with individual operators. + total_expected = len(classical_values) + len(measurement_results) + classical_iter = iter(classical_values) + measurement_iter = iter(measurement_results) - .. code-block:: python + def get_next_value(idx): + return next(classical_iter) if idx in classical_return_indices else next(measurement_iter) + + results = [get_next_value(i) for i in range(total_expected)] + return results - from catalyst.from_plxpr import from_plxpr - - qml.capture.enable() - - @qml.qnode(qml.device('lightning.qubit', wires=2)) - def circuit(x): - qml.RX(x, 0) - return qml.probs(wires=(0, 1)) - - def f(x): - return circuit(2 * x) ** 2 - - plxpr = jax.make_jaxpr(circuit)(0.5) - - print(from_plxpr(plxpr)(0.5)) - - .. code-block:: none - - { lambda ; a:f64[]. let - b:f64[4] = func[ - call_jaxpr={ lambda ; c:f64[]. let - device_init[ - rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None} - rtd_lib=*** - rtd_name=LightningSimulator - ] - d:AbstractQreg() = qalloc 2 - e:AbstractQbit() = qextract d 0 - f:AbstractQbit() = qinst[ - adjoint=False - ctrl_len=0 - op=RX - params_len=1 - qubits_len=1 - ] e c - g:AbstractQbit() = qextract d 1 - h:AbstractObs(num_qubits=2,primitive=compbasis) = compbasis f g - i:f64[4] = probs[shape=(4,) shots=None] h - j:AbstractQreg() = qinsert d 0 f - qdealloc j - in (i,) } - qnode= - ] a - in (b,) } +def _extract_classical_and_measurement_results(results, classical_return_indices): """ - return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) + Split results into classical values and measurement results. + It assume that the results are in the order of classical values and measurement results. + """ + num_classical_return_indices = len(classical_return_indices) + classical_values = results[:num_classical_return_indices] + measurement_results = results[num_classical_return_indices:] + return classical_values, measurement_results -class WorkflowInterpreter(PlxprInterpreter): - """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" +def _finalize_output(out, ctx: OutputContext): + """ + Finalize the output by reconstructing with classical values and unflattening to the + expected tree structure. + Args: + out: The output to finalize + context: OutputContext containing all necessary parameters for finalization + """ + # Handle case with no measurements + if len(ctx.cpy_tape.measurements) == 0: + out = out[: -ctx.num_mcm] - def __init__(self): - self._pass_pipeline = qml.transforms.core.TransformProgram() - self.init_qreg = None + out = _reconstruct_output_with_classical_values( + out, ctx.classical_values, ctx.classical_return_indices + ) - # Compiler options for the new decomposition system - self.requires_decompose_lowering = False - self.decompose_tkwargs = {} # target gateset + out_tree_expected = ctx.out_tree_expected + if ctx.snapshots is not None: + out = (out[0], tree_unflatten(out_tree_expected[1], out[1])) + else: + out = tree_unflatten(out_tree_expected[0], out) + return out - super().__init__() +class QFunc: + """A device specific quantum function. -# pylint: disable=unused-argument, too-many-arguments -@WorkflowInterpreter.register_primitive(qnode_prim) -def handle_qnode( - self, *args, qnode, device, shots_len, execution_config, qfunc_jaxpr, n_consts, batch_dims=None -): - """Handle the conversion from plxpr to Catalyst jaxpr for the qnode primitive""" + Args: + qfunc (Callable): the quantum function + shots (int): How many times the circuit should be evaluated (or sampled) to estimate + the expectation values + device (a derived class from QubitDevice): a device specification which determines + the valid gate set for the quantum function + """ - self.qubit_index_recorder = QubitIndexRecorder() + def __new__(cls): + raise NotImplementedError() # pragma: no-cover - if shots_len > 1: - raise NotImplementedError("shot vectors are not yet supported for catalyst conversion.") + # pylint: disable=no-member + # pylint: disable=self-cls-assignment + @debug_logger + def __call__(self, *args, **kwargs): - shots = args[0] if shots_len else 0 - consts = args[shots_len : n_consts + shots_len] - non_const_args = args[shots_len + n_consts :] + if EvaluationContext.is_quantum_tracing(): + raise CompileError("Can't nest qnodes under qjit") - closed_jaxpr = ( - ClosedJaxpr(qfunc_jaxpr, consts) - if not self.requires_decompose_lowering - else _apply_compiler_decompose_to_plxpr( - inner_jaxpr=qfunc_jaxpr, - consts=consts, - ncargs=non_const_args, - tgateset=list(self.decompose_tkwargs.get("gate_set", [])), - ) - ) + assert isinstance(self, qml.QNode) - graph_succeeded = False - if self.requires_decompose_lowering: - closed_jaxpr, graph_succeeded = _collect_and_compile_graph_solutions( - inner_jaxpr=closed_jaxpr.jaxpr, - consts=closed_jaxpr.consts, - tkwargs=self.decompose_tkwargs, - ncargs=non_const_args, - ) + new_transform_program, new_pipeline = _extract_passes(self.transform_program) - # Fallback to the legacy decomposition if the graph-based decomposition failed - if not graph_succeeded: - # Remove the decompose-lowering pass from the pipeline - self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"] - closed_jaxpr = _apply_compiler_decompose_to_plxpr( - inner_jaxpr=closed_jaxpr.jaxpr, - consts=closed_jaxpr.consts, - ncargs=non_const_args, - tkwargs=self.decompose_tkwargs, - ) + # Update the qnode with peephole pipeline + pass_pipeline = kwargs.pop("pass_pipeline", ()) or () + pass_pipeline += new_pipeline + pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) + new_qnode = copy(self) + new_qnode._transform_program = new_transform_program # pylint: disable=protected-access + + # Mid-circuit measurement configuration/execution + fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs) + + # If the qnode is failed to execute as one-shot, fn_result will be None + if fn_result is not None: + return fn_result + + new_device = copy(new_qnode.device) + qjit_device = QJITDevice(new_device) + + static_argnums = kwargs.pop("static_argnums", ()) + out_tree_expected = kwargs.pop("_out_tree_expected", []) + classical_return_indices = kwargs.pop("_classical_return_indices", []) + num_mcm_expected = kwargs.pop("_num_mcm_expected", []) + debug_info = kwargs.pop("debug_info", None) - def calling_convention(*args): - device_init_p.bind( - shots, - auto_qubit_management=(device.wires is None), - **_get_device_kwargs(device), + print(new_qnode.transform_program) + print(pass_pipeline) + + def _eval_quantum(*args, **kwargs): + trace_result = trace_quantum_function( + new_qnode.func, + qjit_device, + args, + kwargs, + new_qnode, + static_argnums, + debug_info, + ) + closed_jaxpr = trace_result.closed_jaxpr + out_type = trace_result.out_type + out_tree = trace_result.out_tree + out_tree_exp = trace_result.return_values_tree + cls_ret_idx = trace_result.classical_return_indices + num_mcm = trace_result.num_mcm + + out_tree_expected.append(out_tree_exp) + classical_return_indices.append(cls_ret_idx) + num_mcm_expected.append(num_mcm) + dynamic_args = filter_static_args(args, static_argnums) + args_expanded = get_implicit_and_explicit_flat_args(None, *dynamic_args, **kwargs) + res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded) + _, out_keep = unzip2(out_type) + res_flat = [r for r, k in zip(res_expanded, out_keep) if k] + return tree_unflatten(out_tree, res_flat) + + flattened_fun, _, _, out_tree_promise = deduce_avals( + _eval_quantum, args, kwargs, static_argnums, debug_info ) - qreg = qalloc_p.bind(len(device.wires)) - self.init_qreg = QubitHandler(qreg, self.qubit_index_recorder) - converter = PLxPRToQuantumJaxprInterpreter( - device, shots, self.init_qreg, {}, self.qubit_index_recorder + dynamic_args = filter_static_args(args, static_argnums) + args_flat = tree_flatten((dynamic_args, kwargs))[0] + res_flat = quantum_kernel_p.bind( + flattened_fun, *args_flat, qnode=self, pipeline=tuple(pass_pipeline) ) - retvals = converter(closed_jaxpr, *args) - self.init_qreg.insert_all_dangling_qubits() - qdealloc_p.bind(self.init_qreg.get()) - device_release_p.bind() - return retvals - - if self.requires_decompose_lowering and graph_succeeded: - # Add gate_set attribute to the quantum kernel primitive - # decompose_gatesets is treated as a queue of gatesets to be used - # but we only support a single gateset for now in from_plxpr - # as supporting multiple gatesets requires an MLIR/C++ graph-decomposition - # implementation. The current Python implementation cannot be mixed - # with other transforms in between. - gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])] - setattr(qnode, "decompose_gatesets", [gateset]) - - return quantum_kernel_p.bind( - wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), - *non_const_args, - qnode=qnode, - pipeline=tuple(self._pass_pipeline), + return tree_unflatten(out_tree_promise(), res_flat)[0] + + +# pylint: disable=protected-access +def _get_shot_vector(qnode): + shot_vector = qnode._shots.shot_vector if qnode._shots else [] + return ( + shot_vector + if len(shot_vector) > 1 or any(copies > 1 for _, copies in shot_vector) + else None ) -require_decomposition_transforms = {pl_merge_amplitude_embedding} +def _get_snapshot_results(mcm_method, tape, out): + """ + Get the snapshot results from the tape. + Args: + tape: The tape to get the snapshot results from. + out: The output of the tape. + Returns: + processed_snapshots: The extracted snapshot results if available; + otherwise, returns the original output. + measurement_results: The corresponding measurement results. + """ + # if no snapshot are present, return None, out + assert mcm_method == "one-shot" + if not any(isinstance(op, qml.Snapshot) for op in tape.operations): + return None, out -# pylint: disable=too-many-arguments -@WorkflowInterpreter.register_primitive(transform_prim) -def handle_transform( - self, - *args, - args_slice, - consts_slice, - inner_jaxpr, - targs_slice, - tkwargs, - transform, -): - """Handle the conversion from plxpr to Catalyst jaxpr for a - PL transform.""" - consts = args[consts_slice] - non_const_args = args[args_slice] - targs = args[targs_slice] - - # If the transform is a decomposition transform - # and the graph-based decomposition is enabled - if ( - hasattr(transform._plxpr_transform, "__name__") - and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" - and qml.decomposition.enabled_graph() - ): - if not self.requires_decompose_lowering: - self.requires_decompose_lowering = True - else: - raise NotImplementedError( - "Multiple decomposition transforms are not yet supported." - ) + # Snapshots present: out[0] = snapshots, out[1] = measurements + snapshot_results, measurement_results = out + + # Take first shot for each snapshot + processed_snapshots = [ + snapshot[0] if hasattr(snapshot, "shape") and len(snapshot.shape) > 1 else snapshot + for snapshot in snapshot_results + ] + + return processed_snapshots, measurement_results + + +def _reshape_for_shot_vector(mcm_method, result, shot_vector): + assert mcm_method == "one-shot" - # Update the decompose_gateset to be used by the quantum kernel primitive - # TODO: we originally wanted to treat decompose_gateset as a queue of - # gatesets to be used by the decompose-lowering pass at MLIR - # but this requires a C++ implementation of the graph-based decomposition - # which doesn't exist yet. - self.decompose_tkwargs = tkwargs - - # Note. We don't perform the compiler-specific decomposition here - # to be able to support multiple decomposition transforms - # and collect all the required gatesets - # as well as being able to support other transforms in between. - - # The compiler specific transformation will be performed - # in the qnode handler. - - # Add the decompose-lowering pass to the start of the pipeline - container = qml.transforms.core.TransformContainer(qml.transforms.decompose) - self._pass_pipeline.insert_front(container) - # We still need to construct and solve the graph based on - # the current jaxpr based on the current gateset - # but we don't rewrite the jaxpr at this stage. - - # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) - - # def gds_wrapper(*args): - # return gds_interpreter.eval(inner_jaxpr, consts, *args) - - # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) - # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) - return self.eval(inner_jaxpr, consts, *non_const_args) - - if transform.pass_name is None: - # Use PL's ExpandTransformsInterpreter to expand this and any embedded - # transform according to PL rules. It works by overriding the primitive - # registration, making all embedded transforms follow the PL rules - # from now on, hence ignoring the Catalyst pass conversion - def wrapper(*args): - return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) - - unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) - final_jaxpr = transform._plxpr_transform( - unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args + # Calculate the shape for reshaping based on shot vector + result_list = [] + start_idx = 0 + for shot, copies in shot_vector: + # Reshape this segment to (copies, shot, n_wires) + segment = result[start_idx : start_idx + shot * copies] + if copies > 1: + segment_shape = (copies, shot, result.shape[-1]) + segment = jnp.reshape(segment, segment_shape) + result_list.extend([segment[i] for i in range(copies)]) + else: + result_list.append(segment) + start_idx += shot * copies + result = tuple(result_list) + return result + + +def _process_terminal_measurements(mcm_method, cpy_tape, out, snapshots, shot_vector): + """Process measurements when there are no mid-circuit measurements.""" + assert mcm_method == "one-shot" + + # flatten the outs structure + out, _ = tree_flatten(out) + new_out = [] + idx = 0 + + for m in cpy_tape.measurements: + if isinstance(m, CountsMP): + if isinstance(out[idx], tuple) and len(out[idx]) == 2: + # CountsMP result is stored as (keys, counts) tuple + keys, counts = out[idx] + idx += 1 + else: + keys = out[idx] + counts = out[idx + 1] + idx += 2 + + if snapshots is not None: + counts_array = jnp.stack(counts, axis=0) + aggregated_counts = jnp.sum(counts_array, axis=0) + counts_result = (keys, aggregated_counts) + else: + aggregated_counts = jnp.sum(counts, axis=0) + counts_result = (keys[0], aggregated_counts) + + new_out.append(counts_result) + continue + + result = jnp.squeeze(out[idx]) + max_ndim = min(len(out[idx].shape), 2) + if out[idx].shape[0] == 1: + # Adding the first axis back when the first axis in the original + # array is 1, since it corresponds to the shot's dimension. + result = jnp.expand_dims(result, axis=0) + if result.ndim == 1 and max_ndim == 2: + result = jnp.expand_dims(result, axis=1) + + # Without MCMs and postselection, all samples are valid for use in MP computation. + is_valid = jnp.full((result.shape[0],), True) + processed_result = gather_non_mcm( + m, result, is_valid, postselect_mode="pad-invalid-samples" ) - if transform in require_decomposition_transforms: - final_jaxpr = pl_decompose._plxpr_transform( - final_jaxpr.jaxpr, final_jaxpr.consts, targs, tkwargs, *non_const_args + # Handle shot vector reshaping for SampleMP + if isinstance(m, SampleMP) and shot_vector is not None: + processed_result = _reshape_for_shot_vector(mcm_method, processed_result, shot_vector) + + new_out.append(processed_result) + idx += 1 + + return (snapshots, tuple(new_out)) if snapshots else tuple(new_out) + + +def _validate_one_shot_measurements( + mcm_config, tape: qml.tape.QuantumTape, user_specified_mcm_method, shot_vector, wires +) -> None: + """Validate measurements for one-shot mode. + + Args: + mcm_config: The mid-circuit measurement configuration + tape: The quantum tape containing measurements to validate + qnode: The quantum node being transformed + + Raises: + TypeError: If unsupported measurement types are used + NotImplementedError: If measurement configuration is not supported + """ + mcm_method = mcm_config.mcm_method + assert mcm_method == "one-shot" + + # Check if using shot vector with non-SampleMP measurements + has_shot_vector = len(shot_vector) > 1 or any(copies > 1 for _, copies in shot_vector) + has_wires = wires is not None and not is_dynamic_wires(wires) + + # Raise an error if there are no mid-circuit measurements, it will fallback to + # single-branch-statistics + if ( + not any(isinstance(op, MidCircuitMeasure) for op in tape.operations) + and user_specified_mcm_method is None + ): + raise ValueError("No need to run one-shot mode when there are no mid-circuit measurements.") + + for m in tape.measurements: + # Check if measurement type is supported + if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): + raise TypeError( + f"Native mid-circuit measurement mode does not support {type(m).__name__} " + "measurements." ) - return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) + # Check variance with observable + if isinstance(m, VarianceMP) and m.obs: + raise TypeError( + "qml.var(obs) cannot be returned when `mcm_method='one-shot'` because " + "the Catalyst compiler does not support qml.sample(obs)." + ) - # Apply the corresponding Catalyst pass counterpart - container = qml.transforms.core.TransformContainer(transform, args=targs, kwargs=tkwargs) - self._pass_pipeline.insert_front(container) - return self.eval(inner_jaxpr, consts, *non_const_args) + # Check if the measurement is supported with shot-vector + if has_shot_vector and not isinstance(m, SampleMP): + raise NotImplementedError( + f"Measurement {type(m).__name__} does not support shot-vectors. " + "Use qml.sample() instead." + ) + # Check dynamic wires with empty wires + if not has_wires and isinstance(m, (SampleMP, CountsMP)) and (m.wires.tolist() == []): + raise NotImplementedError( + f"Measurement {type(m).__name__} with empty wires is not supported with " + "dynamic wires in one-shot mode. Please specify a constant number of wires on " + "the device." + ) -# pylint: disable=too-many-positional-arguments -def trace_from_pennylane( - fn, static_argnums, dynamic_args, abstracted_axes, sig, kwargs, debug_info=None -): - """Capture the JAX program representation (JAXPR) of the wrapped function, using - PL capure module. +# pylint: disable=protected-access,no-member,not-callable +def dynamic_one_shot(qnode, **kwargs): + """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. Args: - fn(Callable): the user function to be traced - static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the - positions of static arguments. - dynamic_args(Seqence[Any]): the abstract values of the dynamic arguments. - abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): - An experimental option to specify dynamic tensor shapes. - This option affects the compilation of the annotated function. - Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors - with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section - below. - sig(Sequence[Any]): a tuple indicating the argument signature of the function. Static arguments - are indicated with their literal values, and dynamic arguments are indicated by abstract - values. - kwargs(Dict[str, Any]): keyword argumemts to the function. - debug_info(jax.api_util.debug_info): a source debug information object required by jaxprs. + qnode (QNode): a quantum circuit which will run ``num_shots`` times Returns: - ClosedJaxpr: captured JAXPR - Tuple[Tuple[ShapedArray, bool]]: the return type of the captured JAXPR. - The boolean indicates whether each result is a value returned by the user function. - PyTreeDef: PyTree metadata of the function output - Tuple[Any]: the dynamic argument signature - """ + qnode (QNode): - with transient_jax_config({"jax_dynamic_shapes": True}): + The transformed circuit to be run ``num_shots`` times such as to simulate dynamic execution. - make_jaxpr_kwargs = { - "static_argnums": static_argnums, - "abstracted_axes": abstracted_axes, - "debug_info": debug_info, - } - args = sig + **Example** - if isinstance(fn, qml.QNode) and static_argnums: - # `make_jaxpr2` sees the qnode - # The static_argnum on the wrapped function takes precedence over the - # one in `make_jaxpr` - # https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231 - # Therefore we need to coordinate them manually - fn.static_argnums = static_argnums + Consider the following circuit: - plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) - jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs) + .. code-block:: python - return jaxpr, out_type, out_treedef, sig + dev = qml.device("lightning.qubit", shots=100) + params = np.pi / 4 * np.ones(2) + @qjit + @dynamic_one_shot + @qml.qnode(dev, diff_method=None) + def circuit(x, y): + qml.RX(x, wires=0) + m0 = measure(0, reset=reset, postselect=postselect) -def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, ncargs, tgateset=None, tkwargs=None): - """Apply the compiler-specific decomposition for a given JAXPR. + @cond(m0 == 1) + def ansatz(): + qml.RY(y, wires=1) - This function first disables the graph-based decomposition optimization - to ensure that only high-level gates and templates with a single decomposition - are decomposed. It then performs the pre-mlir decomposition using PennyLane's - `plxpr_transform` function. + ansatz() + return measure_f(wires=[0, 1]) - `tgateset` is a list of target gateset for decomposition. - If provided, it will be combined with the default compiler ops for decomposition. - If not provided, `tkwargs` will be used as the keyword arguments for the - decomposition transform. This is to ensure compatibility with the existing - PennyLane decomposition transform as well as providing a fallback mechanism. + The ``dynamic_one_shot`` decorator prompts the QNode to perform a hundred one-shot + calculations, where in each calculation the ``measure`` operations dynamically + measures the 0-wire and collapse the state vector stochastically. + """ - Args: - inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. - consts (list): The constants used in the JAXPR. - ncargs (list): Non-constant arguments for the JAXPR. - tgateset (list): A list of target gateset for decomposition. Defaults to None. - tkwargs (list): The keyword arguments of the decompose transform. Defaults to None. + cpy_tape = None + mcm_config = kwargs.pop("mcm_config", None) - Returns: - ClosedJaxpr: The decomposed JAXPR. - """ + def transform_to_single_shot(qnode): + if not qnode._shots: + raise exceptions.QuantumFunctionError( + "dynamic_one_shot is only supported with finite shots." + ) - # Disable the graph decomposition optimization - - # Why? Because for the compiler-specific decomposition we want to - # only decompose higher-level gates and templates that only have - # a single decomposition, and not do any further optimization - # based on the graph solution. - # Besides, the graph-based decomposition is not supported - # yet in from_plxpr for most gates and templates. - # TODO: Enable the graph-based decomposition - qml.decomposition.disable_graph() - - kwargs = ( - {"gate_set": set(COMPILER_OPS_FOR_DECOMPOSITION.keys()).union(tgateset)} - if tgateset - else tkwargs - ) - final_jaxpr = qml.transforms.decompose.plxpr_transform(inner_jaxpr, consts, (), kwargs, *ncargs) + user_specified_mcm_method = qnode.execute_kwargs["mcm_method"] + shot_vector = qnode._shots.shot_vector if qnode._shots else [] + wires = qnode.device.wires - qml.decomposition.enable_graph() + @qml.transform + def dynamic_one_shot_partial( + tape: qml.tape.QuantumTape, + ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: + nonlocal cpy_tape + cpy_tape = tape - return final_jaxpr + _validate_one_shot_measurements( + mcm_config, tape, user_specified_mcm_method, shot_vector, wires + ) + if tape.batch_size is not None: + raise ValueError("mcm_method='one-shot' is not compatible with broadcasting") -def _collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): - """Collect and compile graph solutions for a given JAXPR. + aux_tapes = [init_auxiliary_tape(tape)] - This function uses the DecompRuleInterpreter to evaluate - the input JAXPR and obtain a new JAXPR that incorporates - the graph-based decomposition solutions. + def processing_fn(results): + return results - This function doesn't modify the underlying quantum function - but rather constructs a new JAXPR with decomposition rules. + return aux_tapes, processing_fn - Args: - inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. - consts (list): The constants used in the JAXPR. - tkwargs (list): The keyword arguments of the decompose transform. - ncargs (list): Non-constant arguments for the JAXPR. + return dynamic_one_shot_partial(qnode) - Returns: - ClosedJaxpr: The decomposed JAXPR. - bool: A flag indicating whether the graph-based decomposition was successful. - """ - gds_interpreter = DecompRuleInterpreter(**tkwargs) - - def gds_wrapper(*args): - return gds_interpreter.eval(inner_jaxpr, consts, *args) - - graph_succeeded = True - - with warnings.catch_warnings(record=True) as captured_warnings: - warnings.simplefilter("always", UserWarning) - final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) - - for w in captured_warnings: - warnings.showwarning(w.message, w.category, w.filename, w.lineno) - # TODO: use a custom warning class for this in PennyLane to remove this - # string matching and make it more robust. - if "The graph-based decomposition system is unable" in str(w.message): # pragma: no cover - graph_succeeded = False - warnings.warn( - "Falling back to the legacy decomposition system.", - UserWarning, + single_shot_qnode = transform_to_single_shot(qnode) + single_shot_qnode = qml.set_shots(single_shot_qnode, shots=1) + if mcm_config is not None: + single_shot_qnode.execute_kwargs["postselect_mode"] = mcm_config.postselect_mode + single_shot_qnode.execute_kwargs["mcm_method"] = mcm_config.mcm_method + single_shot_qnode._dynamic_one_shot_called = True + total_shots = _get_total_shots(qnode) + + def one_shot_wrapper(*args, **kwargs): + def wrap_single_shot_qnode(*_): + return single_shot_qnode(*args, **kwargs) + + arg_vmap = jnp.empty((total_shots,), dtype=float) + results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap) + if isinstance(results[0], tuple) and len(results) == 1: + results = results[0] + has_mcm = any(isinstance(op, MidCircuitMeasure) for op in cpy_tape.operations) + + classical_return_indices = kwargs.pop("_classical_return_indices", [[]])[0] + num_mcm = kwargs.pop("_num_mcm_expected", [0])[0] + out_tree_expected = kwargs.pop("_out_tree_expected", [[]]) + + # Split results into classical and measurement parts + classical_values, results = _extract_classical_and_measurement_results( + results, classical_return_indices + ) + + out = list(results) + + shot_vector = _get_shot_vector(qnode) + snapshots, out = _get_snapshot_results(mcm_config.mcm_method, cpy_tape, out) + + if has_mcm and len(cpy_tape.measurements) > 0: + out = parse_native_mid_circuit_measurements( + cpy_tape, results=results, postselect_mode="pad-invalid-samples" + ) + if len(cpy_tape.measurements) == 1: + out = (out,) + elif len(cpy_tape.measurements) > 0: + out = _process_terminal_measurements( + mcm_config.mcm_method, cpy_tape, out, snapshots, shot_vector ) - return final_jaxpr, graph_succeeded + ctx = OutputContext( + cpy_tape=cpy_tape, + classical_values=classical_values, + classical_return_indices=classical_return_indices, + out_tree_expected=out_tree_expected, + snapshots=snapshots, + shot_vector=shot_vector, + num_mcm=num_mcm, + ) + return _finalize_output(out, ctx) -def _get_operator_name(op): - """Get the name of a pennylane operator, handling wrapped operators. + return one_shot_wrapper - Note: Controlled and Adjoint ops aren't supported in `gate_set` - by PennyLane's DecompositionGraph; unit tests were added in PennyLane. - """ - if isinstance(op, str): - return op - # Return NoNameOp if the operator has no _primitive.name attribute. - # This is to avoid errors when we capture the program - # as we deal with such ops later in the decomposition graph. - return getattr(op._primitive, "name", "NoNameOp") +def _extract_passes(transform_program): + tape_transforms = [] + pass_pipeline = [] + for t in transform_program: + if t.pass_name: + pass_pipeline.append(Pass(t.pass_name, *t.args, **t.kwargs)) + else: + tape_transforms.append(t) + return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) \ No newline at end of file diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index edac9890a7..fab0d2e090 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -43,7 +43,7 @@ from catalyst.jax_primitives import quantum_kernel_p from catalyst.jax_tracer import Function, trace_quantum_function from catalyst.logging import debug_logger -from catalyst.passes.pass_api import dictionary_to_list_of_passes +from catalyst.passes.pass_api import dictionary_to_list_of_passes, Pass from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import filter_static_args from catalyst.utils.exceptions import CompileError @@ -283,18 +283,22 @@ def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) + new_transform_program, new_pipeline = _extract_passes(self.transform_program) + # Update the qnode with peephole pipeline - pass_pipeline = kwargs.pop("pass_pipeline", []) + pass_pipeline = kwargs.pop("pass_pipeline", []) + new_pipeline pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) + new_qnode = copy(self) + new_qnode._transform_program = new_transform_program # pylint: disable=protected-access # Mid-circuit measurement configuration/execution - fn_result = configure_mcm_and_try_one_shot(self, args, kwargs) + fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs) # If the qnode is failed to execute as one-shot, fn_result will be None if fn_result is not None: return fn_result - new_device = copy(self.device) + new_device = copy(new_qnode.device) qjit_device = QJITDevice(new_device) static_argnums = kwargs.pop("static_argnums", ()) @@ -305,11 +309,11 @@ def __call__(self, *args, **kwargs): def _eval_quantum(*args, **kwargs): trace_result = trace_quantum_function( - self.func, + new_qnode.func, qjit_device, args, kwargs, - self, + new_qnode, static_argnums, debug_info, ) @@ -649,3 +653,14 @@ def wrap_single_shot_qnode(*_): return _finalize_output(out, ctx) return one_shot_wrapper + + +def _extract_passes(transform_program): + tape_transforms = [] + pass_pipeline = [] + for t in transform_program: + if t.pass_name: + pass_pipeline.append(Pass(t.pass_name, *t.args, **t.kwargs)) + else: + tape_transforms.append(t) + return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) \ No newline at end of file From 0cb9476f6535cc88e277a44ce66caf1aec4bd615 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 13 Nov 2025 12:25:41 -0500 Subject: [PATCH 04/20] messed from plxpr up somehow --- frontend/catalyst/from_plxpr/__init__.py | 2 +- frontend/catalyst/from_plxpr/control_flow.py | 150 ++- frontend/catalyst/from_plxpr/decompose.py | 9 +- frontend/catalyst/from_plxpr/from_plxpr.py | 1011 +++++++---------- .../catalyst/from_plxpr/qfunc_interpreter.py | 185 ++- frontend/catalyst/from_plxpr/qubit_handler.py | 42 +- 6 files changed, 720 insertions(+), 679 deletions(-) diff --git a/frontend/catalyst/from_plxpr/__init__.py b/frontend/catalyst/from_plxpr/__init__.py index 3599025f6e..a39039fe1d 100644 --- a/frontend/catalyst/from_plxpr/__init__.py +++ b/frontend/catalyst/from_plxpr/__init__.py @@ -15,4 +15,4 @@ """Conversion from plxpr to catalyst jaxpr""" from catalyst.from_plxpr.control_flow import handle_cond, handle_for_loop, handle_while_loop -from catalyst.from_plxpr.from_plxpr import from_plxpr, trace_from_pennylane +from catalyst.from_plxpr.from_plxpr import from_plxpr, register_transform, trace_from_pennylane diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py index 29a27998a8..a538bc1fab 100644 --- a/frontend/catalyst/from_plxpr/control_flow.py +++ b/frontend/catalyst/from_plxpr/control_flow.py @@ -26,24 +26,70 @@ from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter -from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder +from catalyst.from_plxpr.qubit_handler import ( + QubitHandler, + QubitIndexRecorder, + _get_dynamically_allocated_qregs, +) from catalyst.jax_extras import jaxpr_pad_consts from catalyst.jax_primitives import cond_p, for_p, while_p -def _calling_convention(interpreter, closed_jaxpr, *args_plus_qreg): - # The last arg is the scope argument for the body jaxpr - *args, qreg = args_plus_qreg +def _calling_convention( + interpreter, closed_jaxpr, *args_plus_qregs, outer_dynqreg_handlers=(), return_qreg=True +): + # Arg structure (all args are tracers, since this function is to be `make_jaxpr`'d): + # Regular args, then dynamically allocated qregs, then global qreg + # TODO: merge dynamically allocaed qregs into regular args? + # But this is tricky, since qreg arguments need all the SSA value semantics conversion infra + # and are different from the regular plain arguments. + *args_plus_dynqregs, global_qreg = args_plus_qregs + num_dynamic_alloced_qregs = len(outer_dynqreg_handlers) + args, dynalloced_qregs = ( + args_plus_dynqregs[: len(args_plus_dynqregs) - num_dynamic_alloced_qregs], + args_plus_dynqregs[len(args_plus_dynqregs) - num_dynamic_alloced_qregs :], + ) # Launch a new interpreter for the body region # A new interpreter's root qreg value needs a new recorder converter = copy(interpreter) converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) + init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder) converter.init_qreg = init_qreg + # add dynamic qregs to recorder + qreg_map = {} + dyn_qreg_handlers = [] + for dyn_qreg, outer_dynqreg_handler in zip( + dynalloced_qregs, outer_dynqreg_handlers, strict=True + ): + dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder) + dyn_qreg_handlers.append(dyn_qreg_handler) + + # plxpr global wire index does not change across scopes + # So scope arg dynamic qregs need to have the same root hash as their corresponding + # qreg tracers outside + dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash + + # Each qreg argument of the subscope corresponds to a qreg from the outer scope + qreg_map[outer_dynqreg_handler] = dyn_qreg_handler + + # The new interpreter's recorder needs to be updated to include the qreg args + # of this scope, instead of the outer qregs + if qreg_map: + for k, outer_dynqreg_handler in interpreter.qubit_index_recorder.map.items(): + converter.qubit_index_recorder[k] = qreg_map[outer_dynqreg_handler] + retvals = converter(closed_jaxpr, *args) + if not return_qreg: + return retvals + init_qreg.insert_all_dangling_qubits() + + # Return all registers + for dyn_qreg_handler in dyn_qreg_handlers: + dyn_qreg_handler.insert_all_dangling_qubits() + retvals.append(dyn_qreg_handler.get()) return *retvals, converter.init_qreg.get() @@ -89,7 +135,18 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive""" args = plxpr_invals[args_slice] self.init_qreg.insert_all_dangling_qubits() - args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args + + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + plxpr_invals, self.qubit_index_recorder, self.init_qreg + ) + + # Add the qregs to the args + args_plus_qreg = [ + *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + self.init_qreg.get(), + ] + converted_jaxpr_branches = [] all_consts = [] @@ -102,7 +159,9 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): converted_jaxpr_branch = None closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts) - f = partial(_calling_convention, self, closed_jaxpr) + f = partial( + _calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs + ) converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg) all_consts += converted_jaxpr_branch.consts @@ -111,6 +170,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]] # Build Catalyst compatible input values + # strip global wire indices of dynamic wires + all_consts = tuple(const for const in all_consts if const not in dynalloced_wire_global_indices) cond_invals = [*predicate, *all_consts, *args_plus_qreg] # Perform the binding @@ -120,9 +181,12 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): nimplicit_outputs=None, ) - # We assume the last output value is the returned qreg. + # Output structure: + # First a list of dynamically allocated qregs, then the global qreg # Update the current qreg and remove it from the output values. self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) # Return only the output values that match the plxpr output values return outvals @@ -192,9 +256,15 @@ def handle_for_loop( # Add the iteration start and the qreg to the args self.init_qreg.insert_all_dangling_qubits() + + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + plxpr_invals, self.qubit_index_recorder, self.init_qreg + ) + start_plus_args_plus_qreg = [ start, *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], self.init_qreg.get(), ] @@ -202,7 +272,12 @@ def handle_for_loop( jaxpr = ClosedJaxpr(jaxpr_body_fn, consts) - f = partial(_calling_convention, self, jaxpr) + f = partial( + _calling_convention, + self, + jaxpr, + outer_dynqreg_handlers=dynalloced_qregs, + ) converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg) converted_closed_jaxpr_branch = ClosedJaxpr( @@ -210,7 +285,9 @@ def handle_for_loop( ) # Build Catalyst compatible input values + # strip global wire indices of dynamic wires new_consts = converted_jaxpr_branch.consts + new_consts = tuple(const for const in new_consts if const not in dynalloced_wire_global_indices) for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg] # Config additional for loop settings @@ -226,10 +303,14 @@ def handle_for_loop( preserve_dimensions=True, ) - # We assume the last output value is the returned qreg. + # Output structure: + # First a list of dynamically allocated qregs, then the global qreg # Update the current qreg and remove it from the output values. self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) + # Return only the output values that match the plxpr output values return outvals @@ -288,57 +369,59 @@ def handle_while_loop( ): """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive""" self.init_qreg.insert_all_dangling_qubits() + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + plxpr_invals, self.qubit_index_recorder, self.init_qreg + ) consts_body = plxpr_invals[body_slice] consts_cond = plxpr_invals[cond_slice] args = plxpr_invals[args_slice] - args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args + args_plus_qreg = [ + *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + self.init_qreg.get(), + ] # Add the qreg to the args jaxpr = ClosedJaxpr(jaxpr_body_fn, consts_body) - f = partial(_calling_convention, self, jaxpr) - converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr + f = partial(_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs) + converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg) + new_consts_body = converted_body_jaxpr_branch.consts converted_body_closed_jaxpr_branch = ClosedJaxpr( - convert_constvars_jaxpr(converted_body_jaxpr_branch), () + convert_constvars_jaxpr(converted_body_jaxpr_branch.jaxpr), () ) # Convert for condition from plxpr to Catalyst jaxpr # We need to be able to handle arbitrary plxpr here. # But we want to be able to create a state where: # * We do not pass the quantum register as an argument. - # So let's just remove the quantum register here at the end - jaxpr = ClosedJaxpr(jaxpr_cond_fn, consts_cond) - def remove_qreg(*args_plus_qreg): - # The last arg is the scope argument for the body jaxpr - *args, qreg = args_plus_qreg - - # Launch a new interpreter for the body region - # A new interpreter's root qreg value needs a new recorder - converter = copy(self) - converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) - converter.init_qreg = init_qreg + f_remove_qreg = partial( + _calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs, return_qreg=False + ) - return converter(jaxpr, *args) + converted_cond_jaxpr_branch = jax.make_jaxpr(f_remove_qreg)(*args_plus_qreg) - converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg).jaxpr converted_cond_closed_jaxpr_branch = ClosedJaxpr( - convert_constvars_jaxpr(converted_cond_jaxpr_branch), () + convert_constvars_jaxpr(converted_cond_jaxpr_branch.jaxpr), () ) # Build Catalyst compatible input values - while_loop_invals = [*consts_cond, *consts_body, *args_plus_qreg] + new_consts_cond = converted_cond_jaxpr_branch.consts + new_consts_body = tuple( + const for const in new_consts_body if const not in dynalloced_wire_global_indices + ) + while_loop_invals = [*new_consts_cond, *new_consts_body, *args_plus_qreg] # Perform the binding outvals = while_p.bind( *while_loop_invals, cond_jaxpr=converted_cond_closed_jaxpr_branch, body_jaxpr=converted_body_closed_jaxpr_branch, - cond_nconsts=len(consts_cond), - body_nconsts=len(consts_body), + cond_nconsts=len(new_consts_cond), + body_nconsts=len(new_consts_body), nimplicit=0, preserve_dimensions=True, ) @@ -347,5 +430,8 @@ def remove_qreg(*args_plus_qreg): # Update the current qreg and remove it from the output values. self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) + # Return only the output values that match the plxpr output values return outvals diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 58a34cb620..3df0e3a143 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -234,7 +234,14 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess num_params=num_params, requires_copy=num_wires == -1, ) - else: # pragma: no cover + elif not any( + keyword in getattr(op.op, "name", "") for keyword in ("Adjoint", "Controlled") + ): # pragma: no cover + # Note that the graph-decomposition returns abstracted rules + # for Adjoint and Controlled operations, so we skip them here. + # These abstracted rules cannot be captured and lowered. + # We use MLIR AdjointOp and ControlledOp primitives + # to deal with decomposition of symbolic operations at PLxPR. raise ValueError(f"Could not capture {op} without the number of wires.") data, struct = jax.tree_util.tree_flatten(measurement) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 16997f3cf6..86720baf5a 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -1,4 +1,4 @@ -# Copyright 2022-2024 Xanadu Quantum Technologies Inc. +# Copyright 2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,660 +11,515 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ -This module contains a patch for the upstream qml.QNode behaviour, in particular around -what happens when a QNode object is called during tracing. Mostly this involves bypassing -the default behaviour and replacing it with a function-like "QNode" primitive. +This submodule defines a utility for converting plxpr into Catalyst jaxpr. """ -import logging +# pylint: disable=protected-access + +import warnings from copy import copy -from dataclasses import dataclass, replace -from typing import Callable, Sequence +from functools import partial +from typing import Callable -import jax.numpy as jnp +import jax import pennylane as qml -from jax.core import eval_jaxpr -from jax.tree_util import tree_flatten, tree_unflatten -from pennylane import exceptions -from pennylane.measurements import CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP -from pennylane.transforms.dynamic_one_shot import ( - gather_non_mcm, - init_auxiliary_tape, - parse_native_mid_circuit_measurements, +from jax.extend.core import ClosedJaxpr, Jaxpr +from jax.extend.linear_util import wrap_init +from pennylane.capture import PlxprInterpreter, qnode_prim +from pennylane.capture.expand_transforms import ExpandTransformsInterpreter +from pennylane.capture.primitives import jacobian_prim as pl_jac_prim +from pennylane.capture.primitives import transform_prim +from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires +from pennylane.transforms import cancel_inverses as pl_cancel_inverses +from pennylane.transforms import commute_controlled as pl_commute_controlled +from pennylane.transforms import decompose as pl_decompose +from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding +from pennylane.transforms import merge_rotations as pl_merge_rotations +from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion +from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot + +from catalyst.device import extract_backend_info +from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter +from catalyst.jax_extras import make_jaxpr2, transient_jax_config +from catalyst.jax_primitives import ( + device_init_p, + device_release_p, + qalloc_p, + qdealloc_p, + quantum_kernel_p, ) +from catalyst.passes.pass_api import Pass -import catalyst -from catalyst.api_extensions import MidCircuitMeasure -from catalyst.device import QJITDevice -from catalyst.device.qjit_device import is_dynamic_wires -from catalyst.jax_extras import deduce_avals, get_implicit_and_explicit_flat_args, unzip2 -from catalyst.jax_extras.tracing import uses_transform -from catalyst.jax_primitives import quantum_kernel_p -from catalyst.jax_tracer import Function, trace_quantum_function -from catalyst.logging import debug_logger -from catalyst.passes.pass_api import dictionary_to_list_of_passes, Pass -from catalyst.tracing.contexts import EvaluationContext -from catalyst.tracing.type_signatures import filter_static_args -from catalyst.utils.exceptions import CompileError - -logger = logging.getLogger(__name__) -logger.addHandler(logging.NullHandler()) - - -@dataclass -class OutputContext: - """Context containing parameters needed for finalizing quantum function output.""" - - cpy_tape: any - classical_values: any - classical_return_indices: any - out_tree_expected: any - snapshots: any - shot_vector: any - num_mcm: int - - -def _resolve_mcm_config(mcm_config, shots): - """Helper function for resolving and validating that the mcm_config is valid for executing.""" - updated_values = {} - - updated_values["postselect_mode"] = ( - None if isinstance(shots, int) and shots == 0 else mcm_config.postselect_mode - ) - if mcm_config.mcm_method is None: - updated_values["mcm_method"] = "one-shot" - if mcm_config.mcm_method == "deferred": - raise ValueError("mcm_method='deferred' is not supported with Catalyst.") - if ( - mcm_config.mcm_method == "single-branch-statistics" - and mcm_config.postselect_mode == "hw-like" - ): - raise ValueError( - "Cannot use postselect_mode='hw-like' with Catalyst when mcm_method != 'one-shot'." - ) - if mcm_config.mcm_method == "one-shot" and shots == 0: - raise ValueError( - "Cannot use the 'one-shot' method for mid-circuit measurements with analytic mode." - ) +from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter +from .qubit_handler import ( + QubitHandler, + QubitIndexRecorder, +) - return replace(mcm_config, **updated_values) +def _get_device_kwargs(device) -> dict: + """Calulcate the params for a device equation.""" + info = extract_backend_info(device) + # Note that the value of rtd_kwargs is a string version of + # the info kwargs, not the info kwargs itself + # this is due to ease of serialization to MLIR + return { + "rtd_kwargs": str(info.kwargs), + "rtd_lib": info.lpath, + "rtd_name": info.c_interface_name, + } -def _get_total_shots(qnode): - """ - Extract total shots from qnode. - If shots is None on the qnode, this method returns 0 (static). - This method allows the qnode shots to be either static (python int - literals) or dynamic (tracers). - """ - # due to possibility of tracer, we cannot use a simple `or` here to simplify - shots_value = qnode._shots.total_shots # pylint: disable=protected-access - if shots_value is None: - shots = 0 - else: - shots = shots_value - return shots - - -def _is_one_shot_compatible_device(qnode): - device_name = qnode.device.name - exclude_devices = {"softwareq.qpp", "nvidia.custatevec", "nvidia.cutensornet"} - - # Check device name against exclude list - if device_name in exclude_devices: - return False - - # Additional check for OQDDevice class - device_class_name = qnode.device.__class__.__name__ - return device_class_name != "OQDDevice" - - -def configure_mcm_and_try_one_shot(qnode, args, kwargs): - """Configure mid-circuit measurement settings and handle one-shot execution.""" - dynamic_one_shot_called = getattr(qnode, "_dynamic_one_shot_called", False) - if not dynamic_one_shot_called: - mcm_config = copy( - qml.devices.MCMConfig( - postselect_mode=qnode.execute_kwargs["postselect_mode"], - mcm_method=qnode.execute_kwargs["mcm_method"], - ) - ) - total_shots = _get_total_shots(qnode) - user_specified_mcm_method = mcm_config.mcm_method - mcm_config = _resolve_mcm_config(mcm_config, total_shots) - - # Check if measurements_from_{samples/counts} is being used - uses_measurements_from_samples = uses_transform(qnode, "measurements_from_samples") - uses_measurements_from_counts = uses_transform(qnode, "measurements_from_counts") - has_finite_shots = isinstance(total_shots, int) and total_shots > 0 - - # For cases that user are not tend to executed with one-shot, and facing - # 1. non-one-shot compatible device, - # 2. non-finite shots, - # 3. measurement transform, - # fallback to single-branch-statistics - one_shot_compatible = _is_one_shot_compatible_device(qnode) - one_shot_compatible &= has_finite_shots - one_shot_compatible &= not uses_measurements_from_samples - one_shot_compatible &= not uses_measurements_from_counts - - should_fallback = ( - not one_shot_compatible - and user_specified_mcm_method is None - and mcm_config.mcm_method == "one-shot" - ) - - if should_fallback: - mcm_config = replace(mcm_config, mcm_method="single-branch-statistics") - - if mcm_config.mcm_method == "one-shot": - # If measurements_from_samples/counts while one-shot is used, raise an error - if uses_measurements_from_samples: - raise CompileError("measurements_from_samples is not supported with one-shot") - if uses_measurements_from_counts: - raise CompileError("measurements_from_counts is not supported with one-shot") - mcm_config = replace( - mcm_config, postselect_mode=mcm_config.postselect_mode or "hw-like" - ) +# code example has long lines +# pylint: disable=line-too-long +def from_plxpr(plxpr: ClosedJaxpr) -> Callable[..., Jaxpr]: + """Convert PennyLane variant jaxpr to Catalyst variant jaxpr. - try: - return Function(dynamic_one_shot(qnode, mcm_config=mcm_config))(*args, **kwargs) - except (TypeError, ValueError, CompileError, NotImplementedError) as e: - # If user specified mcm_method, we can't fallback to single-branch-statistics, - # reraise the original error - if user_specified_mcm_method is not None: - raise - - # Fallback only if mcm was auto-determined - error_msg = str(e) - unsupported_measurement_error = any( - pattern in error_msg - for pattern in [ - "Native mid-circuit measurement mode does not support", - "qml.var(obs) cannot be returned when `mcm_method='one-shot'`", - "empty wires is not supported with dynamic wires in one-shot mode", - "No need to run one-shot mode", - ] - ) - - # Fallback if error is related to unsupported measurements - if unsupported_measurement_error: - logger.warning("Fallback to single-branch-statistics: %s", e) - mcm_config = replace(mcm_config, mcm_method="single-branch-statistics") - else: - raise - return None - - -def _reconstruct_output_with_classical_values( - measurement_results, classical_values, classical_return_indices -): - """ - Reconstruct the output values from the classical values and measurement results. Args: - out: Output from measurement processing - classical_values: Classical values - classical_return_indices: Indices of classical values - Returns: - results: Reconstructed output with classical values inserted - """ - if not classical_values: - return measurement_results - - total_expected = len(classical_values) + len(measurement_results) - classical_iter = iter(classical_values) - measurement_iter = iter(measurement_results) + jaxpr (ClosedJaxpr): PennyLane variant jaxpr - def get_next_value(idx): - return next(classical_iter) if idx in classical_return_indices else next(measurement_iter) - - results = [get_next_value(i) for i in range(total_expected)] - return results + Returns: + Callable: A function that accepts the same arguments as the plxpr and returns catalyst + variant jaxpr. + Note that the input jaxpr should be workflow level and contain qnode primitives, rather than + qfunc level with individual operators. -def _extract_classical_and_measurement_results(results, classical_return_indices): - """ - Split results into classical values and measurement results. - It assume that the results are in the order of classical values and measurement results. - """ - num_classical_return_indices = len(classical_return_indices) - classical_values = results[:num_classical_return_indices] - measurement_results = results[num_classical_return_indices:] - return classical_values, measurement_results + .. code-block:: python + from catalyst.from_plxpr import from_plxpr + + qml.capture.enable() + + @qml.qnode(qml.device('lightning.qubit', wires=2)) + def circuit(x): + qml.RX(x, 0) + return qml.probs(wires=(0, 1)) + + def f(x): + return circuit(2 * x) ** 2 + + plxpr = jax.make_jaxpr(circuit)(0.5) + + print(from_plxpr(plxpr)(0.5)) + + .. code-block:: none + + { lambda ; a:f64[]. let + b:f64[4] = func[ + call_jaxpr={ lambda ; c:f64[]. let + device_init[ + rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None} + rtd_lib=*** + rtd_name=LightningSimulator + ] + d:AbstractQreg() = qalloc 2 + e:AbstractQbit() = qextract d 0 + f:AbstractQbit() = qinst[ + adjoint=False + ctrl_len=0 + op=RX + params_len=1 + qubits_len=1 + ] e c + g:AbstractQbit() = qextract d 1 + h:AbstractObs(num_qubits=2,primitive=compbasis) = compbasis f g + i:f64[4] = probs[shape=(4,) shots=None] h + j:AbstractQreg() = qinsert d 0 f + qdealloc j + in (i,) } + qnode= + ] a + in (b,) } -def _finalize_output(out, ctx: OutputContext): - """ - Finalize the output by reconstructing with classical values and unflattening to the - expected tree structure. - Args: - out: The output to finalize - context: OutputContext containing all necessary parameters for finalization """ - # Handle case with no measurements - if len(ctx.cpy_tape.measurements) == 0: - out = out[: -ctx.num_mcm] - - out = _reconstruct_output_with_classical_values( - out, ctx.classical_values, ctx.classical_return_indices - ) + return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) - out_tree_expected = ctx.out_tree_expected - if ctx.snapshots is not None: - out = (out[0], tree_unflatten(out_tree_expected[1], out[1])) - else: - out = tree_unflatten(out_tree_expected[0], out) - return out +class WorkflowInterpreter(PlxprInterpreter): + """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" -class QFunc: - """A device specific quantum function. + def __copy__(self): + new_version = WorkflowInterpreter() + new_version._pass_pipeline = copy(self._pass_pipeline) + new_version.init_qreg = self.init_qreg + new_version.requires_decompose_lowering = self.requires_decompose_lowering + new_version.decompose_tkwargs = copy(self.decompose_tkwargs) + return new_version - Args: - qfunc (Callable): the quantum function - shots (int): How many times the circuit should be evaluated (or sampled) to estimate - the expectation values - device (a derived class from QubitDevice): a device specification which determines - the valid gate set for the quantum function - """ - - def __new__(cls): - raise NotImplementedError() # pragma: no-cover + def __init__(self): + self._pass_pipeline = [] + self.init_qreg = None - # pylint: disable=no-member - # pylint: disable=self-cls-assignment - @debug_logger - def __call__(self, *args, **kwargs): + # Compiler options for the new decomposition system + self.requires_decompose_lowering = False + self.decompose_tkwargs = {} # target gateset - if EvaluationContext.is_quantum_tracing(): - raise CompileError("Can't nest qnodes under qjit") + super().__init__() - assert isinstance(self, qml.QNode) - new_transform_program, new_pipeline = _extract_passes(self.transform_program) +@WorkflowInterpreter.register_primitive(pl_jac_prim) +def handle_grad(self, *args, jaxpr, n_consts, **kwargs): + """Translate a grad equation.""" + f = partial(copy(self).eval, jaxpr, args[:n_consts]) + new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]) - # Update the qnode with peephole pipeline - pass_pipeline = kwargs.pop("pass_pipeline", ()) or () - pass_pipeline += new_pipeline - pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) - new_qnode = copy(self) - new_qnode._transform_program = new_transform_program # pylint: disable=protected-access + new_args = (*new_jaxpr.consts, *args[n_consts:]) + return pl_jac_prim.bind( + *new_args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **kwargs + ) - # Mid-circuit measurement configuration/execution - fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs) - # If the qnode is failed to execute as one-shot, fn_result will be None - if fn_result is not None: - return fn_result +# pylint: disable=unused-argument, too-many-arguments +@WorkflowInterpreter.register_primitive(qnode_prim) +def handle_qnode( + self, *args, qnode, device, shots_len, execution_config, qfunc_jaxpr, n_consts, batch_dims=None +): + """Handle the conversion from plxpr to Catalyst jaxpr for the qnode primitive""" - new_device = copy(new_qnode.device) - qjit_device = QJITDevice(new_device) + self.qubit_index_recorder = QubitIndexRecorder() - static_argnums = kwargs.pop("static_argnums", ()) - out_tree_expected = kwargs.pop("_out_tree_expected", []) - classical_return_indices = kwargs.pop("_classical_return_indices", []) - num_mcm_expected = kwargs.pop("_num_mcm_expected", []) - debug_info = kwargs.pop("debug_info", None) + if shots_len > 1: + raise NotImplementedError("shot vectors are not yet supported for catalyst conversion.") - print(new_qnode.transform_program) - print(pass_pipeline) + shots = args[0] if shots_len else 0 + consts = args[shots_len : n_consts + shots_len] + non_const_args = args[shots_len + n_consts :] - def _eval_quantum(*args, **kwargs): - trace_result = trace_quantum_function( - new_qnode.func, - qjit_device, - args, - kwargs, - new_qnode, - static_argnums, - debug_info, - ) - closed_jaxpr = trace_result.closed_jaxpr - out_type = trace_result.out_type - out_tree = trace_result.out_tree - out_tree_exp = trace_result.return_values_tree - cls_ret_idx = trace_result.classical_return_indices - num_mcm = trace_result.num_mcm - - out_tree_expected.append(out_tree_exp) - classical_return_indices.append(cls_ret_idx) - num_mcm_expected.append(num_mcm) - dynamic_args = filter_static_args(args, static_argnums) - args_expanded = get_implicit_and_explicit_flat_args(None, *dynamic_args, **kwargs) - res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded) - _, out_keep = unzip2(out_type) - res_flat = [r for r, k in zip(res_expanded, out_keep) if k] - return tree_unflatten(out_tree, res_flat) - - flattened_fun, _, _, out_tree_promise = deduce_avals( - _eval_quantum, args, kwargs, static_argnums, debug_info + closed_jaxpr = ( + ClosedJaxpr(qfunc_jaxpr, consts) + if not self.requires_decompose_lowering + else _apply_compiler_decompose_to_plxpr( + inner_jaxpr=qfunc_jaxpr, + consts=consts, + ncargs=non_const_args, + tgateset=list(self.decompose_tkwargs.get("gate_set", [])), ) - dynamic_args = filter_static_args(args, static_argnums) - args_flat = tree_flatten((dynamic_args, kwargs))[0] - res_flat = quantum_kernel_p.bind( - flattened_fun, *args_flat, qnode=self, pipeline=tuple(pass_pipeline) - ) - return tree_unflatten(out_tree_promise(), res_flat)[0] - - -# pylint: disable=protected-access -def _get_shot_vector(qnode): - shot_vector = qnode._shots.shot_vector if qnode._shots else [] - return ( - shot_vector - if len(shot_vector) > 1 or any(copies > 1 for _, copies in shot_vector) - else None ) - -def _get_snapshot_results(mcm_method, tape, out): - """ - Get the snapshot results from the tape. - Args: - tape: The tape to get the snapshot results from. - out: The output of the tape. - Returns: - processed_snapshots: The extracted snapshot results if available; - otherwise, returns the original output. - measurement_results: The corresponding measurement results. - """ - # if no snapshot are present, return None, out - assert mcm_method == "one-shot" - - if not any(isinstance(op, qml.Snapshot) for op in tape.operations): - return None, out - - # Snapshots present: out[0] = snapshots, out[1] = measurements - snapshot_results, measurement_results = out - - # Take first shot for each snapshot - processed_snapshots = [ - snapshot[0] if hasattr(snapshot, "shape") and len(snapshot.shape) > 1 else snapshot - for snapshot in snapshot_results - ] - - return processed_snapshots, measurement_results - - -def _reshape_for_shot_vector(mcm_method, result, shot_vector): - assert mcm_method == "one-shot" - - # Calculate the shape for reshaping based on shot vector - result_list = [] - start_idx = 0 - for shot, copies in shot_vector: - # Reshape this segment to (copies, shot, n_wires) - segment = result[start_idx : start_idx + shot * copies] - if copies > 1: - segment_shape = (copies, shot, result.shape[-1]) - segment = jnp.reshape(segment, segment_shape) - result_list.extend([segment[i] for i in range(copies)]) - else: - result_list.append(segment) - start_idx += shot * copies - result = tuple(result_list) - return result - - -def _process_terminal_measurements(mcm_method, cpy_tape, out, snapshots, shot_vector): - """Process measurements when there are no mid-circuit measurements.""" - assert mcm_method == "one-shot" - - # flatten the outs structure - out, _ = tree_flatten(out) - new_out = [] - idx = 0 - - for m in cpy_tape.measurements: - if isinstance(m, CountsMP): - if isinstance(out[idx], tuple) and len(out[idx]) == 2: - # CountsMP result is stored as (keys, counts) tuple - keys, counts = out[idx] - idx += 1 - else: - keys = out[idx] - counts = out[idx + 1] - idx += 2 - - if snapshots is not None: - counts_array = jnp.stack(counts, axis=0) - aggregated_counts = jnp.sum(counts_array, axis=0) - counts_result = (keys, aggregated_counts) - else: - aggregated_counts = jnp.sum(counts, axis=0) - counts_result = (keys[0], aggregated_counts) - - new_out.append(counts_result) - continue - - result = jnp.squeeze(out[idx]) - max_ndim = min(len(out[idx].shape), 2) - if out[idx].shape[0] == 1: - # Adding the first axis back when the first axis in the original - # array is 1, since it corresponds to the shot's dimension. - result = jnp.expand_dims(result, axis=0) - if result.ndim == 1 and max_ndim == 2: - result = jnp.expand_dims(result, axis=1) - - # Without MCMs and postselection, all samples are valid for use in MP computation. - is_valid = jnp.full((result.shape[0],), True) - processed_result = gather_non_mcm( - m, result, is_valid, postselect_mode="pad-invalid-samples" + graph_succeeded = False + if self.requires_decompose_lowering: + closed_jaxpr, graph_succeeded = _collect_and_compile_graph_solutions( + inner_jaxpr=closed_jaxpr.jaxpr, + consts=closed_jaxpr.consts, + tkwargs=self.decompose_tkwargs, + ncargs=non_const_args, ) - # Handle shot vector reshaping for SampleMP - if isinstance(m, SampleMP) and shot_vector is not None: - processed_result = _reshape_for_shot_vector(mcm_method, processed_result, shot_vector) - - new_out.append(processed_result) - idx += 1 - - return (snapshots, tuple(new_out)) if snapshots else tuple(new_out) - - -def _validate_one_shot_measurements( - mcm_config, tape: qml.tape.QuantumTape, user_specified_mcm_method, shot_vector, wires -) -> None: - """Validate measurements for one-shot mode. - - Args: - mcm_config: The mid-circuit measurement configuration - tape: The quantum tape containing measurements to validate - qnode: The quantum node being transformed + # Fallback to the legacy decomposition if the graph-based decomposition failed + if not graph_succeeded: + # Remove the decompose-lowering pass from the pipeline + self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"] + closed_jaxpr = _apply_compiler_decompose_to_plxpr( + inner_jaxpr=closed_jaxpr.jaxpr, + consts=closed_jaxpr.consts, + ncargs=non_const_args, + tkwargs=self.decompose_tkwargs, + ) - Raises: - TypeError: If unsupported measurement types are used - NotImplementedError: If measurement configuration is not supported - """ - mcm_method = mcm_config.mcm_method - assert mcm_method == "one-shot" + def calling_convention(*args): + device_init_p.bind( + shots, + auto_qubit_management=(device.wires is None), + **_get_device_kwargs(device), + ) + qreg = qalloc_p.bind(len(device.wires)) + self.init_qreg = QubitHandler(qreg, self.qubit_index_recorder) + converter = PLxPRToQuantumJaxprInterpreter( + device, shots, self.init_qreg, {}, self.qubit_index_recorder + ) + retvals = converter(closed_jaxpr, *args) + self.init_qreg.insert_all_dangling_qubits() + qdealloc_p.bind(self.init_qreg.get()) + device_release_p.bind() + return retvals + + if self.requires_decompose_lowering and graph_succeeded: + # Add gate_set attribute to the quantum kernel primitive + # decompose_gatesets is treated as a queue of gatesets to be used + # but we only support a single gateset for now in from_plxpr + # as supporting multiple gatesets requires an MLIR/C++ graph-decomposition + # implementation. The current Python implementation cannot be mixed + # with other transforms in between. + gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])] + setattr(qnode, "decompose_gatesets", [gateset]) + + return quantum_kernel_p.bind( + wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), + *non_const_args, + qnode=qnode, + pipeline=self._pass_pipeline, + ) - # Check if using shot vector with non-SampleMP measurements - has_shot_vector = len(shot_vector) > 1 or any(copies > 1 for _, copies in shot_vector) - has_wires = wires is not None and not is_dynamic_wires(wires) - # Raise an error if there are no mid-circuit measurements, it will fallback to - # single-branch-statistics +# The map below describes the parity between PL transforms and Catalyst passes. +# PL transforms having a Catalyst pass counterpart will have a name as value, +# otherwise their value will be None. The second value indicates if the transform +# requires decomposition to be supported by Catalyst. +transforms_to_passes = { + pl_cancel_inverses: ("remove-chained-self-inverse", False), + pl_commute_controlled: (None, False), + pl_decompose: (None, False), + pl_map_wires: (None, False), + pl_merge_amplitude_embedding: (None, True), + pl_merge_rotations: ("merge-rotations", False), + pl_single_qubit_fusion: (None, False), + pl_unitary_to_rot: (None, False), +} + + +def register_transform(pl_transform, pass_name, decomposition): + """Register pennylane transforms and their conversion to Catalyst transforms""" + transforms_to_passes[pl_transform] = (pass_name, decomposition) + + +# pylint: disable=too-many-arguments +@WorkflowInterpreter.register_primitive(transform_prim) +def handle_transform( + self, + *args, + args_slice, + consts_slice, + inner_jaxpr, + targs_slice, + tkwargs, + transform, +): + """Handle the conversion from plxpr to Catalyst jaxpr for a + PL transform.""" + consts = args[consts_slice] + non_const_args = args[args_slice] + targs = args[targs_slice] + + # If the transform is a decomposition transform + # and the graph-based decomposition is enabled if ( - not any(isinstance(op, MidCircuitMeasure) for op in tape.operations) - and user_specified_mcm_method is None + hasattr(transform._plxpr_transform, "__name__") + and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" + and qml.decomposition.enabled_graph() ): - raise ValueError("No need to run one-shot mode when there are no mid-circuit measurements.") - - for m in tape.measurements: - # Check if measurement type is supported - if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): - raise TypeError( - f"Native mid-circuit measurement mode does not support {type(m).__name__} " - "measurements." - ) - - # Check variance with observable - if isinstance(m, VarianceMP) and m.obs: - raise TypeError( - "qml.var(obs) cannot be returned when `mcm_method='one-shot'` because " - "the Catalyst compiler does not support qml.sample(obs)." + if not self.requires_decompose_lowering: + self.requires_decompose_lowering = True + else: + raise NotImplementedError("Multiple decomposition transforms are not yet supported.") + + next_eval = copy(self) + # Update the decompose_gateset to be used by the quantum kernel primitive + # TODO: we originally wanted to treat decompose_gateset as a queue of + # gatesets to be used by the decompose-lowering pass at MLIR + # but this requires a C++ implementation of the graph-based decomposition + # which doesn't exist yet. + next_eval.decompose_tkwargs = tkwargs + + # Note. We don't perform the compiler-specific decomposition here + # to be able to support multiple decomposition transforms + # and collect all the required gatesets + # as well as being able to support other transforms in between. + + # The compiler specific transformation will be performed + # in the qnode handler. + + # Add the decompose-lowering pass to the start of the pipeline + next_eval._pass_pipeline.insert(0, Pass("decompose-lowering")) + + # We still need to construct and solve the graph based on + # the current jaxpr based on the current gateset + # but we don't rewrite the jaxpr at this stage. + + # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) + + # def gds_wrapper(*args): + # return gds_interpreter.eval(inner_jaxpr, consts, *args) + + # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) + return next_eval.eval(inner_jaxpr, consts, *non_const_args) + + catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0] + if catalyst_pass_name is None: + # Use PL's ExpandTransformsInterpreter to expand this and any embedded + # transform according to PL rules. It works by overriding the primitive + # registration, making all embedded transforms follow the PL rules + # from now on, hence ignoring the Catalyst pass conversion + def wrapper(*args): + return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) + + unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) + final_jaxpr = transform._plxpr_transform( + unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args + ) + if transforms_to_passes[transform][1]: + final_jaxpr = pl_decompose._plxpr_transform( + final_jaxpr.jaxpr, final_jaxpr.consts, targs, tkwargs, *non_const_args ) - # Check if the measurement is supported with shot-vector - if has_shot_vector and not isinstance(m, SampleMP): - raise NotImplementedError( - f"Measurement {type(m).__name__} does not support shot-vectors. " - "Use qml.sample() instead." - ) + return copy(self).eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) - # Check dynamic wires with empty wires - if not has_wires and isinstance(m, (SampleMP, CountsMP)) and (m.wires.tolist() == []): - raise NotImplementedError( - f"Measurement {type(m).__name__} with empty wires is not supported with " - "dynamic wires in one-shot mode. Please specify a constant number of wires on " - "the device." - ) + # Apply the corresponding Catalyst pass counterpart + next_eval = copy(self) + next_eval._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs)) + return next_eval.eval(inner_jaxpr, consts, *non_const_args) -# pylint: disable=protected-access,no-member,not-callable -def dynamic_one_shot(qnode, **kwargs): - """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. +# pylint: disable=too-many-positional-arguments +def trace_from_pennylane( + fn, static_argnums, dynamic_args, abstracted_axes, sig, kwargs, debug_info=None +): + """Capture the JAX program representation (JAXPR) of the wrapped function, using + PL capure module. Args: - qnode (QNode): a quantum circuit which will run ``num_shots`` times + fn(Callable): the user function to be traced + static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the + positions of static arguments. + dynamic_args(Seqence[Any]): the abstract values of the dynamic arguments. + abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): + An experimental option to specify dynamic tensor shapes. + This option affects the compilation of the annotated function. + Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors + with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section + below. + sig(Sequence[Any]): a tuple indicating the argument signature of the function. Static arguments + are indicated with their literal values, and dynamic arguments are indicated by abstract + values. + kwargs(Dict[str, Any]): keyword argumemts to the function. + debug_info(jax.api_util.debug_info): a source debug information object required by jaxprs. Returns: - qnode (QNode): - - The transformed circuit to be run ``num_shots`` times such as to simulate dynamic execution. - - - **Example** - - Consider the following circuit: - - .. code-block:: python - - dev = qml.device("lightning.qubit", shots=100) - params = np.pi / 4 * np.ones(2) - - @qjit - @dynamic_one_shot - @qml.qnode(dev, diff_method=None) - def circuit(x, y): - qml.RX(x, wires=0) - m0 = measure(0, reset=reset, postselect=postselect) + ClosedJaxpr: captured JAXPR + Tuple[Tuple[ShapedArray, bool]]: the return type of the captured JAXPR. + The boolean indicates whether each result is a value returned by the user function. + PyTreeDef: PyTree metadata of the function output + Tuple[Any]: the dynamic argument signature + """ - @cond(m0 == 1) - def ansatz(): - qml.RY(y, wires=1) + with transient_jax_config({"jax_dynamic_shapes": True}): - ansatz() - return measure_f(wires=[0, 1]) + make_jaxpr_kwargs = { + "static_argnums": static_argnums, + "abstracted_axes": abstracted_axes, + "debug_info": debug_info, + } - The ``dynamic_one_shot`` decorator prompts the QNode to perform a hundred one-shot - calculations, where in each calculation the ``measure`` operations dynamically - measures the 0-wire and collapse the state vector stochastically. - """ + args = sig - cpy_tape = None - mcm_config = kwargs.pop("mcm_config", None) + if isinstance(fn, qml.QNode) and static_argnums: + # `make_jaxpr2` sees the qnode + # The static_argnum on the wrapped function takes precedence over the + # one in `make_jaxpr` + # https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231 + # Therefore we need to coordinate them manually + fn.static_argnums = static_argnums - def transform_to_single_shot(qnode): - if not qnode._shots: - raise exceptions.QuantumFunctionError( - "dynamic_one_shot is only supported with finite shots." - ) + plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) + jaxpr = from_plxpr(plxpr)(*plxpr.in_avals) - user_specified_mcm_method = qnode.execute_kwargs["mcm_method"] - shot_vector = qnode._shots.shot_vector if qnode._shots else [] - wires = qnode.device.wires + return jaxpr, out_type, out_treedef, sig - @qml.transform - def dynamic_one_shot_partial( - tape: qml.tape.QuantumTape, - ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: - nonlocal cpy_tape - cpy_tape = tape - _validate_one_shot_measurements( - mcm_config, tape, user_specified_mcm_method, shot_vector, wires - ) +def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, ncargs, tgateset=None, tkwargs=None): + """Apply the compiler-specific decomposition for a given JAXPR. - if tape.batch_size is not None: - raise ValueError("mcm_method='one-shot' is not compatible with broadcasting") + This function first disables the graph-based decomposition optimization + to ensure that only high-level gates and templates with a single decomposition + are decomposed. It then performs the pre-mlir decomposition using PennyLane's + `plxpr_transform` function. - aux_tapes = [init_auxiliary_tape(tape)] + `tgateset` is a list of target gateset for decomposition. + If provided, it will be combined with the default compiler ops for decomposition. + If not provided, `tkwargs` will be used as the keyword arguments for the + decomposition transform. This is to ensure compatibility with the existing + PennyLane decomposition transform as well as providing a fallback mechanism. - def processing_fn(results): - return results + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + ncargs (list): Non-constant arguments for the JAXPR. + tgateset (list): A list of target gateset for decomposition. Defaults to None. + tkwargs (list): The keyword arguments of the decompose transform. Defaults to None. - return aux_tapes, processing_fn + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ - return dynamic_one_shot_partial(qnode) + # Disable the graph decomposition optimization + + # Why? Because for the compiler-specific decomposition we want to + # only decompose higher-level gates and templates that only have + # a single decomposition, and not do any further optimization + # based on the graph solution. + # Besides, the graph-based decomposition is not supported + # yet in from_plxpr for most gates and templates. + # TODO: Enable the graph-based decomposition + qml.decomposition.disable_graph() + + kwargs = ( + {"gate_set": set(COMPILER_OPS_FOR_DECOMPOSITION.keys()).union(tgateset)} + if tgateset + else tkwargs + ) + final_jaxpr = qml.transforms.decompose.plxpr_transform(inner_jaxpr, consts, (), kwargs, *ncargs) - single_shot_qnode = transform_to_single_shot(qnode) - single_shot_qnode = qml.set_shots(single_shot_qnode, shots=1) - if mcm_config is not None: - single_shot_qnode.execute_kwargs["postselect_mode"] = mcm_config.postselect_mode - single_shot_qnode.execute_kwargs["mcm_method"] = mcm_config.mcm_method - single_shot_qnode._dynamic_one_shot_called = True - total_shots = _get_total_shots(qnode) + qml.decomposition.enable_graph() - def one_shot_wrapper(*args, **kwargs): - def wrap_single_shot_qnode(*_): - return single_shot_qnode(*args, **kwargs) + return final_jaxpr - arg_vmap = jnp.empty((total_shots,), dtype=float) - results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap) - if isinstance(results[0], tuple) and len(results) == 1: - results = results[0] - has_mcm = any(isinstance(op, MidCircuitMeasure) for op in cpy_tape.operations) - classical_return_indices = kwargs.pop("_classical_return_indices", [[]])[0] - num_mcm = kwargs.pop("_num_mcm_expected", [0])[0] - out_tree_expected = kwargs.pop("_out_tree_expected", [[]]) +def _collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): + """Collect and compile graph solutions for a given JAXPR. - # Split results into classical and measurement parts - classical_values, results = _extract_classical_and_measurement_results( - results, classical_return_indices - ) + This function uses the DecompRuleInterpreter to evaluate + the input JAXPR and obtain a new JAXPR that incorporates + the graph-based decomposition solutions. - out = list(results) + This function doesn't modify the underlying quantum function + but rather constructs a new JAXPR with decomposition rules. - shot_vector = _get_shot_vector(qnode) - snapshots, out = _get_snapshot_results(mcm_config.mcm_method, cpy_tape, out) + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tkwargs (list): The keyword arguments of the decompose transform. + ncargs (list): Non-constant arguments for the JAXPR. - if has_mcm and len(cpy_tape.measurements) > 0: - out = parse_native_mid_circuit_measurements( - cpy_tape, results=results, postselect_mode="pad-invalid-samples" - ) - if len(cpy_tape.measurements) == 1: - out = (out,) - elif len(cpy_tape.measurements) > 0: - out = _process_terminal_measurements( - mcm_config.mcm_method, cpy_tape, out, snapshots, shot_vector + Returns: + ClosedJaxpr: The decomposed JAXPR. + bool: A flag indicating whether the graph-based decomposition was successful. + """ + gds_interpreter = DecompRuleInterpreter(**tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(inner_jaxpr, consts, *args) + + graph_succeeded = True + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always", UserWarning) + final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) + + for w in captured_warnings: + warnings.showwarning(w.message, w.category, w.filename, w.lineno) + # TODO: use a custom warning class for this in PennyLane to remove this + # string matching and make it more robust. + if "The graph-based decomposition system is unable" in str(w.message): # pragma: no cover + graph_succeeded = False + warnings.warn( + "Falling back to the legacy decomposition system.", + UserWarning, ) - ctx = OutputContext( - cpy_tape=cpy_tape, - classical_values=classical_values, - classical_return_indices=classical_return_indices, - out_tree_expected=out_tree_expected, - snapshots=snapshots, - shot_vector=shot_vector, - num_mcm=num_mcm, - ) + return final_jaxpr, graph_succeeded - return _finalize_output(out, ctx) - return one_shot_wrapper +def _get_operator_name(op): + """Get the name of a pennylane operator, handling wrapped operators. + Note: Controlled and Adjoint ops aren't supported in `gate_set` + by PennyLane's DecompositionGraph; unit tests were added in PennyLane. + """ + if isinstance(op, str): + return op -def _extract_passes(transform_program): - tape_transforms = [] - pass_pipeline = [] - for t in transform_program: - if t.pass_name: - pass_pipeline.append(Pass(t.pass_name, *t.args, **t.kwargs)) - else: - tape_transforms.append(t) - return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) \ No newline at end of file + # Return NoNameOp if the operator has no _primitive.name attribute. + # This is to avoid errors when we capture the program + # as we deal with such ops later in the decomposition graph. + return getattr(op._primitive, "name", "NoNameOp") diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index 25b6cb90d8..29c3aade1a 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -17,6 +17,7 @@ # pylint: disable=protected-access import textwrap from copy import copy +from functools import partial import jax import jax.numpy as jnp @@ -66,6 +67,7 @@ from .qubit_handler import ( QubitHandler, QubitIndexRecorder, + _get_dynamically_allocated_qregs, get_in_qubit_values, is_dynamically_allocated_wire, ) @@ -156,7 +158,9 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w if any(not qreg.is_qubit_mode() and qreg.expired for qreg in in_qregs + in_ctrl_qregs): raise CompileError(f"Deallocated qubits cannot be used, but used in {op.name}.") - out_qubits = qinst_p.bind( + bind_fn = _special_op_bind_call.get(type(op), qinst_p.bind) + + out_qubits = bind_fn( *[*in_qubits, *op.data, *in_ctrl_qubits, *control_values], op=op.name, qubits_len=len(op.wires), @@ -164,7 +168,6 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w ctrl_len=len(control_wires), adjoint=is_adjoint, ) - out_non_ctrl_qubits = out_qubits[: len(out_qubits) - len(control_wires)] out_ctrl_qubits = out_qubits[-len(control_wires) :] @@ -275,6 +278,27 @@ def __call__(self, jaxpr, *args): return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) +# pylint: disable=unused-argument +def _qubit_unitary_bind_call(*invals, op, qubits_len, params_len, ctrl_len, adjoint): + wires = invals[:qubits_len] + mat = invals[qubits_len] + ctrl_inputs = invals[qubits_len + 1 :] + return unitary_p.bind( + mat, *wires, *ctrl_inputs, qubits_len=qubits_len, ctrl_len=ctrl_len, adjoint=adjoint + ) + + +# pylint: disable=unused-argument +def _gphase_bind_call(*invals, op, qubits_len, params_len, ctrl_len, adjoint): + return gphase_p.bind(*invals[qubits_len:], ctrl_len=ctrl_len, adjoint=adjoint) + + +_special_op_bind_call = { + qml.QubitUnitary: _qubit_unitary_bind_call, + qml.GlobalPhase: _gphase_bind_call, +} + + # pylint: disable=unused-argument @PLxPRToQuantumJaxprInterpreter.register_primitive(qml.allocation.allocate_prim) def handle_qml_alloc(self, *, num_wires, state=None, restored=False): @@ -316,22 +340,73 @@ def interpret_counts(self, *wires, all_outcomes): return keys, vals +def _subroutine_kernel( + interpreter, + jaxpr, + *qregs_plus_args, + outer_dynqreg_handlers=(), + wire_label_arg_to_tracer_arg_index=(), + wire_to_owner_qreg=(), +): + global_qreg, *dynqregs_plus_args = qregs_plus_args + num_dynamic_alloced_qregs = len(outer_dynqreg_handlers) + dynalloced_qregs, args = ( + dynqregs_plus_args[:num_dynamic_alloced_qregs], + dynqregs_plus_args[num_dynamic_alloced_qregs:], + ) + + # Launch a new interpreter for the body region + # A new interpreter's root qreg value needs a new recorder + converter = copy(interpreter) + converter.qubit_index_recorder = QubitIndexRecorder() + init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder) + converter.init_qreg = init_qreg + + # add dynamic qregs to recorder + qreg_map = {} + dyn_qreg_handlers = [] + arg_to_qreg = {} + for dyn_qreg, outer_dynqreg_handler in zip( + dynalloced_qregs, outer_dynqreg_handlers, strict=True + ): + dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder) + dyn_qreg_handlers.append(dyn_qreg_handler) + + # plxpr global wire index does not change across scopes + # So scope arg dynamic qregs need to have the same root hash as their corresponding + # qreg tracers outside + dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash + + # Each qreg argument of the subscope corresponds to a qreg from the outer scope + qreg_map[outer_dynqreg_handler] = dyn_qreg_handler + + for global_idx, arg_idx in wire_label_arg_to_tracer_arg_index.items(): + arg_to_qreg[args[arg_idx]] = qreg_map[wire_to_owner_qreg[global_idx]] + + # The new interpreter's recorder needs to be updated to include the qreg args + # of this scope, instead of the outer qregs + for arg in args: + if arg in arg_to_qreg: + converter.qubit_index_recorder[arg] = arg_to_qreg[arg] + + retvals = converter(jaxpr, *args) + + init_qreg.insert_all_dangling_qubits() + + # Return all registers + for dyn_qreg_handler in reversed(dyn_qreg_handlers): + dyn_qreg_handler.insert_all_dangling_qubits() + retvals.insert(0, dyn_qreg_handler.get()) + + return converter.init_qreg.get(), *retvals + + @PLxPRToQuantumJaxprInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): """ Transform the subroutine from PLxPR into JAXPR with quantum primitives. """ - if any(is_dynamically_allocated_wire(arg) for arg in args): - raise NotImplementedError( - textwrap.dedent( - """ - Dynamically allocated wires in a parent scope cannot be used in a child - scope yet. Please consider dynamical allocation inside the child scope. - """ - ) - ) - backup = dict(self.init_qreg) self.init_qreg.insert_all_dangling_qubits() @@ -339,20 +414,34 @@ def handle_subroutine(self, *args, **kwargs): plxpr = kwargs["jaxpr"] transformed = self.subroutine_cache.get(plxpr) - def wrapper(qreg, *args): - # Launch a new interpreter for the new subroutine region - # A new interpreter's root qreg value needs a new recorder - converter = copy(self) - converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) - converter.init_qreg = init_qreg - - retvals = converter(plxpr, *args) - converter.init_qreg.insert_all_dangling_qubits() - return converter.init_qreg.get(), *retvals + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + args, self.qubit_index_recorder, self.init_qreg + ) + wire_to_owner_qreg = dict(zip(dynalloced_wire_global_indices, dynalloced_qregs)) + dynalloced_qregs = list(dict.fromkeys(dynalloced_qregs)) # squash duplicates + + # Convert global wire indices into local indices + new_args = () + wire_label_arg_to_tracer_arg_index = {} + for i, arg in enumerate(args): + if arg in dynalloced_wire_global_indices: + wire_label_arg_to_tracer_arg_index[arg] = i + new_args += (self.qubit_index_recorder[arg].global_index_to_local_index(arg),) + else: + new_args += (arg,) if not transformed: - converted_closed_jaxpr_branch = jax.make_jaxpr(wrapper)(self.init_qreg.get(), *args) + f = partial( + _subroutine_kernel, + self, + plxpr, + outer_dynqreg_handlers=dynalloced_qregs, + wire_label_arg_to_tracer_arg_index=wire_label_arg_to_tracer_arg_index, + wire_to_owner_qreg=wire_to_owner_qreg, + ) + converted_closed_jaxpr_branch = jax.make_jaxpr(f)( + self.init_qreg.get(), *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], *args + ) self.subroutine_cache[plxpr] = converted_closed_jaxpr_branch else: converted_closed_jaxpr_branch = transformed @@ -361,12 +450,13 @@ def wrapper(qreg, *args): # is just pjit_p with a different name. vals_out = quantum_subroutine_p.bind( self.init_qreg.get(), - *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + *new_args, jaxpr=converted_closed_jaxpr_branch, - in_shardings=(UNSPECIFIED, *kwargs["in_shardings"]), - out_shardings=(UNSPECIFIED, *kwargs["out_shardings"]), - in_layouts=(None, *kwargs["in_layouts"]), - out_layouts=(None, *kwargs["out_layouts"]), + in_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["in_shardings"]), + out_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["out_shardings"]), + in_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["in_layouts"]), + out_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["out_layouts"]), donated_invars=kwargs["donated_invars"], ctx_mesh=kwargs["ctx_mesh"], name=kwargs["name"], @@ -376,7 +466,9 @@ def wrapper(qreg, *args): ) self.init_qreg.set(vals_out[0]) - vals_out = vals_out[1:] + for i, dyn_qreg in enumerate(dynalloced_qregs): + dyn_qreg.set(vals_out[i + 1]) + vals_out = vals_out[len(dynalloced_qregs) + 1 :] for orig_wire in backup.keys(): self.init_qreg.extract(orig_wire) @@ -441,22 +533,6 @@ def wrapper(*args): return () -@PLxPRToQuantumJaxprInterpreter.register_primitive(qml.QubitUnitary._primitive) -def handle_qubit_unitary(self, *invals, n_wires): - """Handle the conversion from plxpr to Catalyst jaxpr for the QubitUnitary primitive""" - in_qregs, in_qubits = get_in_qubit_values(invals[1:], self.qubit_index_recorder, self.init_qreg) - outvals = unitary_p.bind(invals[0], *in_qubits, qubits_len=n_wires, ctrl_len=0, adjoint=False) - for in_qreg, w, new_wire in zip(in_qregs, invals[1:], outvals): - in_qreg[in_qreg.global_index_to_local_index(w)] = new_wire - - -# pylint: disable=unused-argument -@PLxPRToQuantumJaxprInterpreter.register_primitive(qml.GlobalPhase._primitive) -def handle_global_phase(self, phase, *wires, n_wires): - """Handle the conversion from plxpr to Catalyst jaxpr for the GlobalPhase primitive""" - gphase_p.bind(phase, ctrl_len=0, adjoint=False) - - @PLxPRToQuantumJaxprInterpreter.register_primitive(qml.BasisState._primitive) def handle_basis_state(self, *invals, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the BasisState primitive""" @@ -565,6 +641,12 @@ def handle_adjoint_transform( n_consts, ): """Handle the conversion from plxpr to Catalyst jaxpr for the adjoint primitive""" + + if any(is_dynamically_allocated_wire(arg) for arg in plxpr_invals): + raise NotImplementedError( + "Dynamically allocated wires cannot be used in quantum adjoints yet." + ) + assert jaxpr is not None consts = plxpr_invals[:n_consts] args = plxpr_invals[n_consts:] @@ -590,13 +672,16 @@ def calling_convention(*args_plus_qreg): init_qreg.insert_all_dangling_qubits() return *retvals, converter.init_qreg.get() - _, args_tree = tree_flatten((consts, args, [qreg])) - converted_jaxpr_branch = jax.make_jaxpr(calling_convention)(*consts, *args, qreg).jaxpr + converted_jaxpr_branch = jax.make_jaxpr(calling_convention)(*args, qreg) - converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) + converted_closed_jaxpr_branch = ClosedJaxpr( + convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), () + ) + new_consts = converted_jaxpr_branch.consts + _, args_tree = tree_flatten((new_consts, args, [qreg])) # Perform the binding outvals = adjoint_p.bind( - *consts, + *new_consts, *args, qreg, jaxpr=converted_closed_jaxpr_branch, diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py index f9adb55e6c..b7a54b76f2 100644 --- a/frontend/catalyst/from_plxpr/qubit_handler.py +++ b/frontend/catalyst/from_plxpr/qubit_handler.py @@ -68,8 +68,6 @@ qubit SSA values on its wires? """ -import textwrap - from catalyst.jax_extras import DynamicJaxprTracer from catalyst.jax_primitives import AbstractQbit, AbstractQreg, qextract_p, qinsert_p from catalyst.utils.exceptions import CompileError @@ -422,21 +420,6 @@ def get_in_qubit_values( if not qubit_index_recorder.contains(w): # First time the global wire index w is encountered # Need to extract from fallback qreg - # TODO: this can now only be from the global qreg, because right now in from_plxpr - # conversion, subscopes (control flow, adjoint, ...) can only take in the global - # qreg as the final scope argument. They cannot take an arbitrary number of qreg - # values yet. - # Supporting multiple registers requires refactoring the from_plxpr conversion's - # implementation. - if is_dynamically_allocated_wire(w): - raise NotImplementedError( - textwrap.dedent( - """ - Dynamically allocated wires in a parent scope cannot be used in a child - scope yet. Please consider dynamical allocation inside the child scope. - """ - ) - ) in_qubits.append(fallback_qreg[fallback_qreg.global_index_to_local_index(w)]) in_qregs.append(fallback_qreg) @@ -446,3 +429,28 @@ def get_in_qubit_values( in_qubits.append(in_qreg[in_qreg.global_index_to_local_index(w)]) return in_qregs, in_qubits + + +def _get_dynamically_allocated_qregs(plxpr_invals, qubit_index_recorder, init_qreg): + """ + Get the potential dynamically allocated register values that are visible to a jaxpr. + + Note that dynamically allocated wires have their qreg tracer's id as the global wire index + so the sub jaxpr takes that id in as a "const", since it is closure from the target wire + of gates/measurements/... + We need to remove that const, so we also let this util return these global indices. + """ + dynalloced_qregs = [] + dynalloced_wire_global_indices = [] + for inval in plxpr_invals: + if ( + isinstance(inval, int) + and qubit_index_recorder.contains(inval) + and qubit_index_recorder[inval] is not init_qreg + ): + dyn_qreg = qubit_index_recorder[inval] + dyn_qreg.insert_all_dangling_qubits() + dynalloced_qregs.append(dyn_qreg) + dynalloced_wire_global_indices.append(inval) + + return dynalloced_qregs, dynalloced_wire_global_indices From 3b2c5b85ccd6090f863082ac5ab44993f379f35a Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 13 Nov 2025 14:56:25 -0500 Subject: [PATCH 05/20] more polishing --- frontend/catalyst/from_plxpr/from_plxpr.py | 87 ++++++++++++---------- frontend/catalyst/jax_primitives_utils.py | 3 +- frontend/catalyst/qfunc.py | 28 ++++--- 3 files changed, 69 insertions(+), 49 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 86720baf5a..d19f6e8006 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -273,6 +273,47 @@ def register_transform(pl_transform, pass_name, decomposition): transforms_to_passes[pl_transform] = (pass_name, decomposition) +def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, targs, tkwargs): + if not self.requires_decompose_lowering: + self.requires_decompose_lowering = True + else: + raise NotImplementedError("Multiple decomposition transforms are not yet supported.") + + next_eval = copy(self) + # Update the decompose_gateset to be used by the quantum kernel primitive + # TODO: we originally wanted to treat decompose_gateset as a queue of + # gatesets to be used by the decompose-lowering pass at MLIR + # but this requires a C++ implementation of the graph-based decomposition + # which doesn't exist yet. + next_eval.decompose_tkwargs = tkwargs + + # Note. We don't perform the compiler-specific decomposition here + # to be able to support multiple decomposition transforms + # and collect all the required gatesets + # as well as being able to support other transforms in between. + + # The compiler specific transformation will be performed + # in the qnode handler. + + # Add the decompose-lowering pass to the start of the pipeline + t = qml.transform(pass_name="decompose-lowering") + pass_container = qml.transforms.core.TransformContainer(t, args=targs, kwargs=tkwargs) + next_eval._pass_pipeline.insert(0, pass_container) + + # We still need to construct and solve the graph based on + # the current jaxpr based on the current gateset + # but we don't rewrite the jaxpr at this stage. + + # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) + + # def gds_wrapper(*args): + # return gds_interpreter.eval(inner_jaxpr, consts, *args) + + # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) + return next_eval.eval(inner_jaxpr, consts, *non_const_args) + + # pylint: disable=too-many-arguments @WorkflowInterpreter.register_primitive(transform_prim) def handle_transform( @@ -298,44 +339,13 @@ def handle_transform( and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): - if not self.requires_decompose_lowering: - self.requires_decompose_lowering = True - else: - raise NotImplementedError("Multiple decomposition transforms are not yet supported.") - - next_eval = copy(self) - # Update the decompose_gateset to be used by the quantum kernel primitive - # TODO: we originally wanted to treat decompose_gateset as a queue of - # gatesets to be used by the decompose-lowering pass at MLIR - # but this requires a C++ implementation of the graph-based decomposition - # which doesn't exist yet. - next_eval.decompose_tkwargs = tkwargs - - # Note. We don't perform the compiler-specific decomposition here - # to be able to support multiple decomposition transforms - # and collect all the required gatesets - # as well as being able to support other transforms in between. - - # The compiler specific transformation will be performed - # in the qnode handler. - - # Add the decompose-lowering pass to the start of the pipeline - next_eval._pass_pipeline.insert(0, Pass("decompose-lowering")) - - # We still need to construct and solve the graph based on - # the current jaxpr based on the current gateset - # but we don't rewrite the jaxpr at this stage. - - # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) - - # def gds_wrapper(*args): - # return gds_interpreter.eval(inner_jaxpr, consts, *args) - - # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) - # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) - return next_eval.eval(inner_jaxpr, consts, *non_const_args) + return _handle_decompose_transform( + self, inner_jaxpr, consts, non_const_args, targs, tkwargs + ) - catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0] + catalyst_pass_name = transform.pass_name + if catalyst_pass_name is None: + catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0] if catalyst_pass_name is None: # Use PL's ExpandTransformsInterpreter to expand this and any embedded # transform according to PL rules. It works by overriding the primitive @@ -357,7 +367,8 @@ def wrapper(*args): # Apply the corresponding Catalyst pass counterpart next_eval = copy(self) - next_eval._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs)) + bound_pass = qml.transforms.core.TransformContainer(transform, args=targs, kwargs=tkwargs) + next_eval._pass_pipeline.insert(0, bound_pass) return next_eval.eval(inner_jaxpr, consts, *non_const_args) diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 0a50d0aad6..d9e753efe2 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -327,6 +327,7 @@ def _lowered_options(kwargs): lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value) return lowered_options + def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline): """Generate a transform module embedded in the current module and schedule the transformations in pipeline""" @@ -394,7 +395,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin is_xdsl_pass, ) - if is_xdsl_pass(_pass.name): + if is_xdsl_pass(name): uses_xdsl_passes = True apply_registered_pass_op.operation.attributes["catalyst.xdsl_pass"] = ( ir.UnitAttr.get() diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index fab0d2e090..dbca8c649b 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -43,7 +43,7 @@ from catalyst.jax_primitives import quantum_kernel_p from catalyst.jax_tracer import Function, trace_quantum_function from catalyst.logging import debug_logger -from catalyst.passes.pass_api import dictionary_to_list_of_passes, Pass +from catalyst.passes.pass_api import Pass, dictionary_to_list_of_passes from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import filter_static_args from catalyst.utils.exceptions import CompileError @@ -284,12 +284,12 @@ def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) new_transform_program, new_pipeline = _extract_passes(self.transform_program) - # Update the qnode with peephole pipeline - pass_pipeline = kwargs.pop("pass_pipeline", []) + new_pipeline + old_pipeline = kwargs.pop("pass_pipeline", ()) or () + pass_pipeline = old_pipeline + new_pipeline pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) new_qnode = copy(self) - new_qnode._transform_program = new_transform_program # pylint: disable=protected-access + new_qnode._transform_program = new_transform_program # pylint: disable=protected-access # Mid-circuit measurement configuration/execution fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs) @@ -656,11 +656,19 @@ def wrap_single_shot_qnode(*_): def _extract_passes(transform_program): + """Extract transforms with pass names from the end of the TransformProgram.""" tape_transforms = [] pass_pipeline = [] - for t in transform_program: - if t.pass_name: - pass_pipeline.append(Pass(t.pass_name, *t.args, **t.kwargs)) - else: - tape_transforms.append(t) - return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) \ No newline at end of file + i = len(transform_program) + for t in reversed(transform_program): + if t.pass_name is None: + break + i -= 1 + pass_pipeline = transform_program[i:] + tape_transforms = transform_program[:i] + for t in tape_transforms: + if t.transform is None: + raise ValueError( + f"{t} without a tape definition occurs before tape transform {tape_transforms[-1]}." + ) + return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline) From e885b911592494f7d576194a60d192a812957ccb Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 14 Nov 2025 14:23:15 -0500 Subject: [PATCH 06/20] some test fixes --- frontend/catalyst/from_plxpr/from_plxpr.py | 9 +++++---- frontend/catalyst/qfunc.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 13336f288e..45361b945f 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -279,7 +279,7 @@ def register_transform(pl_transform, pass_name, decomposition): transforms_to_passes[pl_transform] = (pass_name, decomposition) -def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, targs, tkwargs): +def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs): if not self.requires_decompose_lowering: self.requires_decompose_lowering = True else: @@ -303,7 +303,7 @@ def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, targs # Add the decompose-lowering pass to the start of the pipeline t = qml.transform(pass_name="decompose-lowering") - pass_container = qml.transforms.core.TransformContainer(t, args=targs, kwargs=tkwargs) + pass_container = qml.transforms.core.TransformContainer(t) next_eval._pass_pipeline.insert(0, pass_container) # We still need to construct and solve the graph based on @@ -346,7 +346,7 @@ def handle_transform( and qml.decomposition.enabled_graph() ): return _handle_decompose_transform( - self, inner_jaxpr, consts, non_const_args, targs, tkwargs + self, inner_jaxpr, consts, non_const_args, tkwargs ) catalyst_pass_name = transform.pass_name @@ -373,7 +373,8 @@ def wrapper(*args): # Apply the corresponding Catalyst pass counterpart next_eval = copy(self) - bound_pass = qml.transforms.core.TransformContainer(transform, args=targs, kwargs=tkwargs) + t = qml.transform(pass_name=catalyst_pass_name) + bound_pass = qml.transforms.core.TransformContainer(t, args=targs, kwargs=tkwargs) next_eval._pass_pipeline.insert(0, bound_pass) return next_eval.eval(inner_jaxpr, consts, *non_const_args) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index dbca8c649b..ae63ff2c5f 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -285,7 +285,7 @@ def __call__(self, *args, **kwargs): new_transform_program, new_pipeline = _extract_passes(self.transform_program) # Update the qnode with peephole pipeline - old_pipeline = kwargs.pop("pass_pipeline", ()) or () + old_pipeline = tuple(kwargs.pop("pass_pipeline", ()) or ()) pass_pipeline = old_pipeline + new_pipeline pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) new_qnode = copy(self) From 5c99bbadc836987fb081200250a3e26f1ba05a67 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 17 Nov 2025 10:45:04 -0500 Subject: [PATCH 07/20] fix failing test --- frontend/catalyst/from_plxpr/from_plxpr.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 45361b945f..d8d060a64c 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -48,7 +48,6 @@ qdealloc_p, quantum_kernel_p, ) -from catalyst.passes.pass_api import Pass from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter from .qubit_handler import ( @@ -215,7 +214,9 @@ def handle_qnode( # Fallback to the legacy decomposition if the graph-based decomposition failed if not graph_succeeded: # Remove the decompose-lowering pass from the pipeline - self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"] + self._pass_pipeline = [ + p for p in self._pass_pipeline if p.pass_name != "decompose-lowering" + ] closed_jaxpr = _apply_compiler_decompose_to_plxpr( inner_jaxpr=closed_jaxpr.jaxpr, consts=closed_jaxpr.consts, @@ -345,9 +346,7 @@ def handle_transform( and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): - return _handle_decompose_transform( - self, inner_jaxpr, consts, non_const_args, tkwargs - ) + return _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs) catalyst_pass_name = transform.pass_name if catalyst_pass_name is None: From ffedfe4a05a8dbf18305f13ea5a1e7e0cece7bb0 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 17 Nov 2025 14:38:50 -0500 Subject: [PATCH 08/20] see if that fixes the failure --- frontend/catalyst/qfunc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index ae63ff2c5f..75a1d27018 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -285,9 +285,9 @@ def __call__(self, *args, **kwargs): new_transform_program, new_pipeline = _extract_passes(self.transform_program) # Update the qnode with peephole pipeline - old_pipeline = tuple(kwargs.pop("pass_pipeline", ()) or ()) - pass_pipeline = old_pipeline + new_pipeline - pass_pipeline = dictionary_to_list_of_passes(pass_pipeline) + old_pipeline = kwargs.pop("pass_pipeline") + processed_old_pipeline = tuple(dictionary_to_list_of_passes(old_pipeline)) + pass_pipeline = processed_old_pipeline + new_pipeline new_qnode = copy(self) new_qnode._transform_program = new_transform_program # pylint: disable=protected-access From 77a211d9b6aa91ce2c50c89cd508d08fb8a24b67 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 17 Nov 2025 15:00:34 -0500 Subject: [PATCH 09/20] oops --- frontend/catalyst/qfunc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 75a1d27018..944418a036 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -285,7 +285,7 @@ def __call__(self, *args, **kwargs): new_transform_program, new_pipeline = _extract_passes(self.transform_program) # Update the qnode with peephole pipeline - old_pipeline = kwargs.pop("pass_pipeline") + old_pipeline = kwargs.pop("pass_pipeline", None) processed_old_pipeline = tuple(dictionary_to_list_of_passes(old_pipeline)) pass_pipeline = processed_old_pipeline + new_pipeline new_qnode = copy(self) From a1abd24d02db5885997b58425876c6dd339c8f4a Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 17 Nov 2025 18:44:27 -0500 Subject: [PATCH 10/20] [skip ci] starting to test --- frontend/catalyst/qfunc.py | 2 +- .../test/lit/test_peephole_optimizations.py | 34 +++++++++++++++++++ .../pytest/test_peephole_optimizations.py | 30 +++++++++++++--- 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 944418a036..ea564ac043 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -43,7 +43,7 @@ from catalyst.jax_primitives import quantum_kernel_p from catalyst.jax_tracer import Function, trace_quantum_function from catalyst.logging import debug_logger -from catalyst.passes.pass_api import Pass, dictionary_to_list_of_passes +from catalyst.passes.pass_api import dictionary_to_list_of_passes from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import filter_static_args from catalyst.utils.exceptions import CompileError diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index 40a882e035..bffec18fe0 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -86,6 +86,40 @@ def test_pipeline_lowering_workflow(x): test_pipeline_lowering() +def test_transform_lowering(): + """ + Basic pipeline lowering on one qnode. + """ + @qjit(keep_intermediate=True) + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def test_pipeline_lowering_workflow(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) + print_jaxpr(test_pipeline_lowering_workflow, 1.2) + + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} + # CHECK-NEXT: transform.yield + print_mlir(test_pipeline_lowering_workflow, 1.2) + + # CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_0( + # CHECK: func.func public @test_pipeline_lowering_workflow_0( + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + test_pipeline_lowering_workflow(42.42) + flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_workflow) + + +test_transform_lowering() + def test_pipeline_lowering_keep_original(): """ diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index a99d9283ab..4c3809d9f6 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -38,7 +38,9 @@ ### Test peephole pass decorators preserve functionality of circuits ### @pytest.mark.parametrize("theta", [42.42]) -def test_cancel_inverses_functionality(theta, backend): +# should be able to get rid of catalyst.passes.cancel_inverses soon, but testing both for now. +@pytest.mark.parametrize("cancel_inverses_version", (cancel_inverses, qml.transforms.cancel_inverses)) +def test_cancel_inverses_functionality(theta, backend, cancel_inverses_version): def circuit(x): qml.RX(x, wires=0) @@ -50,14 +52,15 @@ def circuit(x): customized_device = qml.device(backend, wires=1) qjitted_workflow = qjit(qml.QNode(circuit, customized_device)) - optimized_workflow = qjit(cancel_inverses(qml.QNode(circuit, customized_device))) + optimized_workflow = qjit(cancel_inverses_version(qml.QNode(circuit, customized_device))) assert np.allclose(reference_workflow(theta), qjitted_workflow(theta)) assert np.allclose(reference_workflow(theta), optimized_workflow(theta)) @pytest.mark.parametrize("theta", [42.42]) -def test_merge_rotation_functionality(theta, backend): +@pytest.mark.parametrize("merge_rotations_version", (merge_rotations, qml.transforms.merge_rotations)) +def test_merge_rotation_functionality(theta, backend, merge_rotations_version): def circuit(x): qml.RX(x, wires=0) @@ -76,7 +79,7 @@ def circuit(x): customized_device = qml.device(backend, wires=1) qjitted_workflow = qjit(qml.QNode(circuit, customized_device)) - optimized_workflow = qjit(merge_rotations(qml.QNode(circuit, customized_device))) + optimized_workflow = qjit(merge_rotations_version(qml.QNode(circuit, customized_device))) assert np.allclose(reference_workflow(theta), qjitted_workflow(theta)) assert np.allclose(reference_workflow(theta), optimized_workflow(theta)) @@ -202,6 +205,25 @@ def test_chained_apply_passes_workflow(x: float): assert "remove-chained-self-inverse" in test_chained_apply_passes_workflow.mlir assert "merge-rotations" in test_chained_apply_passes_workflow.mlir +def test_chained_transforms(): + """ + Test that chained transforms are present in the transform passes. + """ + + @qjit + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def test_chained_apply_passes_workflow(x: float): + qml.Hadamard(wires=[1]) + qml.RX(x, wires=[0]) + qml.RX(-x, wires=[0]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + assert "remove-chained-self-inverse" in test_chained_apply_passes_workflow.mlir + assert "merge-rotations" in test_chained_apply_passes_workflow.mlir + def test_disentangle_passes(): """ From 8b045d00777f7a85bf15fc47c0fa129faa997dd1 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 19 Nov 2025 14:11:40 -0500 Subject: [PATCH 11/20] adding in some tests --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 - frontend/catalyst/jax_primitives_utils.py | 6 +- frontend/catalyst/qfunc.py | 3 +- .../from_plxpr/test_capture_integration.py | 27 ++++++- .../test/pytest/test_transform_pass_name.py | 80 +++++++++++++++++++ 5 files changed, 112 insertions(+), 6 deletions(-) create mode 100644 frontend/test/pytest/test_transform_pass_name.py diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 03e3f4915d..1fe4965edc 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -264,12 +264,10 @@ def calling_convention(*args): # otherwise their value will be None. The second value indicates if the transform # requires decomposition to be supported by Catalyst. transforms_to_passes = { - pl_cancel_inverses: ("cancel-inverses", False), pl_commute_controlled: (None, False), pl_decompose: (None, False), pl_map_wires: (None, False), pl_merge_amplitude_embedding: (None, True), - pl_merge_rotations: ("merge-rotations", False), pl_single_qubit_fusion: (None, False), pl_unitary_to_rot: (None, False), } diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index d9e753efe2..19a9cd05e8 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -320,8 +320,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.ctx.module_context = self.old_module_context -def _lowered_options(kwargs): +def _lowered_options(args, kwargs): lowered_options = {} + for arg in args: + lowered_options[str(arg)] = get_mlir_attribute_from_pyval(True) for option, value in kwargs.items(): mlir_option = str(option).replace("_", "-") lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value) @@ -375,7 +377,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin target = bb_named_sequence.arguments[0] for _pass in pipeline: if isinstance(_pass, qml.transforms.core.TransformContainer): - options = _lowered_options(_pass.kwargs) + options = _lowered_options(_pass.args, _pass.kwargs) name = _pass.pass_name else: options = _pass.get_options() diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 37a5b70b3e..b0af042dbe 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -291,7 +291,8 @@ def __call__(self, *args, **kwargs): processed_old_pipeline = tuple(dictionary_to_list_of_passes(old_pipeline)) pass_pipeline = processed_old_pipeline + new_pipeline new_qnode = copy(self) - new_qnode._transform_program = new_transform_program # pylint: disable=protected-access + # pylint: disable=attribute-defined-outside-init, protected-access + new_qnode._transform_program = new_transform_program # Mid-circuit measurement configuration/execution fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs, pass_pipeline) diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index ef6f9f370d..b0e3ede226 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1048,7 +1048,7 @@ def circuit(x: float): assert jnp.allclose(circuit(0.1), capture_result) @pytest.mark.usefixtures("use_capture") - def test_pass_with_options(self, backend): + def test_pass_with_options_patch(self, backend): """Test the integration for a circuit with a pass that takes in options.""" @qml.transform @@ -1058,6 +1058,31 @@ def my_pass(_tape, my_option=None, my_other_option=None): # pylint: disable=unu register_transform(my_pass, "my-pass", False) + @qjit(target="mlir") + @partial(my_pass, my_option="my_option_value", my_other_option=False) + @qml.qnode(qml.device(backend, wires=1)) + def captured_circuit(): + return qml.expval(qml.PauliZ(0)) + + capture_mlir = captured_circuit.mlir + assert 'transform.apply_registered_pass "my-pass"' in capture_mlir + assert ( + 'with options = {"my-option" = "my_option_value", "my-other-option" = false}' + in capture_mlir + ) + + @pytest.mark.usefixtures("use_capture") + def test_pass_with_options(self, backend): + """Test the integration for a circuit with a pass that takes in options.""" + + @qml.transform + def my_pass(_tape, my_option=None, my_other_option=None): # pylint: disable=unused-argument + """A dummy qml.transform.""" + return + + my_pass = qml.transform(pass_name="my-pass") + + @qjit(target="mlir") @partial(my_pass, my_option="my_option_value", my_other_option=False) @qml.qnode(qml.device(backend, wires=1)) diff --git a/frontend/test/pytest/test_transform_pass_name.py b/frontend/test/pytest/test_transform_pass_name.py new file mode 100644 index 0000000000..9c6bbed9ac --- /dev/null +++ b/frontend/test/pytest/test_transform_pass_name.py @@ -0,0 +1,80 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing use of transforms with pass name integrating with qnodes. +""" +import pytest +from functools import partial + +import pennylane as qml + + + +def test_pass_with_options(backend): + """Test the integration for a circuit with a pass that takes in options.""" + + my_pass = qml.transform(pass_name="my-pass") + + @qml.qjit(target="mlir") + @partial(my_pass, my_option="my_option_value", my_other_option=False) + @qml.qnode(qml.device(backend, wires=1)) + def captured_circuit(): + return qml.expval(qml.PauliZ(0)) + + capture_mlir = captured_circuit.mlir + assert 'transform.apply_registered_pass "my-pass"' in capture_mlir + assert ( + 'with options = {"my-option" = "my_option_value", "my-other-option" = false}' + in capture_mlir + ) + +def test_pass_before_tape_transform(backend): + """Test that provided an mlir-only transform prior to a tape transform raises an error.""" + + my_pass = qml.transform(pass_name="my-pass") + + @qml.transform + def tape_transform(tape): + return (tape, ), lambda x: x[0] + + @qml.qjit + @tape_transform + @my_pass + @qml.qnode(qml.device(backend, wires=1)) + def f(x): + return qml.state() + + with pytest.raises(ValueError, match="without a tape definition occurs before tape transform"): + f(0.5) + +def test_pass_after_tape_transform(backend): + """Test that passes can be applied after tape transforms.""" + + @qml.transform + def tape_only_cancel_inverses(tape): + return qml.transforms.cancel_inverses(tape) + + my_pass = qml.transform(pass_name="my-pass") + + @qml.qjit(target="mlir") + @my_pass + @tape_only_cancel_inverses + @qml.qnode(qml.device(backend, wires=1)) + def c(): + qml.X(0) + qml.X(0) + return qml.state() + + # check inverses canceled + assert 'quantum.custom "PauliX"()' not in c.mlir + assert 'transform.apply_registered_pass "my-pass"' in c.mlir \ No newline at end of file From ed528bafb04f3c74456ba03dc828167996433c84 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 19 Nov 2025 14:12:04 -0500 Subject: [PATCH 12/20] black and isort --- .../test/lit/test_peephole_optimizations.py | 2 ++ .../from_plxpr/test_capture_integration.py | 3 +-- .../test/pytest/test_peephole_optimizations.py | 9 +++++++-- .../test/pytest/test_transform_pass_name.py | 18 +++++++++--------- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index d9ec39c8cf..eac1009f86 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -86,10 +86,12 @@ def test_pipeline_lowering_workflow(x): test_pipeline_lowering() + def test_transform_lowering(): """ Basic pipeline lowering on one qnode. """ + @qjit(keep_intermediate=True) @qml.transforms.merge_rotations @qml.transforms.cancel_inverses diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index b0e3ede226..0447d41f44 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1079,9 +1079,8 @@ def test_pass_with_options(self, backend): def my_pass(_tape, my_option=None, my_other_option=None): # pylint: disable=unused-argument """A dummy qml.transform.""" return - - my_pass = qml.transform(pass_name="my-pass") + my_pass = qml.transform(pass_name="my-pass") @qjit(target="mlir") @partial(my_pass, my_option="my_option_value", my_other_option=False) diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index c6e7783866..9cb1ffa221 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -39,7 +39,9 @@ ### Test peephole pass decorators preserve functionality of circuits ### @pytest.mark.parametrize("theta", [42.42]) # should be able to get rid of catalyst.passes.cancel_inverses soon, but testing both for now. -@pytest.mark.parametrize("cancel_inverses_version", (cancel_inverses, qml.transforms.cancel_inverses)) +@pytest.mark.parametrize( + "cancel_inverses_version", (cancel_inverses, qml.transforms.cancel_inverses) +) def test_cancel_inverses_functionality(theta, backend, cancel_inverses_version): def circuit(x): @@ -59,7 +61,9 @@ def circuit(x): @pytest.mark.parametrize("theta", [42.42]) -@pytest.mark.parametrize("merge_rotations_version", (merge_rotations, qml.transforms.merge_rotations)) +@pytest.mark.parametrize( + "merge_rotations_version", (merge_rotations, qml.transforms.merge_rotations) +) def test_merge_rotation_functionality(theta, backend, merge_rotations_version): def circuit(x): @@ -206,6 +210,7 @@ def test_chained_apply_passes_workflow(x: float): assert "cancel-inverses" in mlir assert "merge-rotations" in mlir + def test_chained_transforms(): """ Test that chained transforms are present in the transform passes. diff --git a/frontend/test/pytest/test_transform_pass_name.py b/frontend/test/pytest/test_transform_pass_name.py index 9c6bbed9ac..54c9af0c2f 100644 --- a/frontend/test/pytest/test_transform_pass_name.py +++ b/frontend/test/pytest/test_transform_pass_name.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing use of transforms with pass name integrating with qnodes. -""" -import pytest +"""Testing use of transforms with pass name integrating with qnodes.""" from functools import partial import pennylane as qml - +import pytest def test_pass_with_options(backend): @@ -38,14 +36,15 @@ def captured_circuit(): in capture_mlir ) + def test_pass_before_tape_transform(backend): """Test that provided an mlir-only transform prior to a tape transform raises an error.""" my_pass = qml.transform(pass_name="my-pass") - + @qml.transform def tape_transform(tape): - return (tape, ), lambda x: x[0] + return (tape,), lambda x: x[0] @qml.qjit @tape_transform @@ -57,13 +56,14 @@ def f(x): with pytest.raises(ValueError, match="without a tape definition occurs before tape transform"): f(0.5) + def test_pass_after_tape_transform(backend): """Test that passes can be applied after tape transforms.""" @qml.transform def tape_only_cancel_inverses(tape): return qml.transforms.cancel_inverses(tape) - + my_pass = qml.transform(pass_name="my-pass") @qml.qjit(target="mlir") @@ -74,7 +74,7 @@ def c(): qml.X(0) qml.X(0) return qml.state() - + # check inverses canceled assert 'quantum.custom "PauliX"()' not in c.mlir - assert 'transform.apply_registered_pass "my-pass"' in c.mlir \ No newline at end of file + assert 'transform.apply_registered_pass "my-pass"' in c.mlir From ec38c77a9b0c38634d806dda557d1783f7d5f06e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Nov 2025 11:03:48 -0500 Subject: [PATCH 13/20] minor fixes --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 -- frontend/test/pytest/from_plxpr/test_capture_integration.py | 5 ----- frontend/test/pytest/test_peephole_optimizations.py | 2 +- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 1fe4965edc..1adb6e546a 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -30,11 +30,9 @@ from pennylane.capture.primitives import jacobian_prim as pl_jac_prim from pennylane.capture.primitives import transform_prim from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires -from pennylane.transforms import cancel_inverses as pl_cancel_inverses from pennylane.transforms import commute_controlled as pl_commute_controlled from pennylane.transforms import decompose as pl_decompose from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding -from pennylane.transforms import merge_rotations as pl_merge_rotations from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index 0447d41f44..b8daa65016 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1075,11 +1075,6 @@ def captured_circuit(): def test_pass_with_options(self, backend): """Test the integration for a circuit with a pass that takes in options.""" - @qml.transform - def my_pass(_tape, my_option=None, my_other_option=None): # pylint: disable=unused-argument - """A dummy qml.transform.""" - return - my_pass = qml.transform(pass_name="my-pass") @qjit(target="mlir") diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 9cb1ffa221..e21aa8dbbe 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -227,7 +227,7 @@ def test_chained_apply_passes_workflow(x: float): qml.Hadamard(wires=[1]) return qml.expval(qml.PauliY(wires=0)) - assert "remove-chained-self-inverse" in test_chained_apply_passes_workflow.mlir + assert "cancel-inverses" in test_chained_apply_passes_workflow.mlir assert "merge-rotations" in test_chained_apply_passes_workflow.mlir From 1cd4f182eca3459f4c82c6060a52acc2dcdeefb0 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Nov 2025 13:41:11 -0500 Subject: [PATCH 14/20] fix test --- frontend/test/lit/test_peephole_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index eac1009f86..3bb12c45fd 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -102,7 +102,7 @@ def test_pipeline_lowering_workflow(x): qml.Hadamard(wires=[1]) return qml.expval(qml.PauliY(wires=0)) - # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) + # CHECK: pipeline=(, ) print_jaxpr(test_pipeline_lowering_workflow, 1.2) # CHECK: transform.named_sequence @__transform_main From f197df9f0f50cc8cf4d239d6ecb040227af29740 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Nov 2025 14:02:17 -0500 Subject: [PATCH 15/20] try and fix this lit test yet again --- frontend/test/lit/test_peephole_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index 3bb12c45fd..85e254cff6 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -106,7 +106,7 @@ def test_pipeline_lowering_workflow(x): print_jaxpr(test_pipeline_lowering_workflow, 1.2) # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "cancel-inverses" to {{%.+}} # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(test_pipeline_lowering_workflow, 1.2) From 4a46934bf4c1cbffc6ef1838c8f6a8aca115a449 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 20 Nov 2025 14:16:15 -0500 Subject: [PATCH 16/20] Update frontend/test/pytest/test_transform_pass_name.py Co-authored-by: Paul <79805239+paul0403@users.noreply.github.com> --- frontend/test/pytest/test_transform_pass_name.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/test_transform_pass_name.py b/frontend/test/pytest/test_transform_pass_name.py index 54c9af0c2f..20b6796af0 100644 --- a/frontend/test/pytest/test_transform_pass_name.py +++ b/frontend/test/pytest/test_transform_pass_name.py @@ -76,5 +76,6 @@ def c(): return qml.state() # check inverses canceled - assert 'quantum.custom "PauliX"()' not in c.mlir - assert 'transform.apply_registered_pass "my-pass"' in c.mlir + c_mlir = c.mlir + assert 'quantum.custom "PauliX"()' not in c_mlir + assert 'transform.apply_registered_pass "my-pass"' in c_mlir From ce8478b167ccc9aaf4095d38998727748cd7d279 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 20 Nov 2025 14:16:24 -0500 Subject: [PATCH 17/20] Update frontend/test/pytest/test_transform_pass_name.py Co-authored-by: Paul <79805239+paul0403@users.noreply.github.com> --- frontend/test/pytest/test_transform_pass_name.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_transform_pass_name.py b/frontend/test/pytest/test_transform_pass_name.py index 20b6796af0..1028fad9f1 100644 --- a/frontend/test/pytest/test_transform_pass_name.py +++ b/frontend/test/pytest/test_transform_pass_name.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. +# Copyright 2025 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e7db76a4a06750142c78702837b0f93c69c2a6c3 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 20 Nov 2025 17:30:46 -0500 Subject: [PATCH 18/20] Apply suggestion from @albi3ro --- frontend/test/pytest/test_transform_pass_name.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_transform_pass_name.py b/frontend/test/pytest/test_transform_pass_name.py index 1028fad9f1..ca76aa398c 100644 --- a/frontend/test/pytest/test_transform_pass_name.py +++ b/frontend/test/pytest/test_transform_pass_name.py @@ -50,7 +50,7 @@ def tape_transform(tape): @tape_transform @my_pass @qml.qnode(qml.device(backend, wires=1)) - def f(x): + def f(x): # pylint: disable=unused-argument return qml.state() with pytest.raises(ValueError, match="without a tape definition occurs before tape transform"): From b6f3c8ce664e17e10c9a56544ed8a2ecd700c720 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 3 Dec 2025 09:56:45 -0500 Subject: [PATCH 19/20] update version, remove unnecessary test --- .dep-versions | 2 +- .github/workflows/check-catalyst.yaml | 12 ----------- doc/releases/changelog-dev.md | 4 ++++ doc/requirements.txt | 2 +- .../pytest/test_peephole_optimizations.py | 20 ------------------- .../test/pytest/test_transform_pass_name.py | 2 +- 6 files changed, 7 insertions(+), 35 deletions(-) diff --git a/.dep-versions b/.dep-versions index da9dfc2894..cfcf034381 100644 --- a/.dep-versions +++ b/.dep-versions @@ -10,7 +10,7 @@ enzyme=v0.0.203 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.44.0-dev42 +pennylane=0.44.0-dev44 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 2f97fec69d..999e5ad137 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -475,10 +475,6 @@ jobs: python3 -m pip install oqc-qcaas-client make frontend - - name: Install PennyLane branch - run: | - pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name - - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -562,10 +558,6 @@ jobs: python3 -m pip install -r requirements.txt make frontend - - name: Install PennyLane branch - run: | - pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name - - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -628,10 +620,6 @@ jobs: python3 -m pip install -r requirements.txt make frontend - - name: Install PennyLane branch - run: | - pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name - - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f4def776b0..028264e4fe 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -70,6 +70,10 @@

Improvements 🛠

+* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now + be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`. + [(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149) + * A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of :func:`~.qjit`. This option saves intermediate IR files after each pass, but only when the IR is actually modified by the pass. diff --git a/doc/requirements.txt b/doc/requirements.txt index 4480d3fc1a..4937e0d28f 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -33,4 +33,4 @@ lxml_html_clean --extra-index-url https://test.pypi.org/simple/ pennylane-lightning-kokkos==0.44.0-dev16 pennylane-lightning==0.44.0-dev16 -pennylane==0.44.0-dev42 +pennylane==0.44.0-dev44 diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 8a52f89207..334f14d63e 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -211,26 +211,6 @@ def test_chained_apply_passes_workflow(x: float): assert "merge-rotations" in mlir -def test_chained_transforms(): - """ - Test that chained transforms are present in the transform passes. - """ - - @qjit - @qml.transforms.merge_rotations - @qml.transforms.cancel_inverses - @qml.qnode(qml.device("lightning.qubit", wires=2)) - def test_chained_apply_passes_workflow(x: float): - qml.Hadamard(wires=[1]) - qml.RX(x, wires=[0]) - qml.RX(-x, wires=[0]) - qml.Hadamard(wires=[1]) - return qml.expval(qml.PauliY(wires=0)) - - assert "cancel-inverses" in test_chained_apply_passes_workflow.mlir - assert "merge-rotations" in test_chained_apply_passes_workflow.mlir - - def test_disentangle_passes(): """ Test that disentangle passes are present in the transform passes diff --git a/frontend/test/pytest/test_transform_pass_name.py b/frontend/test/pytest/test_transform_pass_name.py index ca76aa398c..14cd2a1586 100644 --- a/frontend/test/pytest/test_transform_pass_name.py +++ b/frontend/test/pytest/test_transform_pass_name.py @@ -50,7 +50,7 @@ def tape_transform(tape): @tape_transform @my_pass @qml.qnode(qml.device(backend, wires=1)) - def f(x): # pylint: disable=unused-argument + def f(x): # pylint: disable=unused-argument return qml.state() with pytest.raises(ValueError, match="without a tape definition occurs before tape transform"): From ce518040c53cb780203b180817fd48b36a1139ca Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 3 Dec 2025 11:32:29 -0500 Subject: [PATCH 20/20] fix test failure --- frontend/test/lit/test_decomposition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 0e7b6920dd..59a850bc48 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -46,6 +46,7 @@ def wrapper(): error_msg = str(e) if ( "Unsupported type annotation None for parameter pauli_word" in error_msg + or "Unsupported type annotation for parameter pauli_word" in error_msg or "index is out of bounds for axis" in error_msg ): print(f"# SKIPPED {test_func.__name__}: PauliRot type annotation issue")