Skip to content

Commit 78a5f68

Browse files
committed
tmp
1 parent 6cc74b7 commit 78a5f68

File tree

8 files changed

+132
-60
lines changed

8 files changed

+132
-60
lines changed

example/gpt2/main.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ void Train(const nn::parallel::Rank &rank) {
122122
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
123123

124124
if (ddp_world_size > 1) {
125-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
126-
GetDataParallelGroupRanks(rank.thread_rank()));
125+
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
126+
GetDataParallelGroupRanks(rank.GlobalRank()));
127127
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
128128
}
129129

130130
if (tp_world_size > 1) {
131-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
132-
GetTensorParallelGroupRanks(rank.thread_rank()));
131+
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
132+
GetTensorParallelGroupRanks(rank.GlobalRank()));
133133
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
134134
// NOTE(zbl): Reserved for VocabParallelEmbedding
135135
nn::parallel::tp_rank = tp_rank;
@@ -312,7 +312,7 @@ int main(int argc, char *argv[]) {
312312
if (FLAGS_nthread_per_process > 1) {
313313
std::vector<std::thread> threads;
314314
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
315-
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
315+
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
316316
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
317317
threads.emplace_back(Train, rank);
318318
}

example/llama3/main.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ void Train(const nn::parallel::Rank &rank) {
105105
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
106106

107107
if (ddp_world_size > 1) {
108-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
109-
GetDataParallelGroupRanks(rank.thread_rank()));
110-
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
108+
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
109+
GetDataParallelGroupRanks(rank.GlobalRank()));
110+
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
111111
}
112112

113113
if (tp_world_size > 1) {
114-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
115-
GetTensorParallelGroupRanks(rank.thread_rank()));
114+
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
115+
GetTensorParallelGroupRanks(rank.GlobalRank()));
116116
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
117117
// NOTE(zbl): Reserved for VocabParallelEmbedding
118118
nn::parallel::tp_rank = tp_rank;
@@ -292,7 +292,7 @@ int main(int argc, char *argv[]) {
292292
if (FLAGS_nthread_per_process > 1) {
293293
std::vector<std::thread> threads;
294294
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
295-
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
295+
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
296296
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
297297
threads.emplace_back(Train, rank);
298298
}

infini_train/include/nn/parallel/process_group.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <condition_variable>
34
#include <memory>
45
#include <mutex>
56
#include <string>
@@ -11,6 +12,8 @@
1112
#include <nccl.h>
1213
#endif
1314

15+
#include "glog/logging.h"
16+
1417
#include "infini_train/include/nn/parallel/reduce_op_type.h"
1518

1619
namespace infini_train {
@@ -27,7 +30,7 @@ namespace infini_train::nn::parallel {
2730
#ifdef USE_NCCL
2831
class ProcessGroup {
2932
public:
30-
explicit ProcessGroup(const std::vector<int> &device_indices);
33+
explicit ProcessGroup(const std::string &process_group_name, const std::vector<int> &device_indices);
3134

3235
// support for multi-node distributed training
3336
explicit ProcessGroup(const ncclUniqueId &nccl_id);
@@ -67,6 +70,8 @@ class ProcessGroup {
6770
std::unordered_map<int, int> thread_group_rank_map_; // thread_rank : group_rank
6871

6972
int world_size_ = 0;
73+
74+
const std::string name_ = "";
7075
};
7176
#endif
7277

@@ -93,26 +98,26 @@ class ProcessGroupFactory {
9398

9499
template <typename Creator, typename = std::enable_if_t<std::is_invocable_v<Creator>>>
95100
const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) {
96-
{
97-
std::lock_guard<std::mutex> lock(mutex_);
98-
auto it = name_to_group_.find(name);
99-
if (it != name_to_group_.end()) {
100-
return it->second.get();
101-
}
101+
std::unique_lock<std::mutex> lock(mutex_);
102+
auto [it, inserted] = name_to_group_.emplace(name, nullptr);
103+
if (!inserted) {
104+
while (it->second == nullptr) { cond_.wait(lock); }
105+
return it->second.get();
102106
}
103107

108+
lock.unlock();
104109
auto new_group = creator();
110+
lock.lock();
105111

106-
{
107-
std::lock_guard<std::mutex> lock(mutex_);
108-
auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group));
109-
return it->second.get();
110-
}
112+
it->second = std::move(new_group);
113+
cond_.notify_all();
114+
return it->second.get();
111115
}
112116

113117
private:
114118
// TODO(dcj): maybe RWLock later?
115119
mutable std::mutex mutex_;
120+
std::condition_variable cond_;
116121
std::unordered_map<std::string, std::unique_ptr<ProcessGroup>> name_to_group_;
117122
};
118123
} // namespace infini_train::nn::parallel

infini_train/src/device.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ nn::parallel::Rank CudaDevice::rank() const { return rank_; }
6464

6565
CudaDevice::CudaDevice(int8_t index)
6666
: Device(DeviceType::kCUDA, index),
67-
rank_({nn::parallel::global::GetLocalProcRank(), index, nn::parallel::global::GetNprocPerNode(),
67+
rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(),
6868
nn::parallel::global::GetNthreadPerProc()}) {
6969
// TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode
7070
SetDevice();

infini_train/src/nn/parallel/distributed_data_parallel.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
2323
CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module";
2424

2525
auto ddp_pg
26-
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank()));
26+
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().GlobalRank()));
2727
// FIXME(dcj): use multi-node ddp_pg here
28-
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg);
28+
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg,
29+
ddp_pg);
2930
param->RegisterPostAccumulateGradHook(std::move(hook));
3031
}
3132
for (auto &buffer : module->Buffers()) {

infini_train/src/nn/parallel/global.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@ std::string GetEnvAsStr(const std::string &name, const std::string &default_valu
1919
return value ? std::string(value) : default_value;
2020
}
2121

22-
#ifdef USE_NCCL
23-
ncclUniqueId StringToNcclId(const std::string &str) {
24-
ncclUniqueId id;
25-
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
26-
unsigned int byte;
27-
std::stringstream ss;
28-
ss << std::hex << str.substr(i * 2, 2);
29-
ss >> byte;
30-
id.internal[i] = static_cast<char>(byte);
31-
}
32-
return id;
33-
}
34-
#endif
22+
// #ifdef USE_NCCL
23+
// ncclUniqueId StringToNcclId(const std::string &str) {
24+
// ncclUniqueId id;
25+
// for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
26+
// unsigned int byte;
27+
// std::stringstream ss;
28+
// ss << std::hex << str.substr(i * 2, 2);
29+
// ss >> byte;
30+
// id.internal[i] = static_cast<char>(byte);
31+
// }
32+
// return id;
33+
// }
34+
// #endif
3535

3636
} // namespace
3737

@@ -126,9 +126,9 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
126126
layout_.sizes[PP] = 1;
127127
layout_.InitStrides();
128128
// FIXME(dcj): what if no nccl id?
129-
#ifdef USE_NCCL
130-
nccl_id_ = StringToNcclId(GetEnvAsStr("NCCL_UNIQUE_ID", ""));
131-
#endif
129+
// #ifdef USE_NCCL
130+
// nccl_id_ = StringToNcclId(GetEnvAsStr("NCCL_UNIQUE_ID", ""));
131+
// #endif
132132

133133
initialized_ = true;
134134
}

infini_train/src/nn/parallel/process_group.cc

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
#include "infini_train/include/nn/parallel/process_group.h"
22

33
#include <algorithm>
4+
#include <chrono>
5+
#include <filesystem>
6+
#include <fstream>
7+
#include <iterator>
48
#include <memory>
59
#include <numeric>
10+
#include <thread>
611
#include <vector>
712

813
#ifdef USE_NCCL
@@ -20,28 +25,90 @@
2025
namespace infini_train {
2126

2227
namespace {
28+
using nn::parallel::function::ReduceOpType;
29+
30+
#ifdef USE_NCCL
2331
const std::unordered_map<DataType, ncclDataType_t> kNcclDtypeMap = {
2432
{DataType::kUINT8, ncclUint8}, {DataType::kINT8, ncclInt8}, {DataType::kUINT32, ncclUint32},
2533
{DataType::kINT32, ncclInt32}, {DataType::kUINT64, ncclUint64}, {DataType::kINT64, ncclInt64},
2634
{DataType::kBFLOAT16, ncclBfloat16}, {DataType::kFLOAT16, ncclHalf}, {DataType::kFLOAT32, ncclFloat32},
2735
{DataType::kFLOAT64, ncclFloat64},
2836
};
2937

30-
using nn::parallel::function::ReduceOpType;
31-
3238
const std::unordered_map<ReduceOpType, ncclRedOp_t> kNcclReduceOpMap = {
3339
{ReduceOpType::kSum, ncclSum},
3440
{ReduceOpType::kProd, ncclProd},
3541
{ReduceOpType::kMax, ncclMax},
3642
{ReduceOpType::kAvg, ncclAvg},
3743
};
44+
45+
void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &filename) {
46+
std::string tmp_path = filename + ".tmp";
47+
48+
std::ofstream ofs(tmp_path, std::ios::binary);
49+
ofs.write(reinterpret_cast<const char *>(&nccl_id), sizeof(nccl_id));
50+
ofs.close();
51+
52+
std::rename(tmp_path.c_str(), filename.c_str());
53+
}
54+
55+
void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &filename) {
56+
std::ifstream ifs(filename, std::ios::binary);
57+
ifs.read(reinterpret_cast<char *>(&nccl_id), sizeof(nccl_id));
58+
ifs.close();
59+
}
60+
#endif
61+
3862
} // namespace
3963

4064
} // namespace infini_train
4165

4266
namespace infini_train::nn::parallel {
4367

4468
#ifdef USE_NCCL
69+
// NOTE(dcj): This constructor is used only for initializing intra-node (single-machine) ProcessGroup.
70+
ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector<int> &ranks)
71+
: world_size_(ranks.size()), name_(process_group_name) {
72+
int n_threads = global::GetNthreadPerProc();
73+
// NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, device_indices.data()));
74+
// group-rank 0 create nccl unique id and broadcast to other ranks
75+
76+
ncclUniqueId nccl_id;
77+
78+
if (std::ranges::min(ranks) < (global::GetGlobalProcRank() + 1) * global::GetNthreadPerProc()
79+
&& std::ranges::min(ranks) >= global::GetGlobalProcRank() * global::GetNthreadPerProc()) {
80+
ncclGetUniqueId(&nccl_id);
81+
82+
WriteNcclUniqueId(nccl_id, name_);
83+
} else {
84+
while (std::filesystem::exists(name_) == false) {
85+
std::this_thread::sleep_for(std::chrono::microseconds(1000));
86+
}
87+
ReadNcclUniqueId(nccl_id, name_);
88+
}
89+
90+
std::vector<int> device_indices;
91+
NCCL_CHECK(ncclGroupStart());
92+
for (int i = 0; i < n_threads; ++i) {
93+
int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i;
94+
auto it = std::ranges::find(ranks, global_rank);
95+
if (it != ranks.end()) {
96+
cudaSetDevice(i);
97+
ncclComm_t comm;
98+
int group_rank = std::distance(ranks.begin(), it);
99+
NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank));
100+
comms_.push_back(comm);
101+
device_indices.push_back(i);
102+
// FIXME(dcj): fix Init function
103+
thread_group_rank_map_[DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i)->rank().thread_rank()]
104+
= group_rank;
105+
}
106+
}
107+
NCCL_CHECK(ncclGroupEnd());
108+
109+
Init(device_indices);
110+
}
111+
45112
ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::GetWorldSize()) {
46113
int local_comm_size = global::GetNthreadPerProc();
47114
comms_.resize(local_comm_size);
@@ -63,12 +130,12 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge
63130

64131
void ProcessGroup::Init(const std::vector<int> &device_indices) {
65132
// FIXME(dcj): This is a temporary solution to get the device and comm for each thread.
66-
int local_comm_size = std::min(static_cast<int>(device_indices.size()), global::GetNthreadPerProc());
67-
for (int i = 0; i < local_comm_size; ++i) {
133+
// int local_comm_size = std::min(static_cast<int>(device_indices.size()), global::GetNthreadPerProc());
134+
for (int i = 0; i < device_indices.size(); ++i) {
68135
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]);
69136
devices_.push_back(device);
70137
device_comm_map_[device] = comms_[i];
71-
thread_group_rank_map_[device->rank().thread_rank()] = i + global::GetGlobalProcRank() * local_comm_size;
138+
// thread_group_rank_map_[device->rank().thread_rank()] = i;
72139
}
73140
}
74141

@@ -347,11 +414,11 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() {
347414
const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) {
348415
std::vector<int> device_indices(comm_size);
349416
std::iota(device_indices.begin(), device_indices.end(), 0);
350-
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
417+
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(name, device_indices); });
351418
}
352419

353420
const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector<int> &device_indices) {
354-
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
421+
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(name, device_indices); });
355422
}
356423

357424
#ifdef USE_NCCL
@@ -370,10 +437,10 @@ const ProcessGroup *ProcessGroupFactory::GetDefaultProcessGroup() const {
370437
}
371438

372439
ProcessGroupFactory::ProcessGroupFactory() {
373-
#ifdef USE_NCCL
374-
GetOrCreate(kDefaltProcessGroupName, global::GetNcclId());
375-
#else
440+
// #ifdef USE_NCCL
441+
// GetOrCreate(kDefaltProcessGroupName, global::GetNcclId());
442+
// #else
376443
GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize());
377-
#endif
444+
// #endif
378445
}
379446
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)