| 
 | 1 | +# Copyright 2025 Xanadu Quantum Technologies Inc.  | 
 | 2 | + | 
 | 3 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +# you may not use this file except in compliance with the License.  | 
 | 5 | +# You may obtain a copy of the License at  | 
 | 6 | + | 
 | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | + | 
 | 9 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +# See the License for the specific language governing permissions and  | 
 | 13 | +# limitations under the License.  | 
 | 14 | + | 
 | 15 | +"""Unit tests for xDSL utilities."""  | 
 | 16 | + | 
 | 17 | +import pytest  | 
 | 18 | + | 
 | 19 | +pytestmark = pytest.mark.external  | 
 | 20 | +xdsl = pytest.importorskip("xdsl")  | 
 | 21 | + | 
 | 22 | +# pylint: disable=wrong-import-position  | 
 | 23 | +from xdsl.dialects import arith, builtin, tensor, test  | 
 | 24 | + | 
 | 25 | +from pennylane.compiler.python_compiler.dialects.stablehlo import ConstantOp as hloConstantOp  | 
 | 26 | +from pennylane.compiler.python_compiler.utils import get_constant_from_ssa  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +class TestGetConstantFromSSA:  | 
 | 30 | +    """Unit tests for ``get_constant_from_ssa``."""  | 
 | 31 | + | 
 | 32 | +    def test_non_constant(self):  | 
 | 33 | +        """Test that ``None`` is returned if the input is not a constant."""  | 
 | 34 | +        val = test.TestOp(result_types=(builtin.Float64Type(),)).results[0]  | 
 | 35 | +        assert get_constant_from_ssa(val) is None  | 
 | 36 | + | 
 | 37 | +    @pytest.mark.parametrize(  | 
 | 38 | +        "const, attr_type, dtype",  | 
 | 39 | +        [  | 
 | 40 | +            (11, builtin.IntegerAttr, builtin.IntegerType(64)),  | 
 | 41 | +            (5, builtin.IntegerAttr, builtin.IndexType()),  | 
 | 42 | +            (2.5, builtin.FloatAttr, builtin.Float64Type()),  | 
 | 43 | +        ],  | 
 | 44 | +    )  | 
 | 45 | +    def test_scalar_constant_arith(self, const, attr_type, dtype):  | 
 | 46 | +        """Test that constants created by ``arith.constant`` are returned correctly."""  | 
 | 47 | +        const_attr = attr_type(const, dtype)  | 
 | 48 | +        val = arith.ConstantOp(value=const_attr).results[0]  | 
 | 49 | + | 
 | 50 | +        assert get_constant_from_ssa(val) == const  | 
 | 51 | + | 
 | 52 | +    @pytest.mark.parametrize(  | 
 | 53 | +        "const, elt_type",  | 
 | 54 | +        [  | 
 | 55 | +            (11, builtin.IntegerType(64)),  | 
 | 56 | +            (9, builtin.IndexType()),  | 
 | 57 | +            (2.5, builtin.Float64Type()),  | 
 | 58 | +            (-1.1 + 2.3j, builtin.ComplexType(builtin.Float64Type())),  | 
 | 59 | +        ],  | 
 | 60 | +    )  | 
 | 61 | +    @pytest.mark.parametrize("constant_op", [arith.ConstantOp, hloConstantOp])  | 
 | 62 | +    def test_scalar_constant_extracted_from_rank0_tensor(self, const, elt_type, constant_op):  | 
 | 63 | +        """Test that constants created by ``stablehlo.constant`` are returned correctly."""  | 
 | 64 | +        data = const  | 
 | 65 | +        if isinstance(const, complex):  | 
 | 66 | +            # For complex numbers, the number must be split into a 2-tuple containing  | 
 | 67 | +            # the real and imaginary part when initializing a dense elements attr.  | 
 | 68 | +            data = (const.real, const.imag)  | 
 | 69 | + | 
 | 70 | +        dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(  | 
 | 71 | +            type=builtin.TensorType(element_type=elt_type, shape=()),  | 
 | 72 | +            data=(data,),  | 
 | 73 | +        )  | 
 | 74 | +        tensor_ = constant_op(value=dense_attr).results[0]  | 
 | 75 | +        val = tensor.ExtractOp(tensor=tensor_, indices=[], result_type=elt_type).results[0]  | 
 | 76 | + | 
 | 77 | +        assert get_constant_from_ssa(val) == const  | 
 | 78 | + | 
 | 79 | +    def test_tensor_constant_arith(self):  | 
 | 80 | +        """Test that ``None`` is returned if the input is a tensor created by ``arith.constant``."""  | 
 | 81 | +        dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(  | 
 | 82 | +            type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),  | 
 | 83 | +            data=(1, 2, 3),  | 
 | 84 | +        )  | 
 | 85 | +        val = arith.ConstantOp(value=dense_attr).results[0]  | 
 | 86 | + | 
 | 87 | +        assert get_constant_from_ssa(val) is None  | 
 | 88 | + | 
 | 89 | +    def test_tensor_constant_stablehlo(self):  | 
 | 90 | +        """Test that ``None`` is returned if the input is a tensor created by ``stablehlo.constant``."""  | 
 | 91 | +        dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(  | 
 | 92 | +            type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),  | 
 | 93 | +            data=(1.0, 2.0, 3.0),  | 
 | 94 | +        )  | 
 | 95 | +        val = hloConstantOp(value=dense_attr).results[0]  | 
 | 96 | + | 
 | 97 | +        assert get_constant_from_ssa(val) is None  | 
 | 98 | + | 
 | 99 | +    def test_extract_scalar_from_constant_tensor_stablehlo(self):  | 
 | 100 | +        """Test that ``None`` is returned if the input is a scalar constant, but it was extracted  | 
 | 101 | +        from a non-scalar constant."""  | 
 | 102 | +        # Index SSA value to be used for extracting a value from a tensor  | 
 | 103 | +        dummy_index = test.TestOp(result_types=(builtin.IndexType(),)).results[0]  | 
 | 104 | + | 
 | 105 | +        dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(  | 
 | 106 | +            type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),  | 
 | 107 | +            data=(1.0, 2.0, 3.0),  | 
 | 108 | +        )  | 
 | 109 | +        tensor_ = hloConstantOp(value=dense_attr).results[0]  | 
 | 110 | +        val = tensor.ExtractOp(  | 
 | 111 | +            tensor=tensor_, indices=[dummy_index], result_type=builtin.Float64Type()  | 
 | 112 | +        ).results[0]  | 
 | 113 | +        # val is a value that we got by indexing into a constant tensor with rank >= 1  | 
 | 114 | +        assert isinstance(val.type, builtin.Float64Type)  | 
 | 115 | + | 
 | 116 | +        assert get_constant_from_ssa(val) is None  | 
 | 117 | + | 
 | 118 | + | 
 | 119 | +if __name__ == "__main__":  | 
 | 120 | +    pytest.main(["-x", __file__])  | 
0 commit comments