Skip to content

Commit 6a15a26

Browse files
committed
Make non-strict zip strict in tensor/shape.py
1 parent 9e9d65d commit 6a15a26

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

pytensor/tensor/shape.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,15 @@ def specify_shape(
589589
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
590590
# The above is a type error in Python 3.9 but not 3.12.
591591
# Thus we need to ignore unused-ignore on 3.12.
592-
new_shape_info = any(
593-
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
594-
)
592+
595593
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
596-
if not new_shape_info and len(shape) == x.type.ndim:
594+
if len(shape) != x.type.ndim:
595+
return _specify_shape(x, *shape)
596+
597+
new_shape_matches = all(
598+
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
599+
)
600+
if new_shape_matches:
597601
return x
598602

599603
return _specify_shape(x, *shape)

0 commit comments

Comments
 (0)