11#include " infini_train/include/nn/parallel/process_group.h"
22
33#include < algorithm>
4+ #include < chrono>
5+ #include < filesystem>
6+ #include < fstream>
7+ #include < iterator>
48#include < memory>
59#include < numeric>
10+ #include < thread>
611#include < vector>
712
813#ifdef USE_NCCL
2025namespace infini_train {
2126
2227namespace {
28+ using nn::parallel::function::ReduceOpType;
29+
30+ #ifdef USE_NCCL
2331const std::unordered_map<DataType, ncclDataType_t> kNcclDtypeMap = {
2432 {DataType::kUINT8 , ncclUint8}, {DataType::kINT8 , ncclInt8}, {DataType::kUINT32 , ncclUint32},
2533 {DataType::kINT32 , ncclInt32}, {DataType::kUINT64 , ncclUint64}, {DataType::kINT64 , ncclInt64},
2634 {DataType::kBFLOAT16 , ncclBfloat16}, {DataType::kFLOAT16 , ncclHalf}, {DataType::kFLOAT32 , ncclFloat32},
2735 {DataType::kFLOAT64 , ncclFloat64},
2836};
2937
30- using nn::parallel::function::ReduceOpType;
31-
3238const std::unordered_map<ReduceOpType, ncclRedOp_t> kNcclReduceOpMap = {
3339 {ReduceOpType::kSum , ncclSum},
3440 {ReduceOpType::kProd , ncclProd},
3541 {ReduceOpType::kMax , ncclMax},
3642 {ReduceOpType::kAvg , ncclAvg},
3743};
44+
45+ void WriteNcclUniqueId (const ncclUniqueId &nccl_id, const std::string &filename) {
46+ std::string tmp_path = filename + " .tmp" ;
47+
48+ std::ofstream ofs (tmp_path, std::ios::binary);
49+ ofs.write (reinterpret_cast <const char *>(&nccl_id), sizeof (nccl_id));
50+ ofs.close ();
51+
52+ std::rename (tmp_path.c_str (), filename.c_str ());
53+ }
54+
55+ void ReadNcclUniqueId (ncclUniqueId &nccl_id, const std::string &filename) {
56+ std::ifstream ifs (filename, std::ios::binary);
57+ ifs.read (reinterpret_cast <char *>(&nccl_id), sizeof (nccl_id));
58+ ifs.close ();
59+ }
60+ #endif
61+
3862} // namespace
3963
4064} // namespace infini_train
4165
4266namespace infini_train ::nn::parallel {
4367
4468#ifdef USE_NCCL
69+ // NOTE(dcj): This constructor is used only for initializing intra-node (single-machine) ProcessGroup.
70+ ProcessGroup::ProcessGroup (const std::string &process_group_name, const std::vector<int > &ranks)
71+ : world_size_(ranks.size()), name_(process_group_name) {
72+ int n_threads = global::GetNthreadPerProc ();
73+ // NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, device_indices.data()));
74+ // group-rank 0 create nccl unique id and broadcast to other ranks
75+
76+ ncclUniqueId nccl_id;
77+
78+ if (std::ranges::min (ranks) < (global::GetGlobalProcRank () + 1 ) * global::GetNthreadPerProc ()
79+ && std::ranges::min (ranks) >= global::GetGlobalProcRank () * global::GetNthreadPerProc ()) {
80+ ncclGetUniqueId (&nccl_id);
81+
82+ WriteNcclUniqueId (nccl_id, name_);
83+ } else {
84+ while (std::filesystem::exists (name_) == false ) {
85+ std::this_thread::sleep_for (std::chrono::microseconds (1000 ));
86+ }
87+ ReadNcclUniqueId (nccl_id, name_);
88+ }
89+
90+ std::vector<int > device_indices;
91+ NCCL_CHECK (ncclGroupStart ());
92+ for (int i = 0 ; i < n_threads; ++i) {
93+ int global_rank = global::GetGlobalProcRank () * global::GetNthreadPerProc () + i;
94+ auto it = std::ranges::find (ranks, global_rank);
95+ if (it != ranks.end ()) {
96+ cudaSetDevice (i);
97+ ncclComm_t comm;
98+ int group_rank = std::distance (ranks.begin (), it);
99+ NCCL_CHECK (ncclCommInitRank (&comm, world_size_, nccl_id, group_rank));
100+ comms_.push_back (comm);
101+ device_indices.push_back (i);
102+ // FIXME(dcj): fix Init function
103+ thread_group_rank_map_[DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , i)->rank ().thread_rank ()]
104+ = group_rank;
105+ }
106+ }
107+ NCCL_CHECK (ncclGroupEnd ());
108+
109+ Init (device_indices);
110+ }
111+
45112ProcessGroup::ProcessGroup (const ncclUniqueId &nccl_id) : world_size_(global::GetWorldSize()) {
46113 int local_comm_size = global::GetNthreadPerProc ();
47114 comms_.resize (local_comm_size);
@@ -63,12 +130,12 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge
63130
64131void ProcessGroup::Init (const std::vector<int > &device_indices) {
65132 // FIXME(dcj): This is a temporary solution to get the device and comm for each thread.
66- int local_comm_size = std::min (static_cast <int >(device_indices.size ()), global::GetNthreadPerProc ());
67- for (int i = 0 ; i < local_comm_size ; ++i) {
133+ // int local_comm_size = std::min(static_cast<int>(device_indices.size()), global::GetNthreadPerProc());
134+ for (int i = 0 ; i < device_indices. size () ; ++i) {
68135 auto device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , device_indices[i]);
69136 devices_.push_back (device);
70137 device_comm_map_[device] = comms_[i];
71- thread_group_rank_map_[device->rank ().thread_rank ()] = i + global::GetGlobalProcRank () * local_comm_size ;
138+ // thread_group_rank_map_[device->rank().thread_rank()] = i;
72139 }
73140}
74141
@@ -347,11 +414,11 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() {
347414const ProcessGroup *ProcessGroupFactory::GetOrCreate (const std::string &name, int comm_size) {
348415 std::vector<int > device_indices (comm_size);
349416 std::iota (device_indices.begin (), device_indices.end (), 0 );
350- return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
417+ return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(name, device_indices); });
351418}
352419
353420const ProcessGroup *ProcessGroupFactory::GetOrCreate (const std::string &name, const std::vector<int > &device_indices) {
354- return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
421+ return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(name, device_indices); });
355422}
356423
357424#ifdef USE_NCCL
@@ -370,10 +437,10 @@ const ProcessGroup *ProcessGroupFactory::GetDefaultProcessGroup() const {
370437}
371438
372439ProcessGroupFactory::ProcessGroupFactory () {
373- #ifdef USE_NCCL
374- GetOrCreate (kDefaltProcessGroupName , global::GetNcclId ());
375- #else
440+ // #ifdef USE_NCCL
441+ // GetOrCreate(kDefaltProcessGroupName, global::GetNcclId());
442+ // #else
376443 GetOrCreate (kDefaltProcessGroupName , global::GetWorldSize ());
377- #endif
444+ // #endif
378445}
379446} // namespace infini_train::nn::parallel
0 commit comments