11#include " infini_train/include/nn/parallel/process_group.h"
22
3+ #include < memory>
34#include < numeric>
45#include < vector>
56
@@ -40,11 +41,25 @@ const std::unordered_map<ReduceOpType, ncclRedOp_t> kNcclReduceOpMap = {
4041namespace infini_train ::nn::parallel {
4142
4243#ifdef USE_NCCL
43- ProcessGroup::ProcessGroup (const std::vector<int > &device_indices) : comm_size_(device_indices.size()) {
44- comms_.resize (comm_size_);
45- NCCL_CHECK (ncclCommInitAll (comms_.data (), comm_size_, device_indices.data ()));
44+ ProcessGroup::ProcessGroup (const ncclUniqueId &nccl_id) : world_size_(global::GetWorldSize()) {
45+ int local_comm_size = global::GetNthreadPerProc ();
46+ comms_.resize (local_comm_size);
47+ std::vector<int > device_indices (local_comm_size);
4648
47- for (int i = 0 ; i < comm_size_; ++i) {
49+ NCCL_CHECK (ncclGroupStart ());
50+ for (int i = 0 ; i < local_comm_size; ++i) {
51+ device_indices[i] = i;
52+
53+ int global_rank = global::GetGlobalProcRank () * global::GetNthreadPerProc () + i;
54+ NCCL_CHECK (ncclCommInitRank (&comms_[i], world_size_, nccl_id, global_rank));
55+ }
56+ NCCL_CHECK (ncclGroupEnd ());
57+
58+ Init (device_indices);
59+ }
60+
61+ void ProcessGroup::Init (const std::vector<int > &device_indices) {
62+ for (int i = 0 ; i < world_size_; ++i) {
4863 auto device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , device_indices[i]);
4964 devices_.push_back (device);
5065 device_comm_map_[device] = comms_[i];
@@ -92,7 +107,9 @@ ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensor
92107 std::vector<ncclComm_t> comms;
93108 std::vector<const Device *> devices;
94109
95- for (size_t i = 0 ; i < comm_size_; ++i) {
110+ CHECK_EQ (world_size_, comms_.size ());
111+
112+ for (size_t i = 0 ; i < world_size_; ++i) {
96113 auto device = devices_[i];
97114 for (const auto &input_tensor : input_tensors) {
98115 outputs.push_back (std::make_shared<Tensor>(input_tensor->Dims (), input_tensor->Dtype (), device));
@@ -323,31 +340,20 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() {
323340}
324341
325342const ProcessGroup *ProcessGroupFactory::GetOrCreate (const std::string &name, int comm_size) {
326- std::vector<int > devices (comm_size);
327- std::iota (devices.begin (), devices.end (), 0 );
328- const std::vector<int > &device_indices = devices;
329-
330- return GetOrCreate (name, device_indices);
343+ std::vector<int > device_indices (comm_size);
344+ std::iota (device_indices.begin (), device_indices.end (), 0 );
345+ return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
331346}
332347
333348const ProcessGroup *ProcessGroupFactory::GetOrCreate (const std::string &name, const std::vector<int > &device_indices) {
334- {
335- std::lock_guard<std::mutex> lock (mutex_);
336- auto it = name_to_group_.find (name);
337- if (it != name_to_group_.end ()) {
338- return it->second .get ();
339- }
340- }
341-
342- auto new_group = std::make_unique<ProcessGroup>(device_indices);
343-
344- {
345- std::lock_guard<std::mutex> lock (mutex_);
349+ return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
350+ }
346351
347- auto [it, inserted] = name_to_group_. emplace (name, std::move (new_group));
348- return it-> second . get ();
349- }
352+ # ifdef USE_NCCL
353+ const ProcessGroup * ProcessGroupFactory::GetOrCreate ( const std::string &name, const ncclUniqueId &nccl_id) {
354+ return GetOrCreate (name, [&]() { return std::make_unique<ProcessGroup>(nccl_id); });
350355}
356+ #endif
351357
352358const ProcessGroup *ProcessGroupFactory::Get (const std::string &name) const {
353359 std::lock_guard<std::mutex> lock (mutex_);
@@ -358,5 +364,11 @@ const ProcessGroup *ProcessGroupFactory::GetDefaultProcessGroup() const {
358364 return name_to_group_.at (kDefaltProcessGroupName ).get ();
359365}
360366
361- ProcessGroupFactory::ProcessGroupFactory () { GetOrCreate (kDefaltProcessGroupName , global::GetWorldSize ()); }
367+ ProcessGroupFactory::ProcessGroupFactory () {
368+ #ifdef USE_NCCL
369+ GetOrCreate (kDefaltProcessGroupName , global::GetNcclId ());
370+ #else
371+ GetOrCreate (kDefaltProcessGroupName , global::GetWorldSize ());
372+ #endif
373+ }
362374} // namespace infini_train::nn::parallel
0 commit comments