diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py index 139738baff..e30588d7f7 100644 --- a/frontend/catalyst/passes/builtin_passes.py +++ b/frontend/catalyst/passes/builtin_passes.py @@ -17,9 +17,9 @@ import copy import functools import json +from pennylane import transform from catalyst.compiler import _options_to_cli_flags, _quantum_opt -from catalyst.passes.pass_api import PassPipelineWrapper from catalyst.utils.exceptions import CompileError # pylint: disable=line-too-long, too-many-lines @@ -136,7 +136,7 @@ def circuit(x: float): %2 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs %3 = quantum.expval %2 : f64 """ - return PassPipelineWrapper(qnode, "cancel-inverses") + return transform(pass_name="cancel-inverses")(qnode) def disentangle_cnot(qnode): @@ -225,7 +225,7 @@ def circuit(): %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit %out_qubits_0 = quantum.custom "PauliX"() %2 : !quantum.bit """ - return PassPipelineWrapper(qnode, "disentangle-CNOT") + return transform(pass_name="disentangle-CNOT")(qnode) def disentangle_swap(qnode): @@ -325,7 +325,7 @@ def circuit(): %out_qubits_2:2 = quantum.custom "CNOT"() %out_qubits_1, %out_qubits : !quantum.bit, !quantum.bit %out_qubits_3:2 = quantum.custom "CNOT"() %out_qubits_2#1, %out_qubits_2#0 : !quantum.bit, !quantum.bit """ - return PassPipelineWrapper(qnode, "disentangle-SWAP") + return transform(pass_name="disentangle-SWAP")(qnode) def merge_rotations(qnode): @@ -391,7 +391,7 @@ def circuit(x: float): >>> circuit(0.54) Array(0.5965506257017892, dtype=float64) """ - return PassPipelineWrapper(qnode, "merge-rotations") + return transform(pass_name="merge-rotations")(qnode) def decompose_lowering(qnode): @@ -410,7 +410,7 @@ def decompose_lowering(qnode): // TODO: add example here """ - return PassPipelineWrapper(qnode, "decompose-lowering") # pragma: no cover + return transform(pass_name="decompose-lowering")(qnode) def ions_decomposition(qnode): # pragma: nocover @@ -532,7 +532,7 @@ def circuit(): %out_qubits_8 = quantum.custom "RY"(%cst_2) %out_qubits_6#1 : !quantum.bit %out_qubits_9 = quantum.custom "RY"(%cst_2) %out_qubits_7 : !quantum.bit """ - return PassPipelineWrapper(qnode, "ions-decomposition") + return transform(pass_name="ions-decomposition")(qnode) def to_ppr(qnode): @@ -611,8 +611,7 @@ def circuit(): In the above output, ``PPR-theta-weight`` denotes the type of PPR present in the circuit, where ``theta`` is the PPR angle (:math:`\theta`) and ``weight`` is the PPR weight. """ - return PassPipelineWrapper(qnode, "to-ppr") - + return transform(pass_name="to-ppr")(qnode) def commute_ppr(qnode=None, *, max_pauli_size=0): R""" @@ -701,8 +700,7 @@ def circuit(): if qnode is None: return functools.partial(commute_ppr, max_pauli_size=max_pauli_size) - commute_ppr_pass = {"commute_ppr": {"max-pauli-size": max_pauli_size}} - return PassPipelineWrapper(qnode, commute_ppr_pass) + return transform(pass_name="commute-ppr")(qnode, max_pauli_size=max_pauli_size) def merge_ppr_ppm(qnode=None, *, max_pauli_size=0): @@ -782,8 +780,7 @@ def circuit(): if qnode is None: return functools.partial(merge_ppr_ppm, max_pauli_size=max_pauli_size) - merge_ppr_ppm_pass = {"merge_ppr_ppm": {"max-pauli-size": max_pauli_size}} - return PassPipelineWrapper(qnode, merge_ppr_ppm_pass) + return transform(pass_name="merge-ppr-ppm")(qnode, max_pauli_size=max_pauli_size) def ppr_to_ppm(qnode=None, *, decompose_method="pauli-corrected", avoid_y_measure=False): @@ -882,19 +879,13 @@ def circuit(): :math:`P(\tfrac{\pi}{2}) = \exp(-iP\tfrac{\pi}{2}) = P`. Pauli operators can be commuted to the end of the circuit and absorbed into terminal measurements. """ - passes = { - "ppr_to_ppm": { - "decompose-method": decompose_method, - "avoid-y-measure": avoid_y_measure, - }, - } if qnode is None: return functools.partial( ppr_to_ppm, decompose_method=decompose_method, avoid_y_measure=avoid_y_measure ) - return PassPipelineWrapper(qnode, passes) + return transform(pass_name="ppr-to-ppm")(qnode, decompose_method=decompose_method, avoid_y_measure=avoid_y_measure) def ppm_compilation( @@ -998,13 +989,6 @@ def circuit(): ``max_pauli_size`` qubits (here, ``max_pauli_size = 2``), that commutation or merge would be skipped. """ - passes = { - "ppm-compilation": { - "decompose-method": decompose_method, - "avoid-y-measure": avoid_y_measure, - "max-pauli-size": max_pauli_size, - } - } if qnode is None: return functools.partial( @@ -1014,8 +998,7 @@ def circuit(): max_pauli_size=max_pauli_size, ) - return PassPipelineWrapper(qnode, passes) - + return transform(pass_name="ppm-compilation")(qnode, decompose_method=decompose_method, avoid_y_measure=avoid_y_measure, max_pauli_size=max_pauli_size) def ppm_specs(fn): R""" @@ -1088,34 +1071,31 @@ def loop(i): . . . """ - - if fn.mlir_module is not None: - # aot mode - new_options = copy.copy(fn.compile_options) - if new_options.pipelines is None: - raise CompileError("No pipeline found") - - # add ppm-spec pass at the end to existing pipeline - _, pass_list = new_options.pipelines[0] # first pipeline runs the user passes - # check if ppm-specs is already in the pass list - if "ppm-specs" not in pass_list: # pragma: nocover - pass_list.append("ppm-specs") - - new_options = _options_to_cli_flags(new_options) - raw_result = _quantum_opt(*new_options, [], stdin=str(fn.mlir_module)) - - try: - return json.loads( - raw_result[: raw_result.index("module")] - ) # remove MLIR starting with substring "module..." - except Exception as e: # pragma: nocover - raise CompileError( - "Invalid json format encountered in ppm_specs. " - f"Expected valid JSON but got {raw_result[: raw_result.index('module')]}" - ) from e - - else: + if fn.mlir_module is None: raise NotImplementedError("PPM passes only support AOT (Ahead-Of-Time) compilation mode.") + # aot mode + new_options = copy.copy(fn.compile_options) + if new_options.pipelines is None: + raise CompileError("No pipeline found") + + # add ppm-spec pass at the end to existing pipeline + _, pass_list = new_options.pipelines[0] # first pipeline runs the user passes + # check if ppm-specs is already in the pass list + if "ppm-specs" not in pass_list: # pragma: nocover + pass_list.append("ppm-specs") + + new_options = _options_to_cli_flags(new_options) + raw_result = _quantum_opt(*new_options, [], stdin=str(fn.mlir_module)) + + try: + return json.loads( + raw_result[: raw_result.index("module")] + ) # remove MLIR starting with substring "module..." + except Exception as e: # pragma: nocover + raise CompileError( + "Invalid json format encountered in ppm_specs. " + f"Expected valid JSON but got {raw_result[: raw_result.index('module')]}" + ) from e def reduce_t_depth(qnode): @@ -1194,8 +1174,7 @@ def circuit(): %9:3 = qec.ppr ["X", "X", "Y"](8) %8#0, %8#1, %8#2:!quantum.bit, !quantum.bit, !quantum.bit . . . """ - - return PassPipelineWrapper(qnode, "reduce-t-depth") + return transform(pass_name="reduce-t-depth")(qnode) def ppr_to_mbqc(qnode): @@ -1284,4 +1263,4 @@ def circuit(): ... """ - return PassPipelineWrapper(qnode, "ppr-to-mbqc") + return transform(pass_name="ppr-to-mbqc")(qnode) diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py index 7c861d3ce3..a50fc4bef4 100644 --- a/frontend/catalyst/passes/pass_api.py +++ b/frontend/catalyst/passes/pass_api.py @@ -204,7 +204,7 @@ def module(): """ def decorator(qnode): - return PassPipelineWrapper(qnode, pass_name, *flags, **valued_options) + return qml.transform(pass_name=pass_name)(qnode, *flags, **valued_options) return decorator @@ -244,7 +244,7 @@ def module(): raise FileNotFoundError(f"File '{path_to_plugin}' does not exist.") def decorator(qnode): - return PassPipelineWrapper(qnode, pass_name, *flags, **valued_options) + return qml.transform(pass_name=pass_name)(qnode, *flags, **valued_options) return decorator diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 334f14d63e..cd8e9364f6 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -163,29 +163,6 @@ def classical_func(): ): pipeline({})(classical_func) - with pytest.raises( - TypeError, - match="A QNode is expected, got the classical function", - ): - cancel_inverses(classical_func) - - with pytest.raises( - TypeError, - match="A QNode is expected, got the classical function", - ): - merge_rotations(classical_func) - - with pytest.raises( - TypeError, - match="A QNode is expected, got the classical function", - ): - disentangle_cnot(classical_func) - - with pytest.raises( - TypeError, - match="A QNode is expected, got the classical function", - ): - disentangle_swap(classical_func) test_passes_not_on_qnode() @@ -211,6 +188,26 @@ def test_chained_apply_passes_workflow(x: float): assert "merge-rotations" in mlir +def test_chained_transforms(): + """ + Test that chained transforms are present in the transform passes. + """ + + @qjit + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def test_chained_apply_passes_workflow(x: float): + qml.Hadamard(wires=[1]) + qml.RX(x, wires=[0]) + qml.RX(-x, wires=[0]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + assert "cancel-inverses" in test_chained_apply_passes_workflow.mlir + assert "merge-rotations" in test_chained_apply_passes_workflow.mlir + + def test_disentangle_passes(): """ Test that disentangle passes are present in the transform passes