Skip to content

Commit a283a96

Browse files
committed
Use more strict get_scalar_constant_value when the input must be a scalar
1 parent 64d7f3f commit a283a96

File tree

12 files changed

+50
-66
lines changed

12 files changed

+50
-66
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Split,
1919
TensorFromScalar,
2020
Tri,
21-
get_underlying_scalar_constant_value,
21+
get_scalar_constant_value,
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424
from pytensor.tensor.shape import Shape_i
@@ -103,7 +103,7 @@ def join(axis, *tensors):
103103
def jax_funcify_Split(op: Split, node, **kwargs):
104104
_, axis, splits = node.inputs
105105
try:
106-
constant_axis = get_underlying_scalar_constant_value(axis)
106+
constant_axis = get_scalar_constant_value(axis)
107107
except NotScalarConstantError:
108108
constant_axis = None
109109
warnings.warn(
@@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
113113
try:
114114
constant_splits = np.array(
115115
[
116-
get_underlying_scalar_constant_value(splits[i])
116+
get_scalar_constant_value(splits[i])
117117
for i in range(get_vector_length(splits))
118118
]
119119
)

pytensor/scan/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def wrap_into_list(x):
484484
n_fixed_steps = int(n_steps)
485485
else:
486486
try:
487-
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
487+
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
488488
except NotScalarConstantError:
489489
n_fixed_steps = None
490490

pytensor/scan/rewriting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
Alloc,
5656
AllocEmpty,
5757
get_scalar_constant_value,
58-
get_underlying_scalar_constant_value,
5958
)
6059
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6160
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -1976,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes):
19761975

19771976
nsteps = node.inputs[0]
19781977
try:
1979-
nsteps = int(get_underlying_scalar_constant_value(nsteps))
1978+
nsteps = int(get_scalar_constant_value(nsteps))
19801979
except NotScalarConstantError:
19811980
pass
19821981

19831982
rep_nsteps = rep_node.inputs[0]
19841983
try:
1985-
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
1984+
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
19861985
except NotScalarConstantError:
19871986
pass
19881987

pytensor/tensor/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,7 +1803,7 @@ def do_constant_folding(self, fgraph, node):
18031803
@_get_vector_length.register(Alloc)
18041804
def _get_vector_length_Alloc(var_inst, var):
18051805
try:
1806-
return get_underlying_scalar_constant_value(var.owner.inputs[1])
1806+
return get_scalar_constant_value(var.owner.inputs[1])
18071807
except NotScalarConstantError:
18081808
raise ValueError(f"Length of {var} cannot be determined")
18091809

@@ -2504,7 +2504,7 @@ def make_node(self, axis, *tensors):
25042504

25052505
if not isinstance(axis, int):
25062506
try:
2507-
axis = int(get_underlying_scalar_constant_value(axis))
2507+
axis = int(get_scalar_constant_value(axis))
25082508
except NotScalarConstantError:
25092509
pass
25102510

@@ -2748,7 +2748,7 @@ def infer_shape(self, fgraph, node, ishapes):
27482748
def _get_vector_length_Join(op, var):
27492749
axis, *arrays = var.owner.inputs
27502750
try:
2751-
axis = get_underlying_scalar_constant_value(axis)
2751+
axis = get_scalar_constant_value(axis)
27522752
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
27532753
return builtins.sum(get_vector_length(a) for a in arrays)
27542754
except NotScalarConstantError:
@@ -4141,7 +4141,7 @@ def make_node(self, a, choices):
41414141
static_out_shape = ()
41424142
for s in out_shape:
41434143
try:
4144-
s_val = get_underlying_scalar_constant_value(s)
4144+
s_val = get_scalar_constant_value(s)
41454145
except (NotScalarConstantError, AttributeError):
41464146
s_val = None
41474147

pytensor/tensor/conv/abstract_conv.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.raise_op import Assert
2626
from pytensor.tensor.basic import (
2727
as_tensor_variable,
28-
get_underlying_scalar_constant_value,
28+
get_scalar_constant_value,
2929
)
3030
from pytensor.tensor.exceptions import NotScalarConstantError
3131
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -497,8 +497,8 @@ def check_dim(given, computed):
497497
if given is None or computed is None:
498498
return True
499499
try:
500-
given = get_underlying_scalar_constant_value(given)
501-
computed = get_underlying_scalar_constant_value(computed)
500+
given = get_scalar_constant_value(given)
501+
computed = get_scalar_constant_value(computed)
502502
return int(given) == int(computed)
503503
except NotScalarConstantError:
504504
# no answer possible, accept for now
@@ -534,7 +534,7 @@ def assert_conv_shape(shape):
534534
out_shape = []
535535
for i, n in enumerate(shape):
536536
try:
537-
const_n = get_underlying_scalar_constant_value(n)
537+
const_n = get_scalar_constant_value(n)
538538
if i < 2:
539539
if const_n < 0:
540540
raise ValueError(
@@ -2203,9 +2203,7 @@ def __init__(
22032203
if imshp_i is not None:
22042204
# Components of imshp should be constant or ints
22052205
try:
2206-
get_underlying_scalar_constant_value(
2207-
imshp_i, only_process_constants=True
2208-
)
2206+
get_scalar_constant_value(imshp_i, only_process_constants=True)
22092207
except NotScalarConstantError:
22102208
raise ValueError(
22112209
"imshp should be None or a tuple of constant int values"
@@ -2218,9 +2216,7 @@ def __init__(
22182216
if kshp_i is not None:
22192217
# Components of kshp should be constant or ints
22202218
try:
2221-
get_underlying_scalar_constant_value(
2222-
kshp_i, only_process_constants=True
2223-
)
2219+
get_scalar_constant_value(kshp_i, only_process_constants=True)
22242220
except NotScalarConstantError:
22252221
raise ValueError(
22262222
"kshp should be None or a tuple of constant int values"

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def make_node(self, x, repeats):
678678
out_shape = [None]
679679
else:
680680
try:
681-
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
681+
const_reps = ptb.get_scalar_constant_value(repeats)
682682
except NotScalarConstantError:
683683
const_reps = None
684684
if const_reps == 1:

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
cast,
5858
fill,
5959
get_scalar_constant_value,
60-
get_underlying_scalar_constant_value,
6160
join,
6261
ones_like,
6362
register_infer_shape,
@@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node):
739738
n_conds = len(node.inputs[1:])
740739
for c in node.inputs[1:]:
741740
try:
742-
const = get_underlying_scalar_constant_value(c)
741+
const = get_scalar_constant_value(c)
743742

744743
if 0 != const.ndim or const == 0:
745744
# Should we raise an error here? How to be sure it
@@ -834,7 +833,7 @@ def local_join_empty(fgraph, node):
834833
return
835834
new_inputs = []
836835
try:
837-
join_idx = get_underlying_scalar_constant_value(
836+
join_idx = get_scalar_constant_value(
838837
node.inputs[0], only_process_constants=True
839838
)
840839
except NotScalarConstantError:

pytensor/tensor/rewriting/math.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node):
153153

154154
x = node.inputs[0]
155155
y = node.inputs[1]
156-
replace = False
157-
try:
158-
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
159-
replace = True
160-
except NotScalarConstantError:
161-
pass
162-
163-
try:
164-
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
165-
replace = True
166-
except NotScalarConstantError:
167-
pass
156+
replace = (
157+
get_underlying_scalar_constant_value(
158+
x, only_process_constants=True, raise_not_constant=False
159+
)
160+
== 0
161+
or get_underlying_scalar_constant_value(
162+
y, only_process_constants=True, raise_not_constant=False
163+
)
164+
== 0
165+
)
168166

169167
if replace:
170168
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
@@ -2096,7 +2094,7 @@ def local_add_remove_zeros(fgraph, node):
20962094
y = get_underlying_scalar_constant_value(inp)
20972095
except NotScalarConstantError:
20982096
y = inp
2099-
if np.all(y == 0.0):
2097+
if y == 0.0:
21002098
continue
21012099
new_inputs.append(inp)
21022100

@@ -2194,7 +2192,7 @@ def local_abs_merge(fgraph, node):
21942192
)
21952193
except NotScalarConstantError:
21962194
return False
2197-
if not (const >= 0).all():
2195+
if not const >= 0:
21982196
return False
21992197
inputs.append(i)
22002198
else:
@@ -2844,7 +2842,7 @@ def _is_1(expr):
28442842
"""
28452843
try:
28462844
v = get_underlying_scalar_constant_value(expr)
2847-
return np.allclose(v, 1)
2845+
return np.isclose(v, 1)
28482846
except NotScalarConstantError:
28492847
return False
28502848

@@ -3012,7 +3010,7 @@ def is_neg(var):
30123010
for idx, mul_input in enumerate(var_node.inputs):
30133011
try:
30143012
constant = get_underlying_scalar_constant_value(mul_input)
3015-
is_minus_1 = np.allclose(constant, -1)
3013+
is_minus_1 = np.isclose(constant, -1)
30163014
except NotScalarConstantError:
30173015
is_minus_1 = False
30183016
if is_minus_1:

pytensor/tensor/rewriting/shape.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
cast,
2424
constant,
2525
get_scalar_constant_value,
26-
get_underlying_scalar_constant_value,
2726
register_infer_shape,
2827
stack,
2928
)
@@ -213,7 +212,7 @@ def shape_ir(self, i, r):
213212
# Do not call make_node for test_value
214213
s = Shape_i(i)(r)
215214
try:
216-
s = get_underlying_scalar_constant_value(s)
215+
s = get_scalar_constant_value(s)
217216
except NotScalarConstantError:
218217
pass
219218
return s
@@ -297,7 +296,7 @@ def unpack(self, s_i, var):
297296
assert len(idx) == 1
298297
idx = idx[0]
299298
try:
300-
i = get_underlying_scalar_constant_value(idx)
299+
i = get_scalar_constant_value(idx)
301300
except NotScalarConstantError:
302301
pass
303302
else:
@@ -452,7 +451,7 @@ def update_shape(self, r, other_r):
452451
)
453452
or self.lscalar_one.equals(merged_shape[i])
454453
or self.lscalar_one.equals(
455-
get_underlying_scalar_constant_value(
454+
get_scalar_constant_value(
456455
merged_shape[i],
457456
only_process_constants=True,
458457
raise_not_constant=False,
@@ -481,9 +480,7 @@ def set_shape_i(self, r, i, s_i):
481480
or r.type.shape[idx] != 1
482481
or self.lscalar_one.equals(new_shape[idx])
483482
or self.lscalar_one.equals(
484-
get_underlying_scalar_constant_value(
485-
new_shape[idx], raise_not_constant=False
486-
)
483+
get_scalar_constant_value(new_shape[idx], raise_not_constant=False)
487484
)
488485
for idx in range(r.type.ndim)
489486
)

pytensor/tensor/rewriting/subtensor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ def local_useless_subtensor(fgraph, node):
997997
if isinstance(idx.stop, int | np.integer):
998998
length_pos_data = sys.maxsize
999999
try:
1000-
length_pos_data = get_underlying_scalar_constant_value(
1000+
length_pos_data = get_scalar_constant_value(
10011001
length_pos, only_process_constants=True
10021002
)
10031003
except NotScalarConstantError:
@@ -1062,7 +1062,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
10621062

10631063
# get length of the indexed tensor along the first axis
10641064
try:
1065-
length = get_underlying_scalar_constant_value(
1065+
length = get_scalar_constant_value(
10661066
shape_of[node.inputs[0]][0], only_process_constants=True
10671067
)
10681068
except NotScalarConstantError:
@@ -1734,7 +1734,7 @@ def local_join_subtensors(fgraph, node):
17341734
axis, tensors = node.inputs[0], node.inputs[1:]
17351735

17361736
try:
1737-
axis = get_underlying_scalar_constant_value(axis)
1737+
axis = get_scalar_constant_value(axis)
17381738
except NotScalarConstantError:
17391739
return
17401740

@@ -1795,12 +1795,7 @@ def local_join_subtensors(fgraph, node):
17951795
if step is None:
17961796
continue
17971797
try:
1798-
if (
1799-
get_underlying_scalar_constant_value(
1800-
step, only_process_constants=True
1801-
)
1802-
!= 1
1803-
):
1798+
if get_scalar_constant_value(step, only_process_constants=True) != 1:
18041799
return None
18051800
except NotScalarConstantError:
18061801
return None

0 commit comments

Comments
 (0)