The choice of sharding APIs #31597
Replies: 2 comments 1 reply
-
There are a lot of questions here so I'll do my best to answer as much as I can. Before when we had Auto and Manual modes (no Explicit), there was no way to just control sharding propagation and leave partitioning to the compiler. Which is why But now with Explicit mode, you can take control over sharding propagation leaving partitioning of the computation to XLA. For FSDP, this ends up being enough without needing to use shard_map but there is nothing wrong with dropping into full shmap mode if you want full control. There are complicated FSDP cases where shmap helps more if you want to take over control of communication/compute overlap and schedule them as you please. These kinds of decisions are very subjective and up to the taste of the user which is why JAX doesn't enforce any opinion here. We do have a bias towards using Explicit and Manual mode only and dropping into Auto mode where required instead of being Auto by default. So I would say, use what works for you and what you are comfortable with :) You can mix and match all 3 modes as you please.
Yes, that's correct.
Maybe we should. That might just be a typo/bug. Does this help? |
Beta Was this translation helpful? Give feedback.
-
Yes, yes! I know that there is a much cleaner mental model to mix explicit and shard_map. But as I said, I couldn't find any examples on demonstrating the best practices for that. I like the
Agreed, but I am not saying you enforce it on everyone. All I am asking is your opinion for a simple example I provided above.
Yeah, maybe we should rewrite that example. Because it does not make sense to shard the arrays in one way (especially when you know what you want to do), and then flip the partitionspec in shmap. It will end up confusing a lot of people. Also, can you please comment on the 8-way mesh-product sharding in that example which is a bit different from the normal fsdp-tp? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I think it is great to that we now have a choice of explicit and manual sharding APIs, but one thing have left me confused. Every time someone asks me what API should they choose for their JAX workflows, I almost always think of the caveats, and as as result I almost lean towards
shard_map
because it is as transparent as it gets. Butshard_map
also comes with a lot of verbosity for simple workflows, and it is always not clear what's the best strategy to opt for especially if you are running things on a small scale but the workflow has to be scaled up for the final run. To elaborate the last point, let us look at an example (directly from JAX docs) demonstrating the FSDP + TP workflow.We will modify this example to a classification problem but on a toy dataset before we run this pipeline on scale. For the toy dataset, you can consider any dataset. Though MNIST and CIFAR-10 is not something people really work on, but just for the demonstration purpose:
In this setup, except for the out features of the final layer, everything can be sharded easily. We will shard only the weights, not the biases, of all the hidden layers, and we will keep the last layer replicated for simplicity. What's the ideal strategy here? Should one do shard_map for hidden layers and separate out the final layer forward pass? Or should we explicit sharding with it?
A few other comments related to the original example:
the general colwise-rowwise pattern. I haven't done the maths for collective on paper, so can't say much about the communication time. Is there any special reason behind doing this partitioning?
P(('batch', 'feats')))
butin_specs
is set toP(('feats', 'batch')))
. Does that mean irrespective of the original sharding of an array, the blocks will always be laid out according toin_specs
? If not, then why change the partitionspec?Beta Was this translation helpful? Give feedback.
All reactions