[model parallelism] Is it possible to pre-shard JAX arrays without relying on annotation with_sharding_constraint
?
#8597
Unanswered
sudhakarsingh27
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Currently, user needs to provide JAX with sharding information using
with_sharding_constraint
to achieve model-parallelism. Although JAX can automatically do the sharding, what if we want to manually shard the arrays apriori so that the JAX doesn't spend time in sharding of the arrays at runtime (at least for the first time).Is it possible to do so?
I think basically it'd mean that we pre-shard the arrays and provide that information with something like
with_sharding_constraint
API but then JAX wouldn't have to do the sharding. Taking it a step further, can we then do pre-sharding of arrays (not just for the first time, but during model run as well) which would be separate from JAX but then providing this sharding information to JAX would still allow it to add any communication primitives necessary without worrying about sharding. I'm not sure how this would be compatible withjit
though. I would like to know any thoughts on this as well.Thanks!
Beta Was this translation helpful? Give feedback.
All reactions