Skip to content

Conversation

@zhengchenyu
Copy link
Contributor

Using replica groups offers the following advantages:

  • For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.

  • Checkpointing is performed only on replica_group_rank=0, guaranteeing constant checkpoint world size and avoiding the universal checkpoint transformations during scaling up or down.

We can achieve gradient all reduce within the replica group after backward and before optimizer.step, but we must wait for all buckets to complete, thus can not leverage concurrency advantages.

I know MICS has similar functionality, but currently only supports zero stage 3. Additionally, I want to use this feature for compatibility with architectures like TorchFT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant