Skip to content

Commit 1ec0585

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix process_allgather of global jax.Arrays with shardy
PiperOrigin-RevId: 738823617
1 parent dad1b41 commit 1ec0585

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

jax/experimental/multihost_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,11 @@ def _identity_fn(x):
9999

100100
def _handle_array_process_allgather(inp, tiled):
101101
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
102-
reps = sharding_impls.GSPMDSharding.get_replicated(
103-
inp.sharding._device_assignment)
102+
if isinstance(inp.sharding, sharding_impls.NamedSharding):
103+
reps = inp.sharding.with_spec(P())
104+
else:
105+
reps = sharding_impls.GSPMDSharding.get_replicated(
106+
inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
104107
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
105108
else:
106109
# All inputs here will be fully addressable.

0 commit comments

Comments
 (0)