Skip to content

Commit 55f3cd0

Browse files
committed
Use more strict get_scalar_constant_value when the input must be a scalar
1 parent 32aadc8 commit 55f3cd0

File tree

12 files changed

+50
-66
lines changed

12 files changed

+50
-66
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Split,
1919
TensorFromScalar,
2020
Tri,
21-
get_underlying_scalar_constant_value,
21+
get_scalar_constant_value,
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424
from pytensor.tensor.shape import Shape_i
@@ -103,7 +103,7 @@ def join(axis, *tensors):
103103
def jax_funcify_Split(op: Split, node, **kwargs):
104104
_, axis, splits = node.inputs
105105
try:
106-
constant_axis = get_underlying_scalar_constant_value(axis)
106+
constant_axis = get_scalar_constant_value(axis)
107107
except NotScalarConstantError:
108108
constant_axis = None
109109
warnings.warn(
@@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
113113
try:
114114
constant_splits = np.array(
115115
[
116-
get_underlying_scalar_constant_value(splits[i])
116+
get_scalar_constant_value(splits[i])
117117
for i in range(get_vector_length(splits))
118118
]
119119
)

pytensor/scan/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def wrap_into_list(x):
484484
n_fixed_steps = int(n_steps)
485485
else:
486486
try:
487-
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
487+
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
488488
except NotScalarConstantError:
489489
n_fixed_steps = None
490490

pytensor/scan/rewriting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
Alloc,
5656
AllocEmpty,
5757
get_scalar_constant_value,
58-
get_underlying_scalar_constant_value,
5958
)
6059
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6160
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -1976,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes):
19761975

19771976
nsteps = node.inputs[0]
19781977
try:
1979-
nsteps = int(get_underlying_scalar_constant_value(nsteps))
1978+
nsteps = int(get_scalar_constant_value(nsteps))
19801979
except NotScalarConstantError:
19811980
pass
19821981

19831982
rep_nsteps = rep_node.inputs[0]
19841983
try:
1985-
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
1984+
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
19861985
except NotScalarConstantError:
19871986
pass
19881987

pytensor/tensor/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,7 +1808,7 @@ def do_constant_folding(self, fgraph, node):
18081808
@_get_vector_length.register(Alloc)
18091809
def _get_vector_length_Alloc(var_inst, var):
18101810
try:
1811-
return get_underlying_scalar_constant_value(var.owner.inputs[1])
1811+
return get_scalar_constant_value(var.owner.inputs[1])
18121812
except NotScalarConstantError:
18131813
raise ValueError(f"Length of {var} cannot be determined")
18141814

@@ -2509,7 +2509,7 @@ def make_node(self, axis, *tensors):
25092509

25102510
if not isinstance(axis, int):
25112511
try:
2512-
axis = int(get_underlying_scalar_constant_value(axis))
2512+
axis = int(get_scalar_constant_value(axis))
25132513
except NotScalarConstantError:
25142514
pass
25152515

@@ -2753,7 +2753,7 @@ def infer_shape(self, fgraph, node, ishapes):
27532753
def _get_vector_length_Join(op, var):
27542754
axis, *arrays = var.owner.inputs
27552755
try:
2756-
axis = get_underlying_scalar_constant_value(axis)
2756+
axis = get_scalar_constant_value(axis)
27572757
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
27582758
return builtins.sum(get_vector_length(a) for a in arrays)
27592759
except NotScalarConstantError:
@@ -4146,7 +4146,7 @@ def make_node(self, a, choices):
41464146
static_out_shape = ()
41474147
for s in out_shape:
41484148
try:
4149-
s_val = get_underlying_scalar_constant_value(s)
4149+
s_val = get_scalar_constant_value(s)
41504150
except (NotScalarConstantError, AttributeError):
41514151
s_val = None
41524152

pytensor/tensor/conv/abstract_conv.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.raise_op import Assert
2626
from pytensor.tensor.basic import (
2727
as_tensor_variable,
28-
get_underlying_scalar_constant_value,
28+
get_scalar_constant_value,
2929
)
3030
from pytensor.tensor.exceptions import NotScalarConstantError
3131
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -497,8 +497,8 @@ def check_dim(given, computed):
497497
if given is None or computed is None:
498498
return True
499499
try:
500-
given = get_underlying_scalar_constant_value(given)
501-
computed = get_underlying_scalar_constant_value(computed)
500+
given = get_scalar_constant_value(given)
501+
computed = get_scalar_constant_value(computed)
502502
return int(given) == int(computed)
503503
except NotScalarConstantError:
504504
# no answer possible, accept for now
@@ -534,7 +534,7 @@ def assert_conv_shape(shape):
534534
out_shape = []
535535
for i, n in enumerate(shape):
536536
try:
537-
const_n = get_underlying_scalar_constant_value(n)
537+
const_n = get_scalar_constant_value(n)
538538
if i < 2:
539539
if const_n < 0:
540540
raise ValueError(
@@ -2203,9 +2203,7 @@ def __init__(
22032203
if imshp_i is not None:
22042204
# Components of imshp should be constant or ints
22052205
try:
2206-
get_underlying_scalar_constant_value(
2207-
imshp_i, only_process_constants=True
2208-
)
2206+
get_scalar_constant_value(imshp_i, only_process_constants=True)
22092207
except NotScalarConstantError:
22102208
raise ValueError(
22112209
"imshp should be None or a tuple of constant int values"
@@ -2218,9 +2216,7 @@ def __init__(
22182216
if kshp_i is not None:
22192217
# Components of kshp should be constant or ints
22202218
try:
2221-
get_underlying_scalar_constant_value(
2222-
kshp_i, only_process_constants=True
2223-
)
2219+
get_scalar_constant_value(kshp_i, only_process_constants=True)
22242220
except NotScalarConstantError:
22252221
raise ValueError(
22262222
"kshp should be None or a tuple of constant int values"

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def make_node(self, x, repeats):
678678
out_shape = [None]
679679
else:
680680
try:
681-
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
681+
const_reps = ptb.get_scalar_constant_value(repeats)
682682
except NotScalarConstantError:
683683
const_reps = None
684684
if const_reps == 1:

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
cast,
5858
fill,
5959
get_scalar_constant_value,
60-
get_underlying_scalar_constant_value,
6160
join,
6261
ones_like,
6362
register_infer_shape,
@@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node):
739738
n_conds = len(node.inputs[1:])
740739
for c in node.inputs[1:]:
741740
try:
742-
const = get_underlying_scalar_constant_value(c)
741+
const = get_scalar_constant_value(c)
743742

744743
if 0 != const.ndim or const == 0:
745744
# Should we raise an error here? How to be sure it
@@ -834,7 +833,7 @@ def local_join_empty(fgraph, node):
834833
return
835834
new_inputs = []
836835
try:
837-
join_idx = get_underlying_scalar_constant_value(
836+
join_idx = get_scalar_constant_value(
838837
node.inputs[0], only_process_constants=True
839838
)
840839
except NotScalarConstantError:

pytensor/tensor/rewriting/math.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node):
153153

154154
x = node.inputs[0]
155155
y = node.inputs[1]
156-
replace = False
157-
try:
158-
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
159-
replace = True
160-
except NotScalarConstantError:
161-
pass
162-
163-
try:
164-
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
165-
replace = True
166-
except NotScalarConstantError:
167-
pass
156+
replace = (
157+
get_underlying_scalar_constant_value(
158+
x, only_process_constants=True, raise_not_constant=False
159+
)
160+
== 0
161+
or get_underlying_scalar_constant_value(
162+
y, only_process_constants=True, raise_not_constant=False
163+
)
164+
== 0
165+
)
168166

169167
if replace:
170168
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
@@ -2111,7 +2109,7 @@ def local_add_remove_zeros(fgraph, node):
21112109
y = get_underlying_scalar_constant_value(inp)
21122110
except NotScalarConstantError:
21132111
y = inp
2114-
if np.all(y == 0.0):
2112+
if y == 0.0:
21152113
continue
21162114
new_inputs.append(inp)
21172115

@@ -2209,7 +2207,7 @@ def local_abs_merge(fgraph, node):
22092207
)
22102208
except NotScalarConstantError:
22112209
return False
2212-
if not (const >= 0).all():
2210+
if not const >= 0:
22132211
return False
22142212
inputs.append(i)
22152213
else:
@@ -2861,7 +2859,7 @@ def _is_1(expr):
28612859
"""
28622860
try:
28632861
v = get_underlying_scalar_constant_value(expr)
2864-
return np.allclose(v, 1)
2862+
return np.isclose(v, 1)
28652863
except NotScalarConstantError:
28662864
return False
28672865

@@ -3029,7 +3027,7 @@ def is_neg(var):
30293027
for idx, mul_input in enumerate(var_node.inputs):
30303028
try:
30313029
constant = get_underlying_scalar_constant_value(mul_input)
3032-
is_minus_1 = np.allclose(constant, -1)
3030+
is_minus_1 = np.isclose(constant, -1)
30333031
except NotScalarConstantError:
30343032
is_minus_1 = False
30353033
if is_minus_1:

pytensor/tensor/rewriting/shape.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
cast,
2424
constant,
2525
get_scalar_constant_value,
26-
get_underlying_scalar_constant_value,
2726
register_infer_shape,
2827
stack,
2928
)
@@ -213,7 +212,7 @@ def shape_ir(self, i, r):
213212
# Do not call make_node for test_value
214213
s = Shape_i(i)(r)
215214
try:
216-
s = get_underlying_scalar_constant_value(s)
215+
s = get_scalar_constant_value(s)
217216
except NotScalarConstantError:
218217
pass
219218
return s
@@ -297,7 +296,7 @@ def unpack(self, s_i, var):
297296
assert len(idx) == 1
298297
idx = idx[0]
299298
try:
300-
i = get_underlying_scalar_constant_value(idx)
299+
i = get_scalar_constant_value(idx)
301300
except NotScalarConstantError:
302301
pass
303302
else:
@@ -452,7 +451,7 @@ def update_shape(self, r, other_r):
452451
)
453452
or self.lscalar_one.equals(merged_shape[i])
454453
or self.lscalar_one.equals(
455-
get_underlying_scalar_constant_value(
454+
get_scalar_constant_value(
456455
merged_shape[i],
457456
only_process_constants=True,
458457
raise_not_constant=False,
@@ -481,9 +480,7 @@ def set_shape_i(self, r, i, s_i):
481480
or r.type.shape[idx] != 1
482481
or self.lscalar_one.equals(new_shape[idx])
483482
or self.lscalar_one.equals(
484-
get_underlying_scalar_constant_value(
485-
new_shape[idx], raise_not_constant=False
486-
)
483+
get_scalar_constant_value(new_shape[idx], raise_not_constant=False)
487484
)
488485
for idx in range(r.type.ndim)
489486
)

pytensor/tensor/rewriting/subtensor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ def local_useless_subtensor(fgraph, node):
999999
if isinstance(idx.stop, int | np.integer):
10001000
length_pos_data = sys.maxsize
10011001
try:
1002-
length_pos_data = get_underlying_scalar_constant_value(
1002+
length_pos_data = get_scalar_constant_value(
10031003
length_pos, only_process_constants=True
10041004
)
10051005
except NotScalarConstantError:
@@ -1064,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
10641064

10651065
# get length of the indexed tensor along the first axis
10661066
try:
1067-
length = get_underlying_scalar_constant_value(
1067+
length = get_scalar_constant_value(
10681068
shape_of[node.inputs[0]][0], only_process_constants=True
10691069
)
10701070
except NotScalarConstantError:
@@ -1736,7 +1736,7 @@ def local_join_subtensors(fgraph, node):
17361736
axis, tensors = node.inputs[0], node.inputs[1:]
17371737

17381738
try:
1739-
axis = get_underlying_scalar_constant_value(axis)
1739+
axis = get_scalar_constant_value(axis)
17401740
except NotScalarConstantError:
17411741
return
17421742

@@ -1797,12 +1797,7 @@ def local_join_subtensors(fgraph, node):
17971797
if step is None:
17981798
continue
17991799
try:
1800-
if (
1801-
get_underlying_scalar_constant_value(
1802-
step, only_process_constants=True
1803-
)
1804-
!= 1
1805-
):
1800+
if get_scalar_constant_value(step, only_process_constants=True) != 1:
18061801
return None
18071802
except NotScalarConstantError:
18081803
return None

0 commit comments

Comments
 (0)