Skip to content

Commit 40e0011

Browse files
committed
Deprecate extract_constant
1 parent 029fc88 commit 40e0011

File tree

8 files changed

+406
-253
lines changed

8 files changed

+406
-253
lines changed

pytensor/scan/rewriting.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from pytensor.tensor.basic import (
5555
Alloc,
5656
AllocEmpty,
57+
get_scalar_constant_value,
5758
get_underlying_scalar_constant_value,
5859
)
5960
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used(
665666
client = fgraph.clients[outer_var][0][0]
666667
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
667668
lst = get_idx_list(client.inputs, client.op.idx_list)
668-
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1:
669-
return True
669+
return (
670+
len(lst) == 1
671+
and get_scalar_constant_value(idx[0], raise_not_constant=False) == -1
672+
)
670673

671674
return False
672675

@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node):
13411344
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
13421345
global_nsteps = None
13431346
if isinstance(cf_slice[0], slice):
1344-
stop = pt.extract_constant(cf_slice[0].stop)
1347+
stop = get_scalar_constant_value(
1348+
cf_slice[0].stop, raise_not_constant=False
1349+
)
13451350
else:
1346-
stop = pt.extract_constant(cf_slice[0]) + 1
1347-
if stop == maxsize or stop == pt.extract_constant(length):
1351+
stop = (
1352+
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
1353+
+ 1
1354+
)
1355+
if stop == maxsize or stop == get_scalar_constant_value(
1356+
length, raise_not_constant=False
1357+
):
13481358
stop = None
13491359
else:
13501360
# there is a **gotcha** here ! Namely, scan returns an
@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node):
14481458
cf_slice = get_canonical_form_slice(this_slice[0], length)
14491459

14501460
if isinstance(cf_slice[0], slice):
1451-
start = pt.extract_constant(cf_slice[0].start)
1461+
start = pt.get_scalar_constant_value(
1462+
cf_slice[0].start, raise_not_constant=False
1463+
)
14521464
else:
1453-
start = pt.extract_constant(cf_slice[0])
1465+
start = pt.get_scalar_constant_value(
1466+
cf_slice[0], raise_not_constant=False
1467+
)
14541468

14551469
if start == 0 or store_steps[i] == 0:
14561470
store_steps[i] = 0
@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node):
16251639
# 3.6 Compose the new scan
16261640
# TODO: currently we don't support scan with 0 step. So
16271641
# don't create one.
1628-
if pt.extract_constant(node_ins[0]) == 0:
1642+
if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
16291643
return False
16301644

16311645
# Do not call make_node for test_value

pytensor/tensor/basic.py

Lines changed: 110 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -267,27 +267,7 @@ def _obj_is_wrappable_as_tensor(x):
267267
)
268268

269269

270-
def get_scalar_constant_value(
271-
v, elemwise=True, only_process_constants=False, max_recur=10
272-
):
273-
"""
274-
Checks whether 'v' is a scalar (ndim = 0).
275-
276-
If 'v' is a scalar then this function fetches the underlying constant by calling
277-
'get_underlying_scalar_constant_value()'.
278-
279-
If 'v' is not a scalar, it raises a NotScalarConstantError.
280-
281-
"""
282-
if isinstance(v, Variable | np.ndarray):
283-
if v.ndim != 0:
284-
raise NotScalarConstantError()
285-
return get_underlying_scalar_constant_value(
286-
v, elemwise, only_process_constants, max_recur
287-
)
288-
289-
290-
def get_underlying_scalar_constant_value(
270+
def _get_underlying_scalar_constant_value(
291271
orig_v, elemwise=True, only_process_constants=False, max_recur=10
292272
):
293273
"""Return the constant scalar(0-D) value underlying variable `v`.
@@ -377,7 +357,7 @@ def get_underlying_scalar_constant_value(
377357
elif isinstance(op, CheckAndRaise):
378358
# check if all conditions are constant and true
379359
conds = [
380-
get_underlying_scalar_constant_value(c, max_recur=max_recur)
360+
_get_underlying_scalar_constant_value(c, max_recur=max_recur)
381361
for c in v.owner.inputs[1:]
382362
]
383363
if builtins.all(0 == c.ndim and c != 0 for c in conds):
@@ -391,7 +371,7 @@ def get_underlying_scalar_constant_value(
391371
continue
392372
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
393373
const = [
394-
get_underlying_scalar_constant_value(i, max_recur=max_recur)
374+
_get_underlying_scalar_constant_value(i, max_recur=max_recur)
395375
for i in v.owner.inputs
396376
]
397377
ret = [[None]]
@@ -410,7 +390,7 @@ def get_underlying_scalar_constant_value(
410390
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
411391
):
412392
const = [
413-
get_underlying_scalar_constant_value(i, max_recur=max_recur)
393+
_get_underlying_scalar_constant_value(i, max_recur=max_recur)
414394
for i in v.owner.inputs
415395
]
416396
ret = [[None]]
@@ -453,7 +433,7 @@ def get_underlying_scalar_constant_value(
453433
):
454434
idx = v.owner.op.idx_list[0]
455435
if isinstance(idx, Type):
456-
idx = get_underlying_scalar_constant_value(
436+
idx = _get_underlying_scalar_constant_value(
457437
v.owner.inputs[1], max_recur=max_recur
458438
)
459439
try:
@@ -487,14 +467,13 @@ def get_underlying_scalar_constant_value(
487467
):
488468
idx = v.owner.op.idx_list[0]
489469
if isinstance(idx, Type):
490-
idx = get_underlying_scalar_constant_value(
470+
idx = _get_underlying_scalar_constant_value(
491471
v.owner.inputs[1], max_recur=max_recur
492472
)
493-
# Python 2.4 does not support indexing with numpy.integer
494-
# So we cast it.
495-
idx = int(idx)
496473
ret = v.owner.inputs[0].owner.inputs[idx]
497-
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
474+
ret = _get_underlying_scalar_constant_value(
475+
ret, max_recur=max_recur
476+
)
498477
# MakeVector can cast implicitly its input in some case.
499478
return np.asarray(ret, dtype=v.type.dtype)
500479

@@ -509,7 +488,7 @@ def get_underlying_scalar_constant_value(
509488
idx_list = op.idx_list
510489
idx = idx_list[0]
511490
if isinstance(idx, Type):
512-
idx = get_underlying_scalar_constant_value(
491+
idx = _get_underlying_scalar_constant_value(
513492
owner.inputs[1], max_recur=max_recur
514493
)
515494
grandparent = leftmost_parent.owner.inputs[0]
@@ -519,7 +498,9 @@ def get_underlying_scalar_constant_value(
519498
grandparent.owner.op, Unbroadcast
520499
):
521500
ggp_shape = grandparent.owner.inputs[0].type.shape
522-
l = [get_underlying_scalar_constant_value(s) for s in ggp_shape]
501+
l = [
502+
_get_underlying_scalar_constant_value(s) for s in ggp_shape
503+
]
523504
gp_shape = tuple(l)
524505

525506
if not (idx < ndim):
@@ -541,7 +522,7 @@ def get_underlying_scalar_constant_value(
541522
if isinstance(grandparent, Constant):
542523
return np.asarray(np.shape(grandparent.data)[idx])
543524
elif isinstance(op, CSM):
544-
data = get_underlying_scalar_constant_value(
525+
data = _get_underlying_scalar_constant_value(
545526
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
546527
)
547528
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
@@ -552,6 +533,92 @@ def get_underlying_scalar_constant_value(
552533
raise NotScalarConstantError()
553534

554535

536+
def get_underlying_scalar_constant_value(
537+
v,
538+
*,
539+
elemwise=True,
540+
only_process_constants=False,
541+
max_recur=10,
542+
raise_not_constant=True,
543+
):
544+
"""Return the unique constant scalar(0-D) value underlying variable `v`.
545+
546+
If `v` is the output of dimshuffles, fills, allocs, etc,
547+
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
548+
and some pattern with Subtensor, this function digs through them.
549+
550+
If `v` is not some view of constant scalar data, then raise a
551+
NotScalarConstantError.
552+
553+
This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
554+
constant folding the inputs of `v`.
555+
556+
Parameters
557+
----------
558+
v: Variable
559+
elemwise : bool
560+
If False, we won't try to go into elemwise. So this call is faster.
561+
But we still investigate in Second Elemwise (as this is a substitute
562+
for Alloc)
563+
only_process_constants : bool
564+
If True, we only attempt to obtain the value of `orig_v` if it's
565+
directly constant and don't try to dig through dimshuffles, fills,
566+
allocs, and other to figure out its value.
567+
max_recur : int
568+
The maximum number of recursion.
569+
raise_not_constant: bool, default True
570+
If True, raise a NotScalarConstantError if `v` does not have an
571+
underlying constant scalar value. If False, return `v` as is.
572+
573+
574+
Raises
575+
------
576+
NotScalarConstantError
577+
`v` does not have an underlying constant scalar value.
578+
Only rasise if raise_not_constant is True.
579+
580+
"""
581+
try:
582+
return _get_underlying_scalar_constant_value(
583+
v,
584+
elemwise=elemwise,
585+
only_process_constants=only_process_constants,
586+
max_recur=max_recur,
587+
)
588+
except NotScalarConstantError:
589+
if raise_not_constant:
590+
raise
591+
return v
592+
593+
594+
def get_scalar_constant_value(
595+
v,
596+
elemwise=True,
597+
only_process_constants=False,
598+
max_recur=10,
599+
raise_not_constant: bool = True,
600+
):
601+
"""
602+
Checks whether 'v' is a scalar (ndim = 0).
603+
604+
If 'v' is a scalar then this function fetches the underlying constant by calling
605+
'get_underlying_scalar_constant_value()'.
606+
607+
If 'v' is not a scalar, it raises a TypeError.
608+
609+
"""
610+
if isinstance(v, Variable | np.ndarray):
611+
if v.ndim != 0:
612+
raise TypeError()
613+
return get_underlying_scalar_constant_value(
614+
v,
615+
elemwise=elemwise,
616+
only_process_constants=only_process_constants,
617+
max_recur=max_recur,
618+
raise_not_constant=raise_not_constant,
619+
)
620+
621+
555622
class TensorFromScalar(COp):
556623
__props__ = ()
557624

@@ -2006,16 +2073,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
20062073
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
20072074
20082075
"""
2009-
try:
2010-
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
2011-
except NotScalarConstantError:
2012-
pass
2013-
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable):
2014-
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
2015-
x = x.owner.inputs[0]
2016-
else:
2017-
x = tensor_from_scalar(x)
2018-
return x
2076+
warnings.warn(
2077+
"extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`",
2078+
FutureWarning,
2079+
)
2080+
return get_underlying_scalar_constant_value(
2081+
x,
2082+
elemwise=elemwise,
2083+
only_process_constants=only_process_constants,
2084+
raise_not_constant=False,
2085+
)
20192086

20202087

20212088
def transpose(x, axes=None):
@@ -4394,7 +4461,6 @@ def ix_(*args):
43944461
"split",
43954462
"transpose",
43964463
"matrix_transpose",
4397-
"extract_constant",
43984464
"default",
43994465
"tensor_copy",
44004466
"transfer",

0 commit comments

Comments
 (0)