Skip to content

Commit b3833dc

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Match the behavior on single host wrt multi-host if tiled=False. Fixes #25783
PiperOrigin-RevId: 713398173
1 parent 6e1f060 commit b3833dc

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

jax/experimental/multihost_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def _handle_array_process_allgather(inp, tiled):
103103
else:
104104
# All inputs here will be fully addressable.
105105
if jax.process_count() == 1:
106-
return np.asarray(inp)
106+
out = np.asarray(inp)
107+
return np.expand_dims(out, axis=0) if not tiled else out
107108

108109
devices = np.array(jax.devices()).reshape(jax.process_count(),
109110
jax.local_device_count())

tests/array_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,14 @@ def test_array_is_ready(self):
700700

701701
def test_process_allgather_single_host(self):
702702
x = jnp.arange(8.)
703-
out = multihost_utils.process_allgather(x)
703+
out = multihost_utils.process_allgather(x, tiled=True)
704704
self.assertEqual(out.shape, x.shape)
705705
self.assertArraysEqual(out, x)
706706

707+
out = multihost_utils.process_allgather(x)
708+
self.assertEqual(out.shape, (1, x.shape[0]))
709+
self.assertArraysEqual(out, np.expand_dims(x, axis=0))
710+
707711
@jtu.sample_product(
708712
dtype=jtu.dtypes.all,
709713
shape=[(), (10), (2, 3)],

0 commit comments

Comments
 (0)