Skip to content

Commit ccfef7a

Browse files
Merge pull request jax-ml#25424 from jakevdp:dedupe-broadcast
PiperOrigin-RevId: 705261094
2 parents e92ca9b + c40780b commit ccfef7a

File tree

1 file changed

+20
-35
lines changed

1 file changed

+20
-35
lines changed

jax/_src/lax/lax.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,31 @@ def _check_static_shape(shape: Shape):
102102
else:
103103
map(_check_static_shape, shapes)
104104

105-
def _try_broadcast_shapes(
106-
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
107-
if len(shapes) == 1: return shapes[0]
105+
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
106+
"""
107+
Attempt to broadcast shapes, raising a TypeError if broadcasting fails.
108+
"""
109+
if not shapes:
110+
raise TypeError(f"{name}: At least one shape is required.")
108111
ranks = {len(shape) for shape in shapes}
109-
if len(ranks) > 1: return None # must have consistent rank
110-
rank = ranks.pop()
111-
if not rank: return () # scalar case
112+
if len(ranks) != 1:
113+
raise TypeError(f'{name}: arrays must have the same number of dimensions,'
114+
f' got {ranks}')
112115
result_shape = []
113-
for ds in unsafe_zip(*shapes):
116+
for ds in zip(*shapes):
114117
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
115118
# if all axes are identical objects, the resulting size is the object
116119
result_shape.append(ds[0])
117120
else:
118-
# if all dims are equal (or 1), the result is the non-1 size (or 1)
121+
# if all dims are equal (or 1), the result is the non-1 size
119122
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
120123
if not non_1s:
121124
result_shape.append(1)
122125
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
123126
result_shape.append(non_1s[0])
124127
else:
125-
return None
128+
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
129+
f'{", ".join(map(str, map(tuple, shapes)))}.')
126130
return tuple(result_shape)
127131

128132
def asarray(x: ArrayLike) -> Array:
@@ -165,11 +169,12 @@ def _broadcast_shapes_uncached(*shapes):
165169
return result_shape
166170

167171
# Next try singleton-broadcasting, padding out ranks using singletons.
168-
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
169-
result_shape = _try_broadcast_shapes(shape_list)
170-
if result_shape is None:
171-
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
172-
return result_shape
172+
rank_promoted_shapes = tuple((*((1,) * (ndim - len(shape))), *shape) for shape in shapes)
173+
try:
174+
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
175+
except TypeError as err:
176+
# Raise ValueError here for backward compatibility.
177+
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
173178

174179
def _identity(x): return x
175180

@@ -2133,27 +2138,7 @@ def broadcasting_shape_rule(name, *avals):
21332138
shapes = [aval.shape for aval in avals if aval.shape]
21342139
if not shapes:
21352140
return ()
2136-
if len({len(shape) for shape in shapes}) != 1:
2137-
msg = '{}: arrays must have same number of dimensions, got {}.'
2138-
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
2139-
# TODO(mattjj): de-duplicate with _try_broadcast_shapes
2140-
result_shape = []
2141-
for ds in zip(*shapes):
2142-
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
2143-
# if all axes are identical objects, the resulting size is the object
2144-
result_shape.append(ds[0])
2145-
else:
2146-
# if all dims are equal (or 1), the result is the non-1 size
2147-
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
2148-
if not non_1s:
2149-
result_shape.append(1)
2150-
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
2151-
result_shape.append(non_1s[0])
2152-
else:
2153-
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
2154-
f'{", ".join(map(str, map(tuple, shapes)))}.')
2155-
2156-
return tuple(result_shape)
2141+
return _try_broadcast_shapes(*shapes, name=name)
21572142

21582143

21592144
def broadcasting_sharding_rule(name, *avals):

0 commit comments

Comments
 (0)