Skip to content

Commit 029fc88

Browse files
committed
Deprecate pytensor.get_underlying_scalar_constant
1 parent 7462fdf commit 029fc88

File tree

4 files changed

+28
-22
lines changed

4 files changed

+28
-22
lines changed

pytensor/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# pytensor code, since this code may want to log some messages.
2525
import logging
2626
import sys
27+
import warnings
2728
from functools import singledispatch
2829
from pathlib import Path
2930
from typing import Any, NoReturn, Optional
@@ -148,12 +149,10 @@ def get_underlying_scalar_constant(v):
148149
If `v` is not some view of constant data, then raise a
149150
`NotScalarConstantError`.
150151
"""
151-
# Is it necessary to test for presence of pytensor.sparse at runtime?
152-
sparse = globals().get("sparse")
153-
if sparse and isinstance(v.type, sparse.SparseTensorType):
154-
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
155-
data = v.owner.inputs[0]
156-
return tensor.get_underlying_scalar_constant_value(data)
152+
warnings.warn(
153+
"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
154+
DeprecationWarning,
155+
)
157156
return tensor.get_underlying_scalar_constant_value(v)
158157

159158

pytensor/gradient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def try_to_copy_if_needed(var):
13301330
f" {i}. Since this input is only connected "
13311331
"to integer-valued outputs, it should "
13321332
"evaluate to zeros, but it evaluates to"
1333-
f"{pytensor.get_underlying_scalar_constant(term)}."
1333+
f"{pytensor.get_underlying_scalar_constant_value(term)}."
13341334
)
13351335
raise ValueError(msg)
13361336

@@ -2172,7 +2172,7 @@ def _is_zero(x):
21722172

21732173
no_constant_value = True
21742174
try:
2175-
constant_value = pytensor.get_underlying_scalar_constant(x)
2175+
constant_value = pytensor.get_underlying_scalar_constant_value(x)
21762176
no_constant_value = False
21772177
except pytensor.tensor.exceptions.NotScalarConstantError:
21782178
pass

pytensor/tensor/basic.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value(
319319
320320
"""
321321
from pytensor.compile.ops import DeepCopyOp, OutputGuard
322+
from pytensor.sparse import CSM
323+
from pytensor.tensor.subtensor import Subtensor
322324

323325
v = orig_v
324326
while True:
@@ -346,16 +348,16 @@ def get_underlying_scalar_constant_value(
346348
raise NotScalarConstantError()
347349

348350
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
351+
op = v.owner.op
349352
max_recur -= 1
350353
if isinstance(
351-
v.owner.op,
352-
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
354+
op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
353355
):
354356
# OutputGuard is only used in debugmode but we
355357
# keep it here to avoid problems with old pickles
356358
v = v.owner.inputs[0]
357359
continue
358-
elif isinstance(v.owner.op, Shape_i):
360+
elif isinstance(op, Shape_i):
359361
i = v.owner.op.i
360362
inp = v.owner.inputs[0]
361363
if isinstance(inp, Constant):
@@ -369,10 +371,10 @@ def get_underlying_scalar_constant_value(
369371
# mess with the stabilization optimization and be too slow.
370372
# We put all the scalar Ops used by get_canonical_form_slice()
371373
# to allow it to determine the broadcast pattern correctly.
372-
elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar):
374+
elif isinstance(op, ScalarFromTensor | TensorFromScalar):
373375
v = v.owner.inputs[0]
374376
continue
375-
elif isinstance(v.owner.op, CheckAndRaise):
377+
elif isinstance(op, CheckAndRaise):
376378
# check if all conditions are constant and true
377379
conds = [
378380
get_underlying_scalar_constant_value(c, max_recur=max_recur)
@@ -381,7 +383,7 @@ def get_underlying_scalar_constant_value(
381383
if builtins.all(0 == c.ndim and c != 0 for c in conds):
382384
v = v.owner.inputs[0]
383385
continue
384-
elif isinstance(v.owner.op, ps.ScalarOp):
386+
elif isinstance(op, ps.ScalarOp):
385387
if isinstance(v.owner.op, ps.Second):
386388
# We don't need both input to be constant for second
387389
shp, val = v.owner.inputs
@@ -398,7 +400,7 @@ def get_underlying_scalar_constant_value(
398400
# In fast_compile, we don't enable local_fill_to_alloc, so
399401
# we need to investigate Second as Alloc. So elemwise
400402
# don't disable the check for Second.
401-
elif isinstance(v.owner.op, Elemwise):
403+
elif isinstance(op, Elemwise):
402404
if isinstance(v.owner.op.scalar_op, ps.Second):
403405
# We don't need both input to be constant for second
404406
shp, val = v.owner.inputs
@@ -414,10 +416,7 @@ def get_underlying_scalar_constant_value(
414416
ret = [[None]]
415417
v.owner.op.perform(v.owner, const, ret)
416418
return np.asarray(ret[0][0].copy())
417-
elif (
418-
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)
419-
and v.ndim == 0
420-
):
419+
elif isinstance(op, Subtensor) and v.ndim == 0:
421420
if isinstance(v.owner.inputs[0], TensorConstant):
422421
from pytensor.tensor.subtensor import get_constant_idx
423422

@@ -541,6 +540,14 @@ def get_underlying_scalar_constant_value(
541540

542541
if isinstance(grandparent, Constant):
543542
return np.asarray(np.shape(grandparent.data)[idx])
543+
elif isinstance(op, CSM):
544+
data = get_underlying_scalar_constant_value(
545+
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
546+
)
547+
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
548+
if data == 0:
549+
return data
550+
break
544551

545552
raise NotScalarConstantError()
546553

@@ -4064,7 +4071,7 @@ def make_node(self, a, choices):
40644071
static_out_shape = ()
40654072
for s in out_shape:
40664073
try:
4067-
s_val = pytensor.get_underlying_scalar_constant(s)
4074+
s_val = get_underlying_scalar_constant_value(s)
40684075
except (NotScalarConstantError, AttributeError):
40694076
s_val = None
40704077

tests/tensor/test_elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,8 @@ def test_partial_static_shape_info(self):
804804

805805
assert len(res_shape) == 1
806806
assert len(res_shape[0]) == 2
807-
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
808-
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
807+
assert pytensor.get_underlying_scalar_constant_value(res_shape[0][0]) == 1
808+
assert pytensor.get_underlying_scalar_constant_value(res_shape[0][1]) == 1
809809

810810
def test_infer_shape_multi_output(self):
811811
class CustomElemwise(Elemwise):

0 commit comments

Comments
 (0)