File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed
Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -578,11 +578,15 @@ def specify_shape(
578578 x = ptb .as_tensor_variable (x ) # type: ignore[arg-type,unused-ignore]
579579 # The above is a type error in Python 3.9 but not 3.12.
580580 # Thus we need to ignore unused-ignore on 3.12.
581- new_shape_info = any (
582- s != xts for (s , xts ) in zip (shape , x .type .shape , strict = False ) if s is not None
583- )
581+
584582 # If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
585- if not new_shape_info and len (shape ) == x .type .ndim :
583+ if len (shape ) != x .type .ndim :
584+ return _specify_shape (x , * shape )
585+
586+ new_shape_matches = all (
587+ s == xts for (s , xts ) in zip (shape , x .type .shape , strict = True ) if s is not None
588+ )
589+ if new_shape_matches :
586590 return x
587591
588592 return _specify_shape (x , * shape )
You can’t perform that action at this time.
0 commit comments