Skip to content

Commit 424596b

Browse files
committed
feat: Separate ProcessGroup initialization for single-node multi-thread and multi-node multi-process scenarios
1 parent 16e2d77 commit 424596b

File tree

6 files changed

+61
-50
lines changed

6 files changed

+61
-50
lines changed

infini_train/include/nn/parallel/global.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
#include <string>
55
#include <vector>
66

7-
#ifdef USE_NCCL
8-
#include <nccl.h>
9-
#endif
10-
117
namespace infini_train::nn::parallel::global {
128

139
enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 };
@@ -32,14 +28,16 @@ class GlobalEnv {
3228

3329
void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false);
3430

31+
int nnodes() const;
32+
33+
int nproc_per_node() const;
34+
3535
int world_size() const;
3636

3737
int global_proc_rank() const;
3838

3939
int local_proc_rank() const;
4040

41-
int nproc_per_node() const;
42-
4341
int nthread_per_process() const;
4442

4543
int tensor_parallel_size() const;
@@ -61,9 +59,11 @@ class GlobalEnv {
6159
GlobalEnv &operator=(const GlobalEnv &) = delete;
6260

6361
private:
64-
int world_size_ = 1;
62+
int nnodes_ = 1;
6563
int nproc_per_node_ = 1;
6664
int nthread_per_process_ = 1;
65+
int world_size_ = 1;
66+
6767
int global_proc_rank_ = 0;
6868
int local_proc_rank_ = 0;
6969

@@ -72,10 +72,6 @@ class GlobalEnv {
7272

7373
int data_parallel_size_ = 1;
7474

75-
#ifdef USE_NCCL
76-
ncclUniqueId nccl_id_;
77-
#endif
78-
7975
mutable std::mutex mutex_;
8076
bool initialized_ = false;
8177

@@ -86,6 +82,7 @@ inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool s
8682
GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled);
8783
}
8884

85+
inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); }
8986
inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); }
9087
inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); }
9188
inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); }

infini_train/include/nn/parallel/process_group.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class ProcessGroup {
5555
std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;
5656

5757
private:
58-
void Init(const std::vector<int> &device_indices);
58+
void InitSingleProcess(const std::vector<int> &ranks);
59+
60+
void InitMultiProcess(const std::vector<int> &ranks);
5961

6062
private:
6163
std::vector<ncclComm_t> comms_;

infini_train/src/nn/parallel/global.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
9595

9696
CHECK(!initialized_) << "Repeated initialization of GlobalEnv!";
9797

98-
world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process;
98+
nnodes_ = GetEnvAsInt("NNODES", 1);
9999
nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", 1);
100+
world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process;
100101
global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", 0);
101102
local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", 0);
102103

@@ -115,29 +116,34 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
115116
initialized_ = true;
116117
}
117118

118-
int GlobalEnv::world_size() const {
119+
int GlobalEnv::nnodes() const {
119120
CHECK(initialized_) << "GlobalEnv is not initialized!";
120-
return world_size_;
121+
return nnodes_;
121122
}
122123

123-
int GlobalEnv::global_proc_rank() const {
124+
int GlobalEnv::nproc_per_node() const {
124125
CHECK(initialized_) << "GlobalEnv is not initialized!";
125-
return global_proc_rank_;
126+
return nproc_per_node_;
126127
}
127128

128-
int GlobalEnv::local_proc_rank() const {
129+
int GlobalEnv::nthread_per_process() const {
129130
CHECK(initialized_) << "GlobalEnv is not initialized!";
130-
return local_proc_rank_;
131+
return nthread_per_process_;
131132
}
132133

133-
int GlobalEnv::nproc_per_node() const {
134+
int GlobalEnv::world_size() const {
134135
CHECK(initialized_) << "GlobalEnv is not initialized!";
135-
return nproc_per_node_;
136+
return world_size_;
136137
}
137138

138-
int GlobalEnv::nthread_per_process() const {
139+
int GlobalEnv::global_proc_rank() const {
139140
CHECK(initialized_) << "GlobalEnv is not initialized!";
140-
return nthread_per_process_;
141+
return global_proc_rank_;
142+
}
143+
144+
int GlobalEnv::local_proc_rank() const {
145+
CHECK(initialized_) << "GlobalEnv is not initialized!";
146+
return local_proc_rank_;
141147
}
142148

143149
int GlobalEnv::tensor_parallel_size() const {

infini_train/src/nn/parallel/process_group.cc

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ namespace infini_train::nn::parallel {
7979
#ifdef USE_NCCL
8080
ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector<int> &ranks)
8181
: world_size_(ranks.size()), name_(process_group_name) {
82+
if (global::GetNnodes() == 1 && global::GetNprocPerNode() == 1) {
83+
InitSingleProcess(ranks);
84+
} else {
85+
InitMultiProcess(ranks);
86+
}
87+
}
88+
89+
void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
90+
comms_.resize(world_size_);
91+
NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data()));
92+
93+
for (int i = 0; i < ranks.size(); ++i) {
94+
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, ranks[i]);
95+
devices_.push_back(device);
96+
device_comm_map_[device] = comms_[i];
97+
thread_group_rank_map_[device->rank().thread_rank()] = i;
98+
}
99+
}
100+
101+
void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
82102
int n_threads = global::GetNthreadPerProc();
83103

84104
ncclUniqueId nccl_id;
@@ -99,28 +119,19 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec
99119
auto it = std::ranges::find(ranks, global_rank);
100120
if (it != ranks.end()) {
101121
cudaSetDevice(i);
122+
102123
ncclComm_t comm;
103124
int group_rank = std::distance(ranks.begin(), it);
104125
NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank));
105126
comms_.push_back(comm);
106-
device_indices.push_back(i);
107-
// FIXME(dcj): fix Init function
108-
thread_group_rank_map_[DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i)->rank().thread_rank()]
109-
= group_rank;
127+
128+
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i);
129+
thread_group_rank_map_[device->rank().thread_rank()] = group_rank;
130+
devices_.push_back(device);
131+
device_comm_map_[device] = comm;
110132
}
111133
}
112134
NCCL_CHECK(ncclGroupEnd());
113-
114-
Init(device_indices);
115-
}
116-
117-
void ProcessGroup::Init(const std::vector<int> &device_indices) {
118-
for (int i = 0; i < device_indices.size(); ++i) {
119-
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]);
120-
devices_.push_back(device);
121-
device_comm_map_[device] = comms_[i];
122-
// thread_group_rank_map_[device->rank().thread_rank()] = i;
123-
}
124135
}
125136

126137
int ProcessGroup::GetGroupRank(int thread_rank) const { return thread_group_rank_map_.at(thread_rank); }

tools/infini_run/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
11
add_executable(infini_run infini_run.cc)
22
target_link_libraries(infini_run PRIVATE gflags glog)
3-
if (USE_NCCL)
4-
target_link_libraries(infini_run PRIVATE nccl)
5-
endif()

tools/infini_run/infini_run.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
#include <unistd.h>
88
#include <vector>
99

10-
#ifdef USE_NCCL
11-
#include <nccl.h>
12-
#endif
13-
1410
#include "gflags/gflags.h"
1511
#include "glog/logging.h"
1612

@@ -52,13 +48,17 @@ int main(int argc, char **argv) {
5248
pid_t pid = fork();
5349
if (pid == 0) {
5450
int global_proc_rank = FLAGS_node_rank * FLAGS_nproc_per_node + local_proc_rank;
55-
setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1);
56-
setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1);
57-
setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1);
51+
setenv("NNODES", std::to_string(FLAGS_nnodes).c_str(), 1);
5852
setenv("NPROC_PER_NODE", std::to_string(FLAGS_nproc_per_node).c_str(), 1);
53+
5954
setenv("MASTER_ADDR", master_addr.c_str(), 1);
6055
setenv("MASTER_PORT", master_port.c_str(), 1);
6156

57+
setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1);
58+
setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1);
59+
60+
setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1);
61+
6262
execvp(train_program.c_str(), train_argv.data());
6363
perror("exec failed");
6464
exit(1);
@@ -70,11 +70,9 @@ int main(int argc, char **argv) {
7070
wait(&status);
7171
}
7272

73-
#ifdef USE_NCCL
7473
if (FLAGS_node_rank == 0) {
7574
CleanupNcclIdFiles();
7675
}
77-
#endif
7876

7977
return 0;
8078
}

0 commit comments

Comments
 (0)