Skip to content

Commit 96dbda4

Browse files
committed
Use more strict get_scalar_constant_value when the input must be a scalar
1 parent 40e0011 commit 96dbda4

File tree

12 files changed

+58
-71
lines changed

12 files changed

+58
-71
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Split,
1919
TensorFromScalar,
2020
Tri,
21-
get_underlying_scalar_constant_value,
21+
get_scalar_constant_value,
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424
from pytensor.tensor.shape import Shape_i
@@ -103,7 +103,7 @@ def join(axis, *tensors):
103103
def jax_funcify_Split(op: Split, node, **kwargs):
104104
_, axis, splits = node.inputs
105105
try:
106-
constant_axis = get_underlying_scalar_constant_value(axis)
106+
constant_axis = get_scalar_constant_value(axis)
107107
except NotScalarConstantError:
108108
constant_axis = None
109109
warnings.warn(
@@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
113113
try:
114114
constant_splits = np.array(
115115
[
116-
get_underlying_scalar_constant_value(splits[i])
116+
get_scalar_constant_value(splits[i])
117117
for i in range(get_vector_length(splits))
118118
]
119119
)

pytensor/scan/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def wrap_into_list(x):
484484
n_fixed_steps = int(n_steps)
485485
else:
486486
try:
487-
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
487+
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
488488
except NotScalarConstantError:
489489
n_fixed_steps = None
490490

pytensor/scan/rewriting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
Alloc,
5656
AllocEmpty,
5757
get_scalar_constant_value,
58-
get_underlying_scalar_constant_value,
5958
)
6059
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6160
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -1976,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes):
19761975

19771976
nsteps = node.inputs[0]
19781977
try:
1979-
nsteps = int(get_underlying_scalar_constant_value(nsteps))
1978+
nsteps = int(get_scalar_constant_value(nsteps))
19801979
except NotScalarConstantError:
19811980
pass
19821981

19831982
rep_nsteps = rep_node.inputs[0]
19841983
try:
1985-
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
1984+
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
19861985
except NotScalarConstantError:
19871986
pass
19881987

pytensor/tensor/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def get_scalar_constant_value(
609609
"""
610610
if isinstance(v, Variable | np.ndarray):
611611
if v.ndim != 0:
612-
raise TypeError()
612+
raise TypeError("Input is not a scalar")
613613
return get_underlying_scalar_constant_value(
614614
v,
615615
elemwise=elemwise,
@@ -1801,7 +1801,7 @@ def do_constant_folding(self, fgraph, node):
18011801
@_get_vector_length.register(Alloc)
18021802
def _get_vector_length_Alloc(var_inst, var):
18031803
try:
1804-
return get_underlying_scalar_constant_value(var.owner.inputs[1])
1804+
return get_scalar_constant_value(var.owner.inputs[1])
18051805
except NotScalarConstantError:
18061806
raise ValueError(f"Length of {var} cannot be determined")
18071807

@@ -2502,7 +2502,7 @@ def make_node(self, axis, *tensors):
25022502

25032503
if not isinstance(axis, int):
25042504
try:
2505-
axis = int(get_underlying_scalar_constant_value(axis))
2505+
axis = int(get_scalar_constant_value(axis))
25062506
except NotScalarConstantError:
25072507
pass
25082508

@@ -2746,7 +2746,7 @@ def infer_shape(self, fgraph, node, ishapes):
27462746
def _get_vector_length_Join(op, var):
27472747
axis, *arrays = var.owner.inputs
27482748
try:
2749-
axis = get_underlying_scalar_constant_value(axis)
2749+
axis = get_scalar_constant_value(axis)
27502750
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
27512751
return builtins.sum(get_vector_length(a) for a in arrays)
27522752
except NotScalarConstantError:
@@ -4138,7 +4138,7 @@ def make_node(self, a, choices):
41384138
static_out_shape = ()
41394139
for s in out_shape:
41404140
try:
4141-
s_val = get_underlying_scalar_constant_value(s)
4141+
s_val = get_scalar_constant_value(s)
41424142
except (NotScalarConstantError, AttributeError):
41434143
s_val = None
41444144

pytensor/tensor/conv/abstract_conv.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.raise_op import Assert
2626
from pytensor.tensor.basic import (
2727
as_tensor_variable,
28-
get_underlying_scalar_constant_value,
28+
get_scalar_constant_value,
2929
)
3030
from pytensor.tensor.exceptions import NotScalarConstantError
3131
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -497,8 +497,8 @@ def check_dim(given, computed):
497497
if given is None or computed is None:
498498
return True
499499
try:
500-
given = get_underlying_scalar_constant_value(given)
501-
computed = get_underlying_scalar_constant_value(computed)
500+
given = get_scalar_constant_value(given)
501+
computed = get_scalar_constant_value(computed)
502502
return int(given) == int(computed)
503503
except NotScalarConstantError:
504504
# no answer possible, accept for now
@@ -534,7 +534,7 @@ def assert_conv_shape(shape):
534534
out_shape = []
535535
for i, n in enumerate(shape):
536536
try:
537-
const_n = get_underlying_scalar_constant_value(n)
537+
const_n = get_scalar_constant_value(n)
538538
if i < 2:
539539
if const_n < 0:
540540
raise ValueError(
@@ -2203,9 +2203,7 @@ def __init__(
22032203
if imshp_i is not None:
22042204
# Components of imshp should be constant or ints
22052205
try:
2206-
get_underlying_scalar_constant_value(
2207-
imshp_i, only_process_constants=True
2208-
)
2206+
get_scalar_constant_value(imshp_i, only_process_constants=True)
22092207
except NotScalarConstantError:
22102208
raise ValueError(
22112209
"imshp should be None or a tuple of constant int values"
@@ -2218,9 +2216,7 @@ def __init__(
22182216
if kshp_i is not None:
22192217
# Components of kshp should be constant or ints
22202218
try:
2221-
get_underlying_scalar_constant_value(
2222-
kshp_i, only_process_constants=True
2223-
)
2219+
get_scalar_constant_value(kshp_i, only_process_constants=True)
22242220
except NotScalarConstantError:
22252221
raise ValueError(
22262222
"kshp should be None or a tuple of constant int values"

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def make_node(self, x, repeats):
678678
out_shape = [None]
679679
else:
680680
try:
681-
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
681+
const_reps = ptb.get_scalar_constant_value(repeats)
682682
except NotScalarConstantError:
683683
const_reps = None
684684
if const_reps == 1:

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
cast,
5757
fill,
5858
get_scalar_constant_value,
59-
get_underlying_scalar_constant_value,
6059
join,
6160
ones_like,
6261
register_infer_shape,
@@ -738,7 +737,7 @@ def local_remove_useless_assert(fgraph, node):
738737
n_conds = len(node.inputs[1:])
739738
for c in node.inputs[1:]:
740739
try:
741-
const = get_underlying_scalar_constant_value(c)
740+
const = get_scalar_constant_value(c)
742741

743742
if 0 != const.ndim or const == 0:
744743
# Should we raise an error here? How to be sure it
@@ -833,7 +832,7 @@ def local_join_empty(fgraph, node):
833832
return
834833
new_inputs = []
835834
try:
836-
join_idx = get_underlying_scalar_constant_value(
835+
join_idx = get_scalar_constant_value(
837836
node.inputs[0], only_process_constants=True
838837
)
839838
except NotScalarConstantError:

pytensor/tensor/rewriting/math.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,16 @@ def local_0_dot_x(fgraph, node):
152152

153153
x = node.inputs[0]
154154
y = node.inputs[1]
155-
replace = False
156-
try:
157-
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
158-
replace = True
159-
except NotScalarConstantError:
160-
pass
161-
162-
try:
163-
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
164-
replace = True
165-
except NotScalarConstantError:
166-
pass
155+
replace = (
156+
get_underlying_scalar_constant_value(
157+
x, only_process_constants=True, raise_not_constant=False
158+
)
159+
== 0
160+
or get_underlying_scalar_constant_value(
161+
y, only_process_constants=True, raise_not_constant=False
162+
)
163+
== 0
164+
)
167165

168166
if replace:
169167
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
@@ -2135,7 +2133,7 @@ def local_add_remove_zeros(fgraph, node):
21352133
y = get_underlying_scalar_constant_value(inp)
21362134
except NotScalarConstantError:
21372135
y = inp
2138-
if np.all(y == 0.0):
2136+
if y == 0.0:
21392137
continue
21402138
new_inputs.append(inp)
21412139

@@ -2233,7 +2231,7 @@ def local_abs_merge(fgraph, node):
22332231
)
22342232
except NotScalarConstantError:
22352233
return False
2236-
if not (const >= 0).all():
2234+
if not const >= 0:
22372235
return False
22382236
inputs.append(i)
22392237
else:
@@ -2881,7 +2879,7 @@ def _is_1(expr):
28812879
"""
28822880
try:
28832881
v = get_underlying_scalar_constant_value(expr)
2884-
return np.allclose(v, 1)
2882+
return np.isclose(v, 1)
28852883
except NotScalarConstantError:
28862884
return False
28872885

@@ -3049,7 +3047,7 @@ def is_neg(var):
30493047
for idx, mul_input in enumerate(var_node.inputs):
30503048
try:
30513049
constant = get_underlying_scalar_constant_value(mul_input)
3052-
is_minus_1 = np.allclose(constant, -1)
3050+
is_minus_1 = np.isclose(constant, -1)
30533051
except NotScalarConstantError:
30543052
is_minus_1 = False
30553053
if is_minus_1:

pytensor/tensor/rewriting/shape.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
cast,
2424
constant,
2525
get_scalar_constant_value,
26-
get_underlying_scalar_constant_value,
2726
register_infer_shape,
2827
stack,
2928
)
@@ -213,7 +212,7 @@ def shape_ir(self, i, r):
213212
# Do not call make_node for test_value
214213
s = Shape_i(i)(r)
215214
try:
216-
s = get_underlying_scalar_constant_value(s)
215+
s = get_scalar_constant_value(s)
217216
except NotScalarConstantError:
218217
pass
219218
return s
@@ -297,7 +296,7 @@ def unpack(self, s_i, var):
297296
assert len(idx) == 1
298297
idx = idx[0]
299298
try:
300-
i = get_underlying_scalar_constant_value(idx)
299+
i = get_scalar_constant_value(idx)
301300
except NotScalarConstantError:
302301
pass
303302
else:
@@ -452,7 +451,7 @@ def update_shape(self, r, other_r):
452451
)
453452
or self.lscalar_one.equals(merged_shape[i])
454453
or self.lscalar_one.equals(
455-
get_underlying_scalar_constant_value(
454+
get_scalar_constant_value(
456455
merged_shape[i],
457456
only_process_constants=True,
458457
raise_not_constant=False,
@@ -481,9 +480,7 @@ def set_shape_i(self, r, i, s_i):
481480
or r.type.shape[idx] != 1
482481
or self.lscalar_one.equals(new_shape[idx])
483482
or self.lscalar_one.equals(
484-
get_underlying_scalar_constant_value(
485-
new_shape[idx], raise_not_constant=False
486-
)
483+
get_scalar_constant_value(new_shape[idx], raise_not_constant=False)
487484
)
488485
for idx in range(r.type.ndim)
489486
)

pytensor/tensor/rewriting/subtensor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ def local_useless_subtensor(fgraph, node):
995995
if isinstance(idx.stop, int | np.integer):
996996
length_pos_data = sys.maxsize
997997
try:
998-
length_pos_data = get_underlying_scalar_constant_value(
998+
length_pos_data = get_scalar_constant_value(
999999
length_pos, only_process_constants=True
10001000
)
10011001
except NotScalarConstantError:
@@ -1060,7 +1060,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
10601060

10611061
# get length of the indexed tensor along the first axis
10621062
try:
1063-
length = get_underlying_scalar_constant_value(
1063+
length = get_scalar_constant_value(
10641064
shape_of[node.inputs[0]][0], only_process_constants=True
10651065
)
10661066
except NotScalarConstantError:
@@ -1732,7 +1732,7 @@ def local_join_subtensors(fgraph, node):
17321732
axis, tensors = node.inputs[0], node.inputs[1:]
17331733

17341734
try:
1735-
axis = get_underlying_scalar_constant_value(axis)
1735+
axis = get_scalar_constant_value(axis)
17361736
except NotScalarConstantError:
17371737
return
17381738

@@ -1793,12 +1793,7 @@ def local_join_subtensors(fgraph, node):
17931793
if step is None:
17941794
continue
17951795
try:
1796-
if (
1797-
get_underlying_scalar_constant_value(
1798-
step, only_process_constants=True
1799-
)
1800-
!= 1
1801-
):
1796+
if get_scalar_constant_value(step, only_process_constants=True) != 1:
18021797
return None
18031798
except NotScalarConstantError:
18041799
return None

0 commit comments

Comments
 (0)