Skip to content

Conversation

@zhengchenyu
Copy link
Contributor

Using replica groups offers the following advantages:

  • For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.

  • Checkpointing is performed only on replica_group_rank=0, guaranteeing constant checkpoint world size and avoiding the universal checkpoint transformations during scaling up or down.

We can achieve gradient all reduce within the replica group after backward and before optimizer.step, but we must wait for all buckets to complete, thus can not leverage concurrency advantages.

I know MICS has similar functionality, but currently only supports zero stage 3. Additionally, I want to use this feature for compatibility with architectures like TorchFT.

@sfc-gh-truwase
Copy link
Collaborator

@zhengchenyu thanks for the PR. Can you provide some clarification for the motivation?

  • For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.

We already provide a form of this functionality in hpZ component of ZeRO++. Have you explored whether hpZ would meet your needs?

I know MICS has similar functionality, but currently only supports zero stage 3.

My understanding replica groups is only relevant for zero stage 3 since lower stages don't do parameter partitioning. Can you explain how replica groups exist in your workload?

@zhengchenyu
Copy link
Contributor Author

@sfc-gh-truwase Thanks for your review!
My main motivation is to support torchft to achieve fault tolerance. At the same time, I aim to solve the two following problems:

  • (1) During the forward and backward, parameter gather occurs on all machines.
  • (2) The zero checkpoint adjusts with the world size, leading to the universal checkpoint conversion.

Regarding zero++. It cannot solve problem (2). It can solve problem (1), but there is a cost involved, we must introduce extra ds_secondory_tensor. Moreover, in the first forward of each step, parameters still need to be collected on all machines.

Regarding MICS. For zero stage 3, these two problem do not exist. For stage 1/2, there are no problems (1), but if the optimizer parameters are considered when loading the checkpoint, there will be problem for issue (2).

@sfc-gh-truwase
Copy link
Collaborator

Thanks for sharing more details.

  1. ft replicas: While I appreciate the importance of this work, I am concerned about the cost of such an intrusive change while your project is still evolving. My understanding is that you want to synchronize gradients across independent training replicas. In that case why not perform the synchronization explicitly at the client-level. You could use these APIs to obtain the gradients from each engine.

  2. zero_checkpoints: Restricting checkpoint creation to one (or few) replicas will unacceptably increase checkpointing slowdown, especially at large-scale. We consider the one-time universal checkpoint conversion cost a reasonable trade-off as discussed in the paper.

@zhengchenyu
Copy link
Contributor Author

@sfc-gh-truwase Thanks for your reply.
For 1: Initially, I did use averaged_gradients for allreduce on the client-level, but I felt it was slightly inefficient because it had to wait for all buckets to be collected.
For 2: You're right. However, I think if fault tolerance is handled well, we can appropriately reduce the checkpoint frequency.
Since torchft is still evolving, we can postpone this PR.

@sfc-gh-truwase
Copy link
Collaborator

For 1: Initially, I did use averaged_gradients for allreduce on the client-level, but I felt it was slightly inefficient because it had to wait for all buckets to be collected.

Yes, I agree that existing options like averaged_gradients will be inefficient. But we can revisit a proper deepspeed support once your work gets more matured and various situations like e2e performance are figured out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants