@@ -102,27 +102,31 @@ def _check_static_shape(shape: Shape):
102102 else :
103103 map (_check_static_shape , shapes )
104104
105- def _try_broadcast_shapes (
106- shapes : Sequence [tuple [int , ...]]) -> tuple [int , ...] | None :
107- if len (shapes ) == 1 : return shapes [0 ]
105+ def _try_broadcast_shapes (* shapes : tuple [int , ...], name : str ) -> tuple [int , ...]:
106+ """
107+ Attempt to broadcast shapes, raising a TypeError if broadcasting fails.
108+ """
109+ if not shapes :
110+ raise TypeError (f"{ name } : At least one shape is required." )
108111 ranks = {len (shape ) for shape in shapes }
109- if len (ranks ) > 1 : return None # must have consistent rank
110- rank = ranks . pop ()
111- if not rank : return () # scalar case
112+ if len (ranks ) != 1 :
113+ raise TypeError ( f' { name } : arrays must have the same number of dimensions,'
114+ f' got { ranks } ' )
112115 result_shape = []
113- for ds in unsafe_zip (* shapes ):
116+ for ds in zip (* shapes ):
114117 if all (core .same_referent (d , ds [0 ]) for d in ds [1 :]):
115118 # if all axes are identical objects, the resulting size is the object
116119 result_shape .append (ds [0 ])
117120 else :
118- # if all dims are equal (or 1), the result is the non-1 size (or 1)
121+ # if all dims are equal (or 1), the result is the non-1 size
119122 non_1s = [d for d in ds if not core .definitely_equal (d , 1 )]
120123 if not non_1s :
121124 result_shape .append (1 )
122125 elif all (core .definitely_equal (non_1s [0 ], d ) for d in non_1s [1 :]):
123126 result_shape .append (non_1s [0 ])
124127 else :
125- return None
128+ raise TypeError (f'{ name } got incompatible shapes for broadcasting: '
129+ f'{ ", " .join (map (str , map (tuple , shapes )))} .' )
126130 return tuple (result_shape )
127131
128132def asarray (x : ArrayLike ) -> Array :
@@ -165,11 +169,12 @@ def _broadcast_shapes_uncached(*shapes):
165169 return result_shape
166170
167171 # Next try singleton-broadcasting, padding out ranks using singletons.
168- shape_list = [(1 ,) * (ndim - len (shape )) + shape for shape in shapes ]
169- result_shape = _try_broadcast_shapes (shape_list )
170- if result_shape is None :
171- raise ValueError (f"Incompatible shapes for broadcasting: shapes={ list (shapes )} " )
172- return result_shape
172+ rank_promoted_shapes = tuple ((* ((1 ,) * (ndim - len (shape ))), * shape ) for shape in shapes )
173+ try :
174+ return _try_broadcast_shapes (* rank_promoted_shapes , name = 'broadcast_shapes' )
175+ except TypeError as err :
176+ # Raise ValueError here for backward compatibility.
177+ raise ValueError (f"Incompatible shapes for broadcasting: shapes={ list (shapes )} " ) from err
173178
174179def _identity (x ): return x
175180
@@ -2133,27 +2138,7 @@ def broadcasting_shape_rule(name, *avals):
21332138 shapes = [aval .shape for aval in avals if aval .shape ]
21342139 if not shapes :
21352140 return ()
2136- if len ({len (shape ) for shape in shapes }) != 1 :
2137- msg = '{}: arrays must have same number of dimensions, got {}.'
2138- raise TypeError (msg .format (name , ', ' .join (map (str , map (tuple , shapes )))))
2139- # TODO(mattjj): de-duplicate with _try_broadcast_shapes
2140- result_shape = []
2141- for ds in zip (* shapes ):
2142- if all (core .same_referent (d , ds [0 ]) for d in ds [1 :]):
2143- # if all axes are identical objects, the resulting size is the object
2144- result_shape .append (ds [0 ])
2145- else :
2146- # if all dims are equal (or 1), the result is the non-1 size
2147- non_1s = [d for d in ds if not core .definitely_equal (d , 1 )]
2148- if not non_1s :
2149- result_shape .append (1 )
2150- elif all (core .definitely_equal (non_1s [0 ], d ) for d in non_1s [1 :]):
2151- result_shape .append (non_1s [0 ])
2152- else :
2153- raise TypeError (f'{ name } got incompatible shapes for broadcasting: '
2154- f'{ ", " .join (map (str , map (tuple , shapes )))} .' )
2155-
2156- return tuple (result_shape )
2141+ return _try_broadcast_shapes (* shapes , name = name )
21572142
21582143
21592144def broadcasting_sharding_rule (name , * avals ):
0 commit comments