We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dad1b41 commit 1ec0585Copy full SHA for 1ec0585
jax/experimental/multihost_utils.py
@@ -99,8 +99,11 @@ def _identity_fn(x):
99
100
def _handle_array_process_allgather(inp, tiled):
101
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
102
- reps = sharding_impls.GSPMDSharding.get_replicated(
103
- inp.sharding._device_assignment)
+ if isinstance(inp.sharding, sharding_impls.NamedSharding):
+ reps = inp.sharding.with_spec(P())
104
+ else:
105
+ reps = sharding_impls.GSPMDSharding.get_replicated(
106
+ inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
107
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
108
else:
109
# All inputs here will be fully addressable.
0 commit comments