File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -589,11 +589,15 @@ def specify_shape(
589
589
x = ptb .as_tensor_variable (x ) # type: ignore[arg-type,unused-ignore]
590
590
# The above is a type error in Python 3.9 but not 3.12.
591
591
# Thus we need to ignore unused-ignore on 3.12.
592
- new_shape_info = any (
593
- s != xts for (s , xts ) in zip (shape , x .type .shape , strict = False ) if s is not None
594
- )
592
+
595
593
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
596
- if not new_shape_info and len (shape ) == x .type .ndim :
594
+ if len (shape ) != x .type .ndim :
595
+ return _specify_shape (x , * shape )
596
+
597
+ new_shape_matches = all (
598
+ s == xts for (s , xts ) in zip (shape , x .type .shape , strict = True ) if s is not None
599
+ )
600
+ if new_shape_matches :
597
601
return x
598
602
599
603
return _specify_shape (x , * shape )
You can’t perform that action at this time.
0 commit comments