ShardedDeviceArray from tf.data.Dataset.shard #11272
Unanswered
krzysztofrusek
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I have a question regarding the data pipeline in multi-host training.
In particular, I have multiple workers equipped with GPU and each worker can access central data strorage.
I would like to use
tf.data.Dataset.shard
to load part of the batch independently on each worker and join the shard in a singleShardedDeviceArray
that can be handled bypmap
.It looks like
jax.device_put_sharded
does the job but it requires a list of shards on the host, and I want to load them independently on the workers.I imagine my loop to be like
What is the most efficient way to do it?
Beta Was this translation helpful? Give feedback.
All reactions