Skip to content

Commit 76d8b9c

Browse files
committed
internal: simplify broadcast_shapes logic
1 parent 5fe8bcc commit 76d8b9c

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

jax/_src/lax/lax.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,18 @@ def _broadcast_shapes_uncached(*shapes):
159159
if not rst: return fst
160160

161161
# First check if we need only rank promotion (and not singleton-broadcasting).
162-
try: return _reduce(_broadcast_ranks, rst, fst)
163-
except ValueError: pass
162+
result_shape = _max(shapes, key=len)
163+
ndim = len(result_shape)
164+
if ndim == 0 or all(core.definitely_equal_shape(result_shape[ndim - len(s):], s) for s in shapes):
165+
return result_shape
164166

165167
# Next try singleton-broadcasting, padding out ranks using singletons.
166-
ndim = _max(len(shape) for shape in shapes)
167168
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
168169
result_shape = _try_broadcast_shapes(shape_list)
169170
if result_shape is None:
170171
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
171172
return result_shape
172173

173-
def _broadcast_ranks(s1, s2):
174-
if len(s1) > len(s2):
175-
s1, s2 = s2, s1
176-
assert len(s1) <= len(s2)
177-
s1_ = s2[len(s2) - len(s1):]
178-
if core.definitely_equal_shape(s1_, s1): return s2
179-
else: raise ValueError
180-
181174
def _identity(x): return x
182175

183176
def _extract_tracers_dyn_shape(

0 commit comments

Comments
 (0)