Skip to content

Commit 8c35191

Browse files
emilyfertigGoogle-ML-Automation
authored andcommitted
Enable jax.device_put to a sharding with no local devices.
PiperOrigin-RevId: 737797815
1 parent 051687d commit 8c35191

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

jax/_src/dispatch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy):
466466
if not s.is_fully_addressable:
467467
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
468468
type(x) in array_types):
469-
multihost_utils.assert_equal(
470-
x, fail_message=(
471-
f"{type(x)} passed to device_put is not the same on each"
472-
" process. Make sure you are passing the same value of"
473-
f" {type(x)} on each process."))
469+
# TODO(emilyaf): Remove this condition when jit works when a sharding
470+
# has no local devices.
471+
if not config.enable_empty_arrays.value:
472+
multihost_utils.assert_equal(
473+
x, fail_message=(
474+
f"{type(x)} passed to device_put is not the same on each"
475+
" process. Make sure you are passing the same value of"
476+
f" {type(x)} on each process."))
474477
return _DeferredShardArg(x, s, aval, True, copy)
475478
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
476479
raise ValueError(

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
237237
if (isinstance(x, array.ArrayImpl) and
238238
dispatch.is_single_device_sharding(x.sharding) and
239239
x.devices() == {d})]
240-
if len(bufs) == len(xs):
240+
if len(bufs) == len(xs) > 0:
241241
return array.ArrayImpl(
242242
aval, sharding, bufs, committed=committed, _skip_checks=True)
243243
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)

0 commit comments

Comments
 (0)