Skip to content

Commit 593869a

Browse files
Add get_constant_from_ssa function for getting the concrete value of an xDSL variable (#8514)
There are currently two ways of defining constant values - `arith.constant`, and `stablehlo.constant` - and we need to deal with them in different ways when trying to extract the concrete value. This PR adds a function called `get_constant_from_ssa`, which takes an SSA value as input, and returns its concrete value if it is a constant, else `None`. Notes: * Added a `utils.py` file to the `python_compiler` submodule * The function can currently only extract scalar values, and explicitly returns `None` if the value is not scalar, even if it is a constant. [sc-100957] --------- Co-authored-by: Andrija Paurevic <[email protected]>
1 parent 5a2f1d4 commit 593869a

File tree

3 files changed

+202
-0
lines changed

3 files changed

+202
-0
lines changed

doc/releases/changelog-dev.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212
<h3>Improvements 🛠</h3>
1313

14+
* A new `qml.compiler.python_compiler.utils` submodule has been added, containing general-purpose utilities for
15+
working with xDSL. This includes a function that extracts the concrete value of scalar, constant SSA values.
16+
[(#8514)](https://github.com/PennyLaneAI/pennylane/pull/8514)
17+
1418
* Added a keyword argument ``recursive`` to ``qml.transforms.cancel_inverses`` that enables
1519
recursive cancellation of nested pairs of mutually inverse gates. This makes the transform
1620
more powerful, because it can cancel larger blocks of inverse gates without having to scan
@@ -214,6 +218,7 @@ Astral Cai,
214218
Marcus Edwards,
215219
Lillian Frederiksen,
216220
Christina Lee,
221+
Mudit Pandey,
217222
Shuli Shu,
218223
Jay Soni,
219224
David Wierichs,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""General purpose utilities to use with xDSL."""
16+
17+
from numbers import Number
18+
19+
from xdsl.dialects.arith import ConstantOp as arithConstantOp
20+
from xdsl.dialects.builtin import ComplexType, ShapedType
21+
from xdsl.dialects.tensor import ExtractOp as tensorExtractOp
22+
from xdsl.ir import SSAValue
23+
24+
from .dialects.stablehlo import ConstantOp as hloConstantOp
25+
26+
27+
def get_constant_from_ssa(value: SSAValue) -> Number | None:
28+
"""Return the concrete value corresponding to an SSA value if it is a numerical constant.
29+
30+
.. note::
31+
32+
This function currently only returns constants if they are scalar. For non-scalar
33+
constants, ``None`` will be returned.
34+
35+
Args:
36+
value (xdsl.ir.SSAValue): the SSA value to check
37+
38+
Returns:
39+
Number or None: If the value corresponds to a constant, its concrete value will
40+
be returned, else ``None``.
41+
"""
42+
43+
# If the value has a shape, we can assume that it is not scalar. We check
44+
# this because constant-like operations can return container types. This includes
45+
# arith.constant, which may return containers, and stablehlo.constant, which
46+
# always returns a container.
47+
if not isinstance(value.type, ShapedType):
48+
owner = value.owner
49+
50+
if isinstance(owner, arithConstantOp):
51+
const_attr = owner.value
52+
return const_attr.value.data
53+
54+
# Constant-like operations can also create scalars by returning rank 0 tensors.
55+
# In this case, the owner of a scalar value should be a tensor.extract, which
56+
# uses the aforementioned rank 0 constant tensor as input.
57+
if isinstance(owner, tensorExtractOp):
58+
tensor_ = owner.tensor
59+
if (
60+
len(owner.indices) == 0
61+
and len(tensor_.type.shape) == 0
62+
and isinstance(tensor_.owner, (arithConstantOp, hloConstantOp))
63+
):
64+
dense_attr = tensor_.owner.value
65+
# We know that the tensor has shape (). Dense element attributes store
66+
# their data as a sequence. For a scalar, this will be a sequence with
67+
# a single element.
68+
val = dense_attr.get_values()[0]
69+
if isinstance(tensor_.type.element_type, ComplexType):
70+
# If the dtype is complex, the value will be a 2-tuple containing
71+
# the real and imaginary components of the number rather than a
72+
# Python complex number
73+
val = val[0] + 1j * val[1]
74+
75+
return val
76+
77+
return None
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for xDSL utilities."""
16+
17+
import pytest
18+
19+
pytestmark = pytest.mark.external
20+
xdsl = pytest.importorskip("xdsl")
21+
22+
# pylint: disable=wrong-import-position
23+
from xdsl.dialects import arith, builtin, tensor, test
24+
25+
from pennylane.compiler.python_compiler.dialects.stablehlo import ConstantOp as hloConstantOp
26+
from pennylane.compiler.python_compiler.utils import get_constant_from_ssa
27+
28+
29+
class TestGetConstantFromSSA:
30+
"""Unit tests for ``get_constant_from_ssa``."""
31+
32+
def test_non_constant(self):
33+
"""Test that ``None`` is returned if the input is not a constant."""
34+
val = test.TestOp(result_types=(builtin.Float64Type(),)).results[0]
35+
assert get_constant_from_ssa(val) is None
36+
37+
@pytest.mark.parametrize(
38+
"const, attr_type, dtype",
39+
[
40+
(11, builtin.IntegerAttr, builtin.IntegerType(64)),
41+
(5, builtin.IntegerAttr, builtin.IndexType()),
42+
(2.5, builtin.FloatAttr, builtin.Float64Type()),
43+
],
44+
)
45+
def test_scalar_constant_arith(self, const, attr_type, dtype):
46+
"""Test that constants created by ``arith.constant`` are returned correctly."""
47+
const_attr = attr_type(const, dtype)
48+
val = arith.ConstantOp(value=const_attr).results[0]
49+
50+
assert get_constant_from_ssa(val) == const
51+
52+
@pytest.mark.parametrize(
53+
"const, elt_type",
54+
[
55+
(11, builtin.IntegerType(64)),
56+
(9, builtin.IndexType()),
57+
(2.5, builtin.Float64Type()),
58+
(-1.1 + 2.3j, builtin.ComplexType(builtin.Float64Type())),
59+
],
60+
)
61+
@pytest.mark.parametrize("constant_op", [arith.ConstantOp, hloConstantOp])
62+
def test_scalar_constant_extracted_from_rank0_tensor(self, const, elt_type, constant_op):
63+
"""Test that constants created by ``stablehlo.constant`` are returned correctly."""
64+
data = const
65+
if isinstance(const, complex):
66+
# For complex numbers, the number must be split into a 2-tuple containing
67+
# the real and imaginary part when initializing a dense elements attr.
68+
data = (const.real, const.imag)
69+
70+
dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(
71+
type=builtin.TensorType(element_type=elt_type, shape=()),
72+
data=(data,),
73+
)
74+
tensor_ = constant_op(value=dense_attr).results[0]
75+
val = tensor.ExtractOp(tensor=tensor_, indices=[], result_type=elt_type).results[0]
76+
77+
assert get_constant_from_ssa(val) == const
78+
79+
def test_tensor_constant_arith(self):
80+
"""Test that ``None`` is returned if the input is a tensor created by ``arith.constant``."""
81+
dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(
82+
type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),
83+
data=(1, 2, 3),
84+
)
85+
val = arith.ConstantOp(value=dense_attr).results[0]
86+
87+
assert get_constant_from_ssa(val) is None
88+
89+
def test_tensor_constant_stablehlo(self):
90+
"""Test that ``None`` is returned if the input is a tensor created by ``stablehlo.constant``."""
91+
dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(
92+
type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),
93+
data=(1.0, 2.0, 3.0),
94+
)
95+
val = hloConstantOp(value=dense_attr).results[0]
96+
97+
assert get_constant_from_ssa(val) is None
98+
99+
def test_extract_scalar_from_constant_tensor_stablehlo(self):
100+
"""Test that ``None`` is returned if the input is a scalar constant, but it was extracted
101+
from a non-scalar constant."""
102+
# Index SSA value to be used for extracting a value from a tensor
103+
dummy_index = test.TestOp(result_types=(builtin.IndexType(),)).results[0]
104+
105+
dense_attr = builtin.DenseIntOrFPElementsAttr.from_list(
106+
type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),
107+
data=(1.0, 2.0, 3.0),
108+
)
109+
tensor_ = hloConstantOp(value=dense_attr).results[0]
110+
val = tensor.ExtractOp(
111+
tensor=tensor_, indices=[dummy_index], result_type=builtin.Float64Type()
112+
).results[0]
113+
# val is a value that we got by indexing into a constant tensor with rank >= 1
114+
assert isinstance(val.type, builtin.Float64Type)
115+
116+
assert get_constant_from_ssa(val) is None
117+
118+
119+
if __name__ == "__main__":
120+
pytest.main(["-x", __file__])

0 commit comments

Comments
 (0)