Skip to content

Commit 39e4f7f

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
1 parent ccfef7a commit 39e4f7f

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

jax/_src/lax/lax.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,27 @@ def _broadcast_shapes_uncached(*shapes):
176176
# Raise ValueError here for backward compatibility.
177177
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
178178

179+
def broadcast_shardings(*avals) -> NamedSharding:
180+
fst, *rst = avals
181+
if not rst:
182+
return fst.sharding
183+
184+
# First check if we need only rank promotion (and not singleton-broadcasting).
185+
res_aval = _max(avals, key=lambda a: a.ndim)
186+
ndim = res_aval.ndim
187+
if ndim == 0 or all(
188+
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
189+
return res_aval.sharding
190+
191+
# Next try singleton-broadcasting, padding out ranks using singletons.
192+
aval_list = []
193+
for a in avals:
194+
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
195+
new_shape = (1,) * (ndim - a.ndim) + a.shape
196+
aval_list.append(a.update(shape=new_shape,
197+
sharding=a.sharding.with_spec(new_spec)))
198+
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
199+
179200
def _identity(x): return x
180201

181202
def _extract_tracers_dyn_shape(

jax/_src/numpy/util.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jax._src import config
2424
from jax._src import core
2525
from jax._src import dtypes
26+
from jax._src import api_util
2627
from jax._src.lax import lax
2728
from jax._src.util import safe_zip, safe_map
2829
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
@@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
213214
@partial(api.jit, inline=True)
214215
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
215216
"""Like Numpy's broadcast_arrays but doesn't return views."""
216-
shapes = [np.shape(arg) for arg in args]
217+
avals = [api_util.shaped_abstractify(arg) for arg in args]
218+
shapes = [a.shape for a in avals]
217219
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
218220
return [lax.asarray(arg) for arg in args]
219221
result_shape = lax.broadcast_shapes(*shapes)
220-
return [_broadcast_to(arg, result_shape) for arg in args]
222+
result_sharding = (lax.broadcast_shardings(*avals) # type: ignore
223+
if config.sharding_in_types.value else None)
224+
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
221225

222226

223-
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
227+
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
228+
) -> Array:
224229
check_arraylike("broadcast_to", arr)
225230
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
226231
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
@@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
240245
if nlead < 0 or not compatible:
241246
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
242247
raise ValueError(msg.format(arr_shape, shape))
243-
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))))
248+
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
249+
sharding=sharding)
244250

245251

246252
# The `jit` on `where` exists to avoid materializing constants in cases like

tests/pjit_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5548,6 +5548,16 @@ def f(x, x2):
55485548
"AxisTypes should be the same in a tuple subset of PartitionSpec"):
55495549
NamedSharding(mesh2, P(('x', 'y')))
55505550

5551+
@jtu.with_user_mesh((2, 2), ('x', 'y'))
5552+
def test_where_with_scalar(self, mesh):
5553+
np_inp = np.arange(16.).reshape(8, 2)
5554+
s = NamedSharding(mesh, P('x', 'y'))
5555+
x = jax.device_put(np_inp, s)
5556+
5557+
out = jnp.where(x > 0, x, 0)
5558+
self.assertArraysEqual(out, x)
5559+
self.assertEqual(out.sharding, s)
5560+
55515561

55525562
@jtu.pytest_mark_if_available('multiaccelerator')
55535563
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)