-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Introduce all_reduce_hook to support gradient aggregation across replica groups. #7764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…ica groups. Signed-off-by: zhengchenyu <[email protected]>
|
@zhengchenyu thanks for the PR. Can you provide some clarification for the motivation?
We already provide a form of this functionality in hpZ component of ZeRO++. Have you explored whether hpZ would meet your needs?
My understanding replica groups is only relevant for zero stage 3 since lower stages don't do parameter partitioning. Can you explain how replica groups exist in your workload? |
|
@sfc-gh-truwase Thanks for your review!
Regarding zero++. It cannot solve problem (2). It can solve problem (1), but there is a cost involved, we must introduce extra Regarding MICS. For zero stage 3, these two problem do not exist. For stage 1/2, there are no problems (1), but if the optimizer parameters are considered when loading the checkpoint, there will be problem for issue (2). |
|
Thanks for sharing more details.
|
|
@sfc-gh-truwase Thanks for your reply. |
Yes, I agree that existing options like |
Using replica groups offers the following advantages:
For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.
Checkpointing is performed only on
replica_group_rank=0, guaranteeing constant checkpoint world size and avoiding the universal checkpoint transformations during scaling up or down.We can achieve gradient all reduce within the replica group after backward and before optimizer.step, but we must wait for all buckets to complete, thus can not leverage concurrency advantages.
I know MICS has similar functionality, but currently only supports zero stage 3. Additionally, I want to use this feature for compatibility with architectures like TorchFT.