-
Following up to the issue reported in #16788, I now want to distribute an array over the devices and apply some function to them as a simple test. import jax
import numpy as np
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
jax.distributed.initialize()
jax.config.update("jax_enable_x64", True)
print(jax.devices(), jax.local_devices())
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
@pjit
def func(x):
return jax.numpy.sqrt(x)
with mesh:
x = jax.random.uniform(jax.random.PRNGKey(0), (2 ** 8, 2 ** 8))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
z = func(x)
jax.debug.visualize_array_sharding(z)
jax.debug.visualize_array_sharding(x)
result = np.asarray(z)
start = np.asarray(x)
assert np.allclose(np.sqrt(start), result)
print(result.size) But this fails (when run with a job-script like in #16788) with the following error:
Traceback
I see that this happens because of but I am not sure what I am supposed to do differently. Does anyone know how to apply the idea of distributed arrays and later shard_map to multi-host setups? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I found the following function with jax.spmd_mode('allow_all'): from here: |
Beta Was this translation helpful? Give feedback.
-
If your setup is fully data parallel then you can use this: https://github.com/google/jax/blob/main/jax/_src/array.py#L647-L659 |
Beta Was this translation helpful? Give feedback.
If your setup is fully data parallel then you can use this: https://github.com/google/jax/blob/main/jax/_src/array.py#L647-L659