Skip to content

Commit 6cc74b7

Browse files
committed
fix: make multi-node DDP precision work
1 parent 78047a7 commit 6cc74b7

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

example/llama3/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void Train(const nn::parallel::Rank &rank) {
107107
if (ddp_world_size > 1) {
108108
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
109109
GetDataParallelGroupRanks(rank.thread_rank()));
110-
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
110+
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
111111
}
112112

113113
if (tp_world_size > 1) {

infini_train/include/nn/parallel/rank.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ class Rank {
1010
int process_size() const;
1111
int thread_size() const;
1212

13+
int GlobalRank() const;
14+
1315
bool IsParallel() const;
1416

1517
bool IsMainRank() const;

infini_train/src/nn/parallel/distributed_data_parallel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
2424

2525
auto ddp_pg
2626
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank()));
27-
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg,
28-
ddp_pg);
27+
// FIXME(dcj): use multi-node ddp_pg here
28+
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg);
2929
param->RegisterPostAccumulateGradHook(std::move(hook));
3030
}
3131
for (auto &buffer : module->Buffers()) {

infini_train/src/nn/parallel/process_group.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "infini_train/include/nn/parallel/process_group.h"
22

3+
#include <algorithm>
34
#include <memory>
45
#include <numeric>
56
#include <vector>
@@ -51,6 +52,8 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge
5152
device_indices[i] = i;
5253

5354
int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i;
55+
56+
cudaSetDevice(i);
5457
NCCL_CHECK(ncclCommInitRank(&comms_[i], world_size_, nccl_id, global_rank));
5558
}
5659
NCCL_CHECK(ncclGroupEnd());
@@ -59,11 +62,13 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge
5962
}
6063

6164
void ProcessGroup::Init(const std::vector<int> &device_indices) {
62-
for (int i = 0; i < world_size_; ++i) {
65+
// FIXME(dcj): This is a temporary solution to get the device and comm for each thread.
66+
int local_comm_size = std::min(static_cast<int>(device_indices.size()), global::GetNthreadPerProc());
67+
for (int i = 0; i < local_comm_size; ++i) {
6368
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]);
6469
devices_.push_back(device);
6570
device_comm_map_[device] = comms_[i];
66-
thread_group_rank_map_[device->rank().thread_rank()] = i;
71+
thread_group_rank_map_[device->rank().thread_rank()] = i + global::GetGlobalProcRank() * local_comm_size;
6772
}
6873
}
6974

infini_train/src/nn/parallel/rank.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ int Rank::thread_rank() const { return thread_rank_; }
1010
int Rank::process_size() const { return process_size_; }
1111
int Rank::thread_size() const { return thread_size_; }
1212

13+
int Rank::GlobalRank() const { return process_rank_ * thread_size_ + thread_rank_; }
14+
1315
bool Rank::IsParallel() const { return thread_size_ * process_size_ > 1; }
1416
bool Rank::IsMainRank() const { return thread_rank_ == 0; }
1517
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)