Skip to content

Commit 8699f5d

Browse files
yashk2810Google-ML-Automation
authored andcommitted
When host local inputs on all hosts are the same, use _DeferredShardArg to do the transfers instead of jit to avoid blocking.
PiperOrigin-RevId: 699336402
1 parent 030ee4a commit 8699f5d

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

jax/_src/dispatch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,15 +411,12 @@ def _device_put_sharding_impl(x, aval, device, copy):
411411
if not s.is_fully_addressable:
412412
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
413413
type(x) in array_types):
414-
# TODO(yashkatariya): Move this check to `jit`.
415414
multihost_utils.assert_equal(
416415
x, fail_message=(
417416
f"{type(x)} passed to device_put is not the same on each"
418417
" process. Make sure you are passing the same value of"
419418
f" {type(x)} on each process."))
420-
return api.jit(
421-
_identity_fn, out_shardings=s,
422-
donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x)
419+
return _DeferredShardArg(x, s, aval, True, copy)
423420
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
424421
raise ValueError(
425422
"device_put's second argument must be a Device or a Sharding which"

0 commit comments

Comments
 (0)