Skip to content

Commit 78047a7

Browse files
committed
feat: add multi-node default ProcessGroup
1 parent ecdc5ce commit 78047a7

File tree

6 files changed

+110
-28
lines changed

6 files changed

+110
-28
lines changed

infini_train/include/nn/parallel/global.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include <string>
55
#include <vector>
66

7+
#ifdef USE_NCCL
8+
#include <nccl.h>
9+
#endif
10+
711
namespace infini_train::nn::parallel::global {
812

913
enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 };
@@ -45,6 +49,9 @@ class GlobalEnv {
4549
int data_parallel_size() const;
4650

4751
Layout layout() const;
52+
#ifdef USE_NCCL
53+
ncclUniqueId nccl_id() const;
54+
#endif
4855

4956
private:
5057
GlobalEnv() = default;
@@ -65,6 +72,10 @@ class GlobalEnv {
6572

6673
int data_parallel_size_ = 1;
6774

75+
#ifdef USE_NCCL
76+
ncclUniqueId nccl_id_;
77+
#endif
78+
6879
mutable std::mutex mutex_;
6980
bool initialized_ = false;
7081

@@ -108,5 +119,8 @@ inline std::vector<int> GetGroupRanks(Axis target, int rank) {
108119
}
109120

110121
std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true);
122+
#ifdef USE_NCCL
123+
inline ncclUniqueId GetNcclId() { return GlobalEnv::Instance().nccl_id(); }
124+
#endif
111125

112126
} // namespace infini_train::nn::parallel::global

infini_train/include/nn/parallel/process_group.h

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <memory>
44
#include <mutex>
55
#include <string>
6+
#include <type_traits>
67
#include <unordered_map>
78
#include <vector>
89

@@ -28,6 +29,9 @@ class ProcessGroup {
2829
public:
2930
explicit ProcessGroup(const std::vector<int> &device_indices);
3031

32+
// support for multi-node distributed training
33+
explicit ProcessGroup(const ncclUniqueId &nccl_id);
34+
3135
int GetGroupRank(int thread_rank) const;
3236

3337
void AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;
@@ -52,14 +56,17 @@ class ProcessGroup {
5256

5357
std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;
5458

59+
private:
60+
void Init(const std::vector<int> &device_indices);
61+
5562
private:
5663
std::vector<ncclComm_t> comms_;
5764
std::vector<const Device *> devices_;
5865

5966
std::unordered_map<const Device *, ncclComm_t> device_comm_map_;
6067
std::unordered_map<int, int> thread_group_rank_map_; // thread_rank : group_rank
6168

62-
int comm_size_ = 0;
69+
int world_size_ = 0;
6370
};
6471
#endif
6572

@@ -73,12 +80,37 @@ class ProcessGroupFactory {
7380

7481
const ProcessGroup *GetOrCreate(const std::string &name, const std::vector<int> &device_indices);
7582

83+
#ifdef USE_NCCL
84+
const ProcessGroup *GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id);
85+
#endif
86+
7687
const ProcessGroup *Get(const std::string &name) const;
7788

7889
const ProcessGroup *GetDefaultProcessGroup() const;
7990

8091
private:
8192
ProcessGroupFactory();
93+
94+
template <typename Creator, typename = std::enable_if_t<std::is_invocable_v<Creator>>>
95+
const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) {
96+
{
97+
std::lock_guard<std::mutex> lock(mutex_);
98+
auto it = name_to_group_.find(name);
99+
if (it != name_to_group_.end()) {
100+
return it->second.get();
101+
}
102+
}
103+
104+
auto new_group = creator();
105+
106+
{
107+
std::lock_guard<std::mutex> lock(mutex_);
108+
auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group));
109+
return it->second.get();
110+
}
111+
}
112+
113+
private:
82114
// TODO(dcj): maybe RWLock later?
83115
mutable std::mutex mutex_;
84116
std::unordered_map<std::string, std::unique_ptr<ProcessGroup>> name_to_group_;

infini_train/src/nn/parallel/global.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ int GetEnvAsInt(const std::string &name, int default_value) {
1414
return value ? std::atoi(value) : default_value;
1515
}
1616

17+
std::string GetEnvAsStr(const std::string &name, const std::string &default_value) {
18+
const char *value = std::getenv(name.c_str());
19+
return value ? std::string(value) : default_value;
20+
}
21+
1722
#ifdef USE_NCCL
1823
ncclUniqueId StringToNcclId(const std::string &str) {
1924
ncclUniqueId id;
@@ -120,6 +125,10 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
120125
// FIXME(zbl): set PP size
121126
layout_.sizes[PP] = 1;
122127
layout_.InitStrides();
128+
// FIXME(dcj): what if no nccl id?
129+
#ifdef USE_NCCL
130+
nccl_id_ = StringToNcclId(GetEnvAsStr("NCCL_UNIQUE_ID", ""));
131+
#endif
123132

124133
initialized_ = true;
125134
}
@@ -267,5 +276,11 @@ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) {
267276
oss << "\n";
268277
return oss.str();
269278
}
279+
#ifdef USE_NCCL
280+
ncclUniqueId GlobalEnv::nccl_id() const {
281+
CHECK(initialized_) << "GlobalEnv is not initialized!";
282+
return nccl_id_;
283+
}
284+
#endif
270285

271286
} // namespace infini_train::nn::parallel::global

infini_train/src/nn/parallel/process_group.cc

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "infini_train/include/nn/parallel/process_group.h"
22

3+
#include <memory>
34
#include <numeric>
45
#include <vector>
56

@@ -40,11 +41,25 @@ const std::unordered_map<ReduceOpType, ncclRedOp_t> kNcclReduceOpMap = {
4041
namespace infini_train::nn::parallel {
4142

4243
#ifdef USE_NCCL
43-
ProcessGroup::ProcessGroup(const std::vector<int> &device_indices) : comm_size_(device_indices.size()) {
44-
comms_.resize(comm_size_);
45-
NCCL_CHECK(ncclCommInitAll(comms_.data(), comm_size_, device_indices.data()));
44+
ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::GetWorldSize()) {
45+
int local_comm_size = global::GetNthreadPerProc();
46+
comms_.resize(local_comm_size);
47+
std::vector<int> device_indices(local_comm_size);
4648

47-
for (int i = 0; i < comm_size_; ++i) {
49+
NCCL_CHECK(ncclGroupStart());
50+
for (int i = 0; i < local_comm_size; ++i) {
51+
device_indices[i] = i;
52+
53+
int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i;
54+
NCCL_CHECK(ncclCommInitRank(&comms_[i], world_size_, nccl_id, global_rank));
55+
}
56+
NCCL_CHECK(ncclGroupEnd());
57+
58+
Init(device_indices);
59+
}
60+
61+
void ProcessGroup::Init(const std::vector<int> &device_indices) {
62+
for (int i = 0; i < world_size_; ++i) {
4863
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]);
4964
devices_.push_back(device);
5065
device_comm_map_[device] = comms_[i];
@@ -92,7 +107,9 @@ ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensor
92107
std::vector<ncclComm_t> comms;
93108
std::vector<const Device *> devices;
94109

95-
for (size_t i = 0; i < comm_size_; ++i) {
110+
CHECK_EQ(world_size_, comms_.size());
111+
112+
for (size_t i = 0; i < world_size_; ++i) {
96113
auto device = devices_[i];
97114
for (const auto &input_tensor : input_tensors) {
98115
outputs.push_back(std::make_shared<Tensor>(input_tensor->Dims(), input_tensor->Dtype(), device));
@@ -323,31 +340,20 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() {
323340
}
324341

325342
const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) {
326-
std::vector<int> devices(comm_size);
327-
std::iota(devices.begin(), devices.end(), 0);
328-
const std::vector<int> &device_indices = devices;
329-
330-
return GetOrCreate(name, device_indices);
343+
std::vector<int> device_indices(comm_size);
344+
std::iota(device_indices.begin(), device_indices.end(), 0);
345+
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
331346
}
332347

333348
const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector<int> &device_indices) {
334-
{
335-
std::lock_guard<std::mutex> lock(mutex_);
336-
auto it = name_to_group_.find(name);
337-
if (it != name_to_group_.end()) {
338-
return it->second.get();
339-
}
340-
}
341-
342-
auto new_group = std::make_unique<ProcessGroup>(device_indices);
343-
344-
{
345-
std::lock_guard<std::mutex> lock(mutex_);
349+
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(device_indices); });
350+
}
346351

347-
auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group));
348-
return it->second.get();
349-
}
352+
#ifdef USE_NCCL
353+
const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id) {
354+
return GetOrCreate(name, [&]() { return std::make_unique<ProcessGroup>(nccl_id); });
350355
}
356+
#endif
351357

352358
const ProcessGroup *ProcessGroupFactory::Get(const std::string &name) const {
353359
std::lock_guard<std::mutex> lock(mutex_);
@@ -358,5 +364,11 @@ const ProcessGroup *ProcessGroupFactory::GetDefaultProcessGroup() const {
358364
return name_to_group_.at(kDefaltProcessGroupName).get();
359365
}
360366

361-
ProcessGroupFactory::ProcessGroupFactory() { GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize()); }
367+
ProcessGroupFactory::ProcessGroupFactory() {
368+
#ifdef USE_NCCL
369+
GetOrCreate(kDefaltProcessGroupName, global::GetNcclId());
370+
#else
371+
GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize());
372+
#endif
373+
}
362374
} // namespace infini_train::nn::parallel

tools/infini_run/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
add_executable(infini_run infini_run.cc)
22
target_link_libraries(infini_run PRIVATE gflags glog)
3+
if (USE_NCCL)
4+
target_link_libraries(infini_run PRIVATE nccl)
5+
endif()

tools/infini_run/infini_run.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)
2525
std::string NcclIdToString(const ncclUniqueId& id) {
2626
std::ostringstream oss;
2727
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
28-
oss << std::hex << std::uppercase << (int)(unsigned char)id.internal[i];
28+
oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0') << (int)(unsigned char)id.internal[i];
2929
}
3030
return oss.str();
3131
}
@@ -99,5 +99,11 @@ int main(int argc, char **argv) {
9999
wait(&status);
100100
}
101101

102+
#ifdef USE_NCCL
103+
if (FLAGS_node_rank == 0) {
104+
std::remove(nccl_id_path);
105+
}
106+
#endif
107+
102108
return 0;
103109
}

0 commit comments

Comments
 (0)