Skip to content
Discussion options

You must be logged in to vote

Hi!

So the problem here is that device_put cannot transfer across hosts (we know about this and we are looking into improving the situation here). On single host, it works out as you know but will fail on multiple hosts.

A better thing here is to use jax.make_array_from_callback because the input on every host is the same i.e. it's arr = jnp.arange(32*4).reshape(32, 4). make_array_from_callback will carve out the shards that each device needs on that host. Here how the code will look:

import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

arr = np.arange(32*4).reshape(32, 4)
n_devices = jax.device_count(…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
3 replies
@yixiaoer
Comment options

@yashk2810
Comment options

@yixiaoer
Comment options

Answer selected by yixiaoer
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants