|  | 
|  | 1 | +# Copyright 2025 Xanadu Quantum Technologies Inc. | 
|  | 2 | + | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | + | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | + | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +"""Unit test module for pennylane/compiler/python_compiler/transform.py.""" | 
|  | 16 | + | 
|  | 17 | +from dataclasses import dataclass | 
|  | 18 | + | 
|  | 19 | +import pytest | 
|  | 20 | + | 
|  | 21 | +# pylint: disable=wrong-import-position | 
|  | 22 | + | 
|  | 23 | +xdsl = pytest.importorskip("xdsl") | 
|  | 24 | +filecheck = pytest.importorskip("filecheck") | 
|  | 25 | + | 
|  | 26 | +pytestmark = pytest.mark.external | 
|  | 27 | + | 
|  | 28 | +from xdsl import passes | 
|  | 29 | +from xdsl.context import Context | 
|  | 30 | +from xdsl.dialects import builtin | 
|  | 31 | +from xdsl.dialects.builtin import DictionaryAttr, IntegerAttr, i64 | 
|  | 32 | +from xdsl.dialects.transform import AnyOpType | 
|  | 33 | +from xdsl.utils.exceptions import VerifyException | 
|  | 34 | +from xdsl.utils.test_value import create_ssa_value | 
|  | 35 | + | 
|  | 36 | +from pennylane.compiler.python_compiler.dialects import transform | 
|  | 37 | +from pennylane.compiler.python_compiler.dialects.transform import ApplyRegisteredPassOp | 
|  | 38 | +from pennylane.compiler.python_compiler.jax_utils import xdsl_from_docstring | 
|  | 39 | +from pennylane.compiler.python_compiler.transforms.api import ( | 
|  | 40 | +    ApplyTransformSequence, | 
|  | 41 | +    compiler_transform, | 
|  | 42 | +) | 
|  | 43 | + | 
|  | 44 | + | 
|  | 45 | +def test_dict_options(): | 
|  | 46 | +    """Test ApplyRegisteredPassOp constructor with dict options.""" | 
|  | 47 | +    target = create_ssa_value(AnyOpType()) | 
|  | 48 | +    options = {"option1": 1, "option2": True} | 
|  | 49 | + | 
|  | 50 | +    op = ApplyRegisteredPassOp("canonicalize", target, options) | 
|  | 51 | + | 
|  | 52 | +    assert op.pass_name.data == "canonicalize" | 
|  | 53 | +    assert isinstance(op.options, DictionaryAttr) | 
|  | 54 | +    assert op.options == DictionaryAttr({"option1": 1, "option2": True}) | 
|  | 55 | +    assert op.verify_() is None | 
|  | 56 | + | 
|  | 57 | + | 
|  | 58 | +def test_attr_options(): | 
|  | 59 | +    """Test ApplyRegisteredPassOp constructor with DictionaryAttr options.""" | 
|  | 60 | +    target = create_ssa_value(AnyOpType()) | 
|  | 61 | +    options = DictionaryAttr({"test-option": IntegerAttr(42, i64)}) | 
|  | 62 | + | 
|  | 63 | +    # This should trigger the __init__ method | 
|  | 64 | +    op = ApplyRegisteredPassOp("canonicalize", target, options) | 
|  | 65 | + | 
|  | 66 | +    assert op.pass_name.data == "canonicalize" | 
|  | 67 | +    assert isinstance(op.options, DictionaryAttr) | 
|  | 68 | +    assert op.options == DictionaryAttr({"test-option": IntegerAttr(42, i64)}) | 
|  | 69 | +    assert op.verify_() is None | 
|  | 70 | + | 
|  | 71 | + | 
|  | 72 | +def test_none_options(): | 
|  | 73 | +    """Test ApplyRegisteredPassOp constructor with None options.""" | 
|  | 74 | +    target = create_ssa_value(AnyOpType()) | 
|  | 75 | + | 
|  | 76 | +    # This should trigger the __init__ method | 
|  | 77 | +    op = ApplyRegisteredPassOp("canonicalize", target, None) | 
|  | 78 | + | 
|  | 79 | +    assert op.pass_name.data == "canonicalize" | 
|  | 80 | +    assert isinstance(op.options, DictionaryAttr) | 
|  | 81 | +    assert op.options == DictionaryAttr({}) | 
|  | 82 | +    assert op.verify_() is None | 
|  | 83 | + | 
|  | 84 | + | 
|  | 85 | +def test_invalid_options(): | 
|  | 86 | +    """Test ApplyRegisteredPassOp constructor with invalid options type.""" | 
|  | 87 | +    target = create_ssa_value(AnyOpType()) | 
|  | 88 | + | 
|  | 89 | +    with pytest.raises( | 
|  | 90 | +        VerifyException, match="invalid_options should be of base attribute dictionary" | 
|  | 91 | +    ): | 
|  | 92 | +        ApplyRegisteredPassOp("canonicalize", target, "invalid_options").verify_() | 
|  | 93 | + | 
|  | 94 | + | 
|  | 95 | +def test_transform_dialect_filecheck(run_filecheck): | 
|  | 96 | +    """Test that the transform dialect operations are parsed correctly.""" | 
|  | 97 | +    program = """ | 
|  | 98 | +        "builtin.module"() ({ | 
|  | 99 | +            "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ | 
|  | 100 | +            ^bb0(%arg0: !transform.any_op): | 
|  | 101 | +                %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op | 
|  | 102 | +                // CHECK: options = {"invalid-option" = 1 : i64} | 
|  | 103 | +                %1 = "transform.apply_registered_pass"(%0) <{options = {"invalid-option" = 1 : i64}, pass_name = "canonicalize"}> : (!transform.any_op) -> !transform.any_op | 
|  | 104 | +                "transform.yield"() : () -> () | 
|  | 105 | +            }) : () -> () | 
|  | 106 | +        }) {transform.with_named_sequence} : () -> () | 
|  | 107 | +    """ | 
|  | 108 | + | 
|  | 109 | +    run_filecheck(program) | 
|  | 110 | + | 
|  | 111 | + | 
|  | 112 | +def test_integration_for_transform_interpreter(capsys): | 
|  | 113 | +    """Test that a pass with options is run via the transform interpreter""" | 
|  | 114 | + | 
|  | 115 | +    @compiler_transform | 
|  | 116 | +    @dataclass(frozen=True) | 
|  | 117 | +    class _HelloWorld(passes.ModulePass): | 
|  | 118 | +        name = "hello-world" | 
|  | 119 | + | 
|  | 120 | +        custom_print: str | None = None | 
|  | 121 | + | 
|  | 122 | +        def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: | 
|  | 123 | +            if self.custom_print: | 
|  | 124 | +                print(self.custom_print) | 
|  | 125 | +            else: | 
|  | 126 | +                print("hello world") | 
|  | 127 | + | 
|  | 128 | +    @xdsl_from_docstring | 
|  | 129 | +    def program(): | 
|  | 130 | +        """ | 
|  | 131 | +        builtin.module { | 
|  | 132 | +          builtin.module { | 
|  | 133 | +            transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) { | 
|  | 134 | +              %0 = "transform.apply_registered_pass"(%arg0) <{options = {"custom_print" = "Hello from custom option!"}, pass_name = "hello-world"}> : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> | 
|  | 135 | +              transform.yield | 
|  | 136 | +            } | 
|  | 137 | +          } | 
|  | 138 | +        } | 
|  | 139 | +        """ | 
|  | 140 | + | 
|  | 141 | +    ctx = xdsl.context.Context() | 
|  | 142 | +    ctx.load_dialect(builtin.Builtin) | 
|  | 143 | +    ctx.load_dialect(transform.Transform) | 
|  | 144 | + | 
|  | 145 | +    mod = program() | 
|  | 146 | +    pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) | 
|  | 147 | +    pipeline.apply(ctx, mod) | 
|  | 148 | + | 
|  | 149 | +    assert "Hello from custom option!" in capsys.readouterr().out | 
0 commit comments