Skip to content

Commit 8e0862f

Browse files
committed
feat: support multi-node DDP + TP + SP parallel training
1 parent ed9fba4 commit 8e0862f

File tree

9 files changed

+98
-127
lines changed

9 files changed

+98
-127
lines changed

example/gpt2/main.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ void Train(const nn::parallel::Rank &rank) {
122122
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
123123

124124
if (ddp_world_size > 1) {
125-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
126-
GetDataParallelGroupRanks(rank.thread_rank()));
125+
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
126+
GetDataParallelGroupRanks(rank.GlobalRank()));
127127
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
128128
}
129129

130130
if (tp_world_size > 1) {
131-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
132-
GetTensorParallelGroupRanks(rank.thread_rank()));
131+
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
132+
GetTensorParallelGroupRanks(rank.GlobalRank()));
133133
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
134134
// NOTE(zbl): Reserved for VocabParallelEmbedding
135135
nn::parallel::tp_rank = tp_rank;
@@ -312,7 +312,7 @@ int main(int argc, char *argv[]) {
312312
if (FLAGS_nthread_per_process > 1) {
313313
std::vector<std::thread> threads;
314314
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
315-
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
315+
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
316316
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
317317
threads.emplace_back(Train, rank);
318318
}

example/llama3/main.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ void Train(const nn::parallel::Rank &rank) {
105105
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
106106

107107
if (ddp_world_size > 1) {
108-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
109-
GetDataParallelGroupRanks(rank.thread_rank()));
110-
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
108+
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
109+
GetDataParallelGroupRanks(rank.GlobalRank()));
110+
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
111111
}
112112

113113
if (tp_world_size > 1) {
114-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
115-
GetTensorParallelGroupRanks(rank.thread_rank()));
114+
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
115+
GetTensorParallelGroupRanks(rank.GlobalRank()));
116116
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
117117
// NOTE(zbl): Reserved for VocabParallelEmbedding
118118
nn::parallel::tp_rank = tp_rank;
@@ -292,7 +292,7 @@ int main(int argc, char *argv[]) {
292292
if (FLAGS_nthread_per_process > 1) {
293293
std::vector<std::thread> threads;
294294
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
295-
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
295+
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
296296
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
297297
threads.emplace_back(Train, rank);
298298
}

infini_train/include/nn/parallel/global.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,4 @@ inline std::vector<int> GetGroupRanks(Axis target, int rank) {
174174
*/
175175
std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true);
176176

177-
#ifdef USE_NCCL
178-
inline ncclUniqueId GetNcclId() { return GlobalEnv::Instance().nccl_id(); }
179-
#endif
180-
181177
} // namespace infini_train::nn::parallel::global

infini_train/include/nn/parallel/process_group.h

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <condition_variable>
34
#include <memory>
45
#include <mutex>
56
#include <string>
@@ -27,10 +28,7 @@ namespace infini_train::nn::parallel {
2728
#ifdef USE_NCCL
2829
class ProcessGroup {
2930
public:
30-
explicit ProcessGroup(const std::vector<int> &device_indices);
31-
32-
// support for multi-node distributed training
33-
explicit ProcessGroup(const ncclUniqueId &nccl_id);
31+
explicit ProcessGroup(const std::string &process_group_name, const std::vector<int> &device_indices);
3432

3533
int GetGroupRank(int thread_rank) const;
3634

@@ -67,6 +65,8 @@ class ProcessGroup {
6765
std::unordered_map<int, int> thread_group_rank_map_; // thread_rank : group_rank
6866

6967
int world_size_ = 0;
68+
69+
const std::string name_ = "";
7070
};
7171
#endif
7272

@@ -80,10 +80,6 @@ class ProcessGroupFactory {
8080

8181
const ProcessGroup *GetOrCreate(const std::string &name, const std::vector<int> &device_indices);
8282

83-
#ifdef USE_NCCL
84-
const ProcessGroup *GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id);
85-
#endif
86-
8783
const ProcessGroup *Get(const std::string &name) const;
8884

8985
const ProcessGroup *GetDefaultProcessGroup() const;
@@ -93,26 +89,26 @@ class ProcessGroupFactory {
9389

9490
template <typename Creator, typename = std::enable_if_t<std::is_invocable_v<Creator>>>
9591
const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) {
96-
{
97-
std::lock_guard<std::mutex> lock(mutex_);
98-
auto it = name_to_group_.find(name);
99-
if (it != name_to_group_.end()) {
100-
return it->second.get();
101-
}
92+
std::unique_lock<std::mutex> lock(mutex_);
93+
auto [it, inserted] = name_to_group_.emplace(name, nullptr);
94+
if (!inserted) {
95+
while (it->second == nullptr) { cond_.wait(lock); }
96+
return it->second.get();
10297
}
10398

99+
lock.unlock();
104100
auto new_group = creator();
101+
lock.lock();
105102

106-
{
107-
std::lock_guard<std::mutex> lock(mutex_);
108-
auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group));
109-
return it->second.get();
110-
}
103+
it->second = std::move(new_group);
104+
cond_.notify_all();
105+
return it->second.get();
111106
}
112107

113108
private:
114109
// TODO(dcj): maybe RWLock later?
115110
mutable std::mutex mutex_;
111+
std::condition_variable cond_;
116112
std::unordered_map<std::string, std::unique_ptr<ProcessGroup>> name_to_group_;
117113
};
118114
} // namespace infini_train::nn::parallel

infini_train/src/device.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ nn::parallel::Rank CudaDevice::rank() const { return rank_; }
6464

6565
CudaDevice::CudaDevice(int8_t index)
6666
: Device(DeviceType::kCUDA, index),
67-
rank_({nn::parallel::global::GetLocalProcRank(), index, nn::parallel::global::GetNprocPerNode(),
67+
rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(),
6868
nn::parallel::global::GetNthreadPerProc()}) {
6969
// TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode
7070
SetDevice();

infini_train/src/nn/parallel/distributed_data_parallel.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
2323
CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module";
2424

2525
auto ddp_pg
26-
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank()));
26+
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().GlobalRank()));
2727
// FIXME(dcj): use multi-node ddp_pg here
28-
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg);
28+
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg,
29+
ddp_pg);
2930
param->RegisterPostAccumulateGradHook(std::move(hook));
3031
}
3132
for (auto &buffer : module->Buffers()) {

infini_train/src/nn/parallel/global.cc

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,6 @@ std::string GetEnvAsStr(const std::string &name, const std::string &default_valu
1919
return value ? std::string(value) : default_value;
2020
}
2121

22-
#ifdef USE_NCCL
23-
ncclUniqueId StringToNcclId(const std::string &str) {
24-
ncclUniqueId id;
25-
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
26-
unsigned int byte;
27-
std::stringstream ss;
28-
ss << std::hex << str.substr(i * 2, 2);
29-
ss >> byte;
30-
id.internal[i] = static_cast<char>(byte);
31-
}
32-
return id;
33-
}
34-
#endif
35-
3622
} // namespace
3723

3824
namespace infini_train::nn::parallel::global {
@@ -125,10 +111,6 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
125111
// FIXME(zbl): set PP size
126112
layout_.sizes[PP] = 1;
127113
layout_.InitStrides();
128-
// FIXME(dcj): what if no nccl id?
129-
#ifdef USE_NCCL
130-
nccl_id_ = StringToNcclId(GetEnvAsStr("NCCL_UNIQUE_ID", ""));
131-
#endif
132114

133115
initialized_ = true;
134116
}
@@ -225,34 +207,6 @@ inline void AppendAxisGroups(std::ostringstream &oss, const Layout &L, Axis targ
225207
}
226208
}
227209

228-
/**
229-
* @brief Generate a human-readable overview of all parallel communication groups.
230-
*
231-
* The output is intended for debugging, logging, and runtime verification of
232-
* distributed parallelism configuration.
233-
*
234-
* @param L The Layout describing DP / TP / PP sizes and axis ordering.
235-
* @param skip_trivial_axes
236-
* If true, axes whose size <= 1(i.e. parallel strategy that is not enabled)
237-
* will be marked as "unenabled" and their detailed group listing will be skipped.
238-
*
239-
* @return A formatted string containing the full overview of process groups.
240-
*
241-
* Example:
242-
* === Parallel Communication Groups ===
243-
* world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP}
244-
* [DP] size=2, num_groups=4
245-
* - DP 0 (dp=-, tp=0, pp=0): [0, 4]
246-
* - DP 1 (dp=-, tp=1, pp=0): [1, 5]
247-
* - DP 2 (dp=-, tp=2, pp=0): [2, 6]
248-
* - DP 3 (dp=-, tp=3, pp=0): [3, 7]
249-
*
250-
* [TP] size=4, num_groups=2
251-
* - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3]
252-
* - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7]
253-
*
254-
* [PP] size=1, unenabled
255-
*/
256210
std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) {
257211
std::ostringstream oss;
258212
oss << std::format("\n=== Parallel Communication Groups ===\n"
@@ -276,11 +230,5 @@ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) {
276230
oss << "\n";
277231
return oss.str();
278232
}
279-
#ifdef USE_NCCL
280-
ncclUniqueId GlobalEnv::nccl_id() const {
281-
CHECK(initialized_) << "GlobalEnv is not initialized!";
282-
return nccl_id_;
283-
}
284-
#endif
285233

286234
} // namespace infini_train::nn::parallel::global

0 commit comments

Comments
 (0)