Skip to content

Commit 32aadc8

Browse files
committed
Deprecate extract_constant
1 parent aad6fb7 commit 32aadc8

File tree

8 files changed

+403
-253
lines changed

8 files changed

+403
-253
lines changed

pytensor/scan/rewriting.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from pytensor.tensor.basic import (
5555
Alloc,
5656
AllocEmpty,
57+
get_scalar_constant_value,
5758
get_underlying_scalar_constant_value,
5859
)
5960
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used(
665666
client = fgraph.clients[outer_var][0][0]
666667
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
667668
lst = get_idx_list(client.inputs, client.op.idx_list)
668-
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1:
669-
return True
669+
return (
670+
len(lst) == 1
671+
and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1
672+
)
670673

671674
return False
672675

@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node):
13411344
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
13421345
global_nsteps = None
13431346
if isinstance(cf_slice[0], slice):
1344-
stop = pt.extract_constant(cf_slice[0].stop)
1347+
stop = get_scalar_constant_value(
1348+
cf_slice[0].stop, raise_not_constant=False
1349+
)
13451350
else:
1346-
stop = pt.extract_constant(cf_slice[0]) + 1
1347-
if stop == maxsize or stop == pt.extract_constant(length):
1351+
stop = (
1352+
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
1353+
+ 1
1354+
)
1355+
if stop == maxsize or stop == get_scalar_constant_value(
1356+
length, raise_not_constant=False
1357+
):
13481358
stop = None
13491359
else:
13501360
# there is a **gotcha** here ! Namely, scan returns an
@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node):
14481458
cf_slice = get_canonical_form_slice(this_slice[0], length)
14491459

14501460
if isinstance(cf_slice[0], slice):
1451-
start = pt.extract_constant(cf_slice[0].start)
1461+
start = pt.get_scalar_constant_value(
1462+
cf_slice[0].start, raise_not_constant=False
1463+
)
14521464
else:
1453-
start = pt.extract_constant(cf_slice[0])
1465+
start = pt.get_scalar_constant_value(
1466+
cf_slice[0], raise_not_constant=False
1467+
)
14541468

14551469
if start == 0 or store_steps[i] == 0:
14561470
store_steps[i] = 0
@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node):
16251639
# 3.6 Compose the new scan
16261640
# TODO: currently we don't support scan with 0 step. So
16271641
# don't create one.
1628-
if pt.extract_constant(node_ins[0]) == 0:
1642+
if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
16291643
return False
16301644

16311645
# Do not call make_node for test_value

pytensor/tensor/basic.py

Lines changed: 111 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -268,27 +268,7 @@ def _obj_is_wrappable_as_tensor(x):
268268
)
269269

270270

271-
def get_scalar_constant_value(
272-
v, elemwise=True, only_process_constants=False, max_recur=10
273-
):
274-
"""
275-
Checks whether 'v' is a scalar (ndim = 0).
276-
277-
If 'v' is a scalar then this function fetches the underlying constant by calling
278-
'get_underlying_scalar_constant_value()'.
279-
280-
If 'v' is not a scalar, it raises a NotScalarConstantError.
281-
282-
"""
283-
if isinstance(v, Variable | np.ndarray):
284-
if v.ndim != 0:
285-
raise NotScalarConstantError()
286-
return get_underlying_scalar_constant_value(
287-
v, elemwise, only_process_constants, max_recur
288-
)
289-
290-
291-
def get_underlying_scalar_constant_value(
271+
def _get_underlying_scalar_constant_value(
292272
orig_v, elemwise=True, only_process_constants=False, max_recur=10
293273
):
294274
"""Return the constant scalar(0-D) value underlying variable `v`.
@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value(
381361
elif isinstance(op, CheckAndRaise):
382362
# check if all conditions are constant and true
383363
conds = [
384-
get_underlying_scalar_constant_value(c, max_recur=max_recur)
364+
_get_underlying_scalar_constant_value(c, max_recur=max_recur)
385365
for c in v.owner.inputs[1:]
386366
]
387367
if builtins.all(0 == c.ndim and c != 0 for c in conds):
@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value(
395375
continue
396376
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
397377
const = [
398-
get_underlying_scalar_constant_value(i, max_recur=max_recur)
378+
_get_underlying_scalar_constant_value(i, max_recur=max_recur)
399379
for i in v.owner.inputs
400380
]
401381
ret = [[None]]
@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value(
414394
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
415395
):
416396
const = [
417-
get_underlying_scalar_constant_value(i, max_recur=max_recur)
397+
_get_underlying_scalar_constant_value(i, max_recur=max_recur)
418398
for i in v.owner.inputs
419399
]
420400
ret = [[None]]
@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value(
457437
):
458438
idx = v.owner.op.idx_list[0]
459439
if isinstance(idx, Type):
460-
idx = get_underlying_scalar_constant_value(
440+
idx = _get_underlying_scalar_constant_value(
461441
v.owner.inputs[1], max_recur=max_recur
462442
)
463443
try:
@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value(
491471
):
492472
idx = v.owner.op.idx_list[0]
493473
if isinstance(idx, Type):
494-
idx = get_underlying_scalar_constant_value(
474+
idx = _get_underlying_scalar_constant_value(
495475
v.owner.inputs[1], max_recur=max_recur
496476
)
497-
# Python 2.4 does not support indexing with numpy.integer
498-
# So we cast it.
499-
idx = int(idx)
500477
ret = v.owner.inputs[0].owner.inputs[idx]
501-
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
478+
ret = _get_underlying_scalar_constant_value(
479+
ret, max_recur=max_recur
480+
)
502481
# MakeVector can cast implicitly its input in some case.
503482
return np.asarray(ret, dtype=v.type.dtype)
504483

@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value(
513492
idx_list = op.idx_list
514493
idx = idx_list[0]
515494
if isinstance(idx, Type):
516-
idx = get_underlying_scalar_constant_value(
495+
idx = _get_underlying_scalar_constant_value(
517496
owner.inputs[1], max_recur=max_recur
518497
)
519498
grandparent = leftmost_parent.owner.inputs[0]
@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value(
523502
grandparent.owner.op, Unbroadcast
524503
):
525504
ggp_shape = grandparent.owner.inputs[0].type.shape
526-
l = [get_underlying_scalar_constant_value(s) for s in ggp_shape]
505+
l = [
506+
_get_underlying_scalar_constant_value(s) for s in ggp_shape
507+
]
527508
gp_shape = tuple(l)
528509

529510
if not (idx < ndim):
@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value(
545526
if isinstance(grandparent, Constant):
546527
return np.asarray(np.shape(grandparent.data)[idx])
547528
elif isinstance(op, CSM):
548-
data = get_underlying_scalar_constant_value(
529+
data = _get_underlying_scalar_constant_value(
549530
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
550531
)
551532
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value(
556537
raise NotScalarConstantError()
557538

558539

540+
def get_underlying_scalar_constant_value(
541+
v,
542+
*,
543+
elemwise=True,
544+
only_process_constants=False,
545+
max_recur=10,
546+
raise_not_constant=True,
547+
):
548+
"""Return the unique constant scalar(0-D) value underlying variable `v`.
549+
550+
If `v` is the output of dimshuffles, fills, allocs, etc,
551+
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
552+
and some pattern with Subtensor, this function digs through them.
553+
554+
If `v` is not some view of constant scalar data, then raise a
555+
NotScalarConstantError.
556+
557+
This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
558+
constant folding the inputs of `v`.
559+
560+
Parameters
561+
----------
562+
v: Variable
563+
elemwise : bool
564+
If False, we won't try to go into elemwise. So this call is faster.
565+
But we still investigate in Second Elemwise (as this is a substitute
566+
for Alloc)
567+
only_process_constants : bool
568+
If True, we only attempt to obtain the value of `orig_v` if it's
569+
directly constant and don't try to dig through dimshuffles, fills,
570+
allocs, and other to figure out its value.
571+
max_recur : int
572+
The maximum number of recursion.
573+
raise_not_constant: bool, default True
574+
If True, raise a NotScalarConstantError if `v` does not have an
575+
underlying constant scalar value. If False, return `v` as is.
576+
577+
578+
Raises
579+
------
580+
NotScalarConstantError
581+
`v` does not have an underlying constant scalar value.
582+
Only rasise if raise_not_constant is True.
583+
584+
"""
585+
try:
586+
return _get_underlying_scalar_constant_value(
587+
v,
588+
elemwise=elemwise,
589+
only_process_constants=only_process_constants,
590+
max_recur=max_recur,
591+
)
592+
except NotScalarConstantError:
593+
if raise_not_constant:
594+
raise
595+
return v
596+
597+
598+
def get_scalar_constant_value(
599+
v,
600+
elemwise=True,
601+
only_process_constants=False,
602+
max_recur=10,
603+
raise_not_constant: bool = True,
604+
):
605+
"""
606+
Checks whether 'v' is a scalar (ndim = 0).
607+
608+
If 'v' is a scalar then this function fetches the underlying constant by calling
609+
'get_underlying_scalar_constant_value()'.
610+
611+
If 'v' is not a scalar, it raises a NotScalarConstantError.
612+
613+
"""
614+
if isinstance(v, TensorVariable | np.ndarray):
615+
if v.ndim != 0:
616+
print(v, v.ndim)
617+
raise NotScalarConstantError("Input ndim != 0")
618+
return get_underlying_scalar_constant_value(
619+
v,
620+
elemwise=elemwise,
621+
only_process_constants=only_process_constants,
622+
max_recur=max_recur,
623+
raise_not_constant=raise_not_constant,
624+
)
625+
626+
559627
class TensorFromScalar(COp):
560628
__props__ = ()
561629

@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
20122080
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
20132081
20142082
"""
2015-
try:
2016-
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
2017-
except NotScalarConstantError:
2018-
pass
2019-
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable):
2020-
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
2021-
x = x.owner.inputs[0]
2022-
else:
2023-
x = tensor_from_scalar(x)
2024-
return x
2083+
warnings.warn(
2084+
"extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`",
2085+
FutureWarning,
2086+
)
2087+
return get_underlying_scalar_constant_value(
2088+
x,
2089+
elemwise=elemwise,
2090+
only_process_constants=only_process_constants,
2091+
raise_not_constant=False,
2092+
)
20252093

20262094

20272095
def transpose(x, axes=None):
@@ -4401,7 +4469,6 @@ def ix_(*args):
44014469
"split",
44024470
"transpose",
44034471
"matrix_transpose",
4404-
"extract_constant",
44054472
"default",
44064473
"tensor_copy",
44074474
"transfer",

0 commit comments

Comments
 (0)