Skip to content

Commit aad6fb7

Browse files
committed
Deprecate pytensor.get_underlying_scalar_constant
1 parent a120dc2 commit aad6fb7

File tree

4 files changed

+36
-25
lines changed

4 files changed

+36
-25
lines changed

pytensor/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# pytensor code, since this code may want to log some messages.
2525
import logging
2626
import sys
27+
import warnings
2728
from functools import singledispatch
2829
from pathlib import Path
2930
from typing import Any, NoReturn, Optional
@@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v):
148149
If `v` is not some view of constant data, then raise a
149150
`NotScalarConstantError`.
150151
"""
151-
# Is it necessary to test for presence of pytensor.sparse at runtime?
152-
sparse = globals().get("sparse")
153-
if sparse and isinstance(v.type, sparse.SparseTensorType):
154-
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
155-
data = v.owner.inputs[0]
156-
return tensor.get_underlying_scalar_constant_value(data)
157-
return tensor.get_underlying_scalar_constant_value(v)
152+
warnings.warn(
153+
"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
154+
FutureWarning,
155+
)
156+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
157+
158+
return get_underlying_scalar_constant_value(v)
158159

159160

160161
# isort: off

pytensor/gradient.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,7 @@ def try_to_copy_if_needed(var):
13291329
f" {i}. Since this input is only connected "
13301330
"to integer-valued outputs, it should "
13311331
"evaluate to zeros, but it evaluates to"
1332-
f"{pytensor.get_underlying_scalar_constant(term)}."
1332+
f"{pytensor.get_underlying_scalar_constant_value(term)}."
13331333
)
13341334
raise ValueError(msg)
13351335

@@ -2157,6 +2157,9 @@ def _is_zero(x):
21572157
'maybe' means that x is an expression that is complicated enough
21582158
that we can't tell that it simplifies to 0.
21592159
"""
2160+
from pytensor.tensor import get_underlying_scalar_constant_value
2161+
from pytensor.tensor.exceptions import NotScalarConstantError
2162+
21602163
if not hasattr(x, "type"):
21612164
return np.all(x == 0.0)
21622165
if isinstance(x.type, NullType):
@@ -2166,9 +2169,9 @@ def _is_zero(x):
21662169

21672170
no_constant_value = True
21682171
try:
2169-
constant_value = pytensor.get_underlying_scalar_constant(x)
2172+
constant_value = get_underlying_scalar_constant_value(x)
21702173
no_constant_value = False
2171-
except pytensor.tensor.exceptions.NotScalarConstantError:
2174+
except NotScalarConstantError:
21722175
pass
21732176

21742177
if no_constant_value:

pytensor/tensor/basic.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
320320
321321
"""
322322
from pytensor.compile.ops import DeepCopyOp, OutputGuard
323+
from pytensor.sparse import CSM
324+
from pytensor.tensor.subtensor import Subtensor
323325

324326
v = orig_v
325327
while True:
@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
350352
raise NotScalarConstantError()
351353

352354
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
355+
op = v.owner.op
353356
max_recur -= 1
354357
if isinstance(
355-
v.owner.op,
356-
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
358+
op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
357359
):
358360
# OutputGuard is only used in debugmode but we
359361
# keep it here to avoid problems with old pickles
360362
v = v.owner.inputs[0]
361363
continue
362-
elif isinstance(v.owner.op, Shape_i):
364+
elif isinstance(op, Shape_i):
363365
i = v.owner.op.i
364366
inp = v.owner.inputs[0]
365367
if isinstance(inp, Constant):
@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
373375
# mess with the stabilization optimization and be too slow.
374376
# We put all the scalar Ops used by get_canonical_form_slice()
375377
# to allow it to determine the broadcast pattern correctly.
376-
elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar):
378+
elif isinstance(op, ScalarFromTensor | TensorFromScalar):
377379
v = v.owner.inputs[0]
378380
continue
379-
elif isinstance(v.owner.op, CheckAndRaise):
381+
elif isinstance(op, CheckAndRaise):
380382
# check if all conditions are constant and true
381383
conds = [
382384
get_underlying_scalar_constant_value(c, max_recur=max_recur)
@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
385387
if builtins.all(0 == c.ndim and c != 0 for c in conds):
386388
v = v.owner.inputs[0]
387389
continue
388-
elif isinstance(v.owner.op, ps.ScalarOp):
390+
elif isinstance(op, ps.ScalarOp):
389391
if isinstance(v.owner.op, ps.Second):
390392
# We don't need both input to be constant for second
391393
shp, val = v.owner.inputs
@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
402404
# In fast_compile, we don't enable local_fill_to_alloc, so
403405
# we need to investigate Second as Alloc. So elemwise
404406
# don't disable the check for Second.
405-
elif isinstance(v.owner.op, Elemwise):
407+
elif isinstance(op, Elemwise):
406408
if isinstance(v.owner.op.scalar_op, ps.Second):
407409
# We don't need both input to be constant for second
408410
shp, val = v.owner.inputs
@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
418420
ret = [[None]]
419421
v.owner.op.perform(v.owner, const, ret)
420422
return np.asarray(ret[0][0].copy())
421-
elif (
422-
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)
423-
and v.ndim == 0
424-
):
423+
elif isinstance(op, Subtensor) and v.ndim == 0:
425424
if isinstance(v.owner.inputs[0], TensorConstant):
426425
from pytensor.tensor.subtensor import get_constant_idx
427426

@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
545544

546545
if isinstance(grandparent, Constant):
547546
return np.asarray(np.shape(grandparent.data)[idx])
547+
elif isinstance(op, CSM):
548+
data = get_underlying_scalar_constant_value(
549+
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
550+
)
551+
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
552+
if data == 0:
553+
return data
554+
break
548555

549556
raise NotScalarConstantError()
550557

@@ -4071,7 +4078,7 @@ def make_node(self, a, choices):
40714078
static_out_shape = ()
40724079
for s in out_shape:
40734080
try:
4074-
s_val = pytensor.get_underlying_scalar_constant(s)
4081+
s_val = get_underlying_scalar_constant_value(s)
40754082
except (NotScalarConstantError, AttributeError):
40764083
s_val = None
40774084

tests/tensor/test_elemwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytensor.link.basic import PerformLinker
2020
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2121
from pytensor.tensor import as_tensor_variable
22-
from pytensor.tensor.basic import second
22+
from pytensor.tensor.basic import get_scalar_constant_value, second
2323
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2424
from pytensor.tensor.math import Any, Sum, exp
2525
from pytensor.tensor.math import all as pt_all
@@ -807,8 +807,8 @@ def test_partial_static_shape_info(self):
807807

808808
assert len(res_shape) == 1
809809
assert len(res_shape[0]) == 2
810-
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
811-
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
810+
assert get_scalar_constant_value(res_shape[0][0]) == 1
811+
assert get_scalar_constant_value(res_shape[0][1]) == 1
812812

813813
def test_infer_shape_multi_output(self):
814814
class CustomElemwise(Elemwise):

0 commit comments

Comments
 (0)