Skip to content

Commit 70c64a0

Browse files
committed
Use more strict get_scalar_constant_value when the input must be a scalar
1 parent 96dbda4 commit 70c64a0

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

pytensor/tensor/shape.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pytensor.tensor import (
1919
_get_vector_length,
2020
as_tensor_variable,
21-
get_scalar_constant_value,
2221
get_vector_length,
2322
)
2423
from pytensor.tensor import basic as ptb
@@ -433,7 +432,7 @@ def make_node(self, x, *shape):
433432
type_shape[i] = xts
434433
else:
435434
try:
436-
type_s = get_scalar_constant_value(s)
435+
type_s = ptb.get_scalar_constant_value(s)
437436
if type_s is not None:
438437
type_shape[i] = int(type_s)
439438
except NotScalarConstantError:

0 commit comments

Comments
 (0)