@@ -79,6 +79,26 @@ namespace infini_train::nn::parallel {
7979#ifdef USE_NCCL
8080ProcessGroup::ProcessGroup (const std::string &process_group_name, const std::vector<int > &ranks)
8181 : world_size_(ranks.size()), name_(process_group_name) {
82+ if (global::GetNnodes () == 1 && global::GetNprocPerNode () == 1 ) {
83+ InitSingleProcess (ranks);
84+ } else {
85+ InitMultiProcess (ranks);
86+ }
87+ }
88+
89+ void ProcessGroup::InitSingleProcess (const std::vector<int > &ranks) {
90+ comms_.resize (world_size_);
91+ NCCL_CHECK (ncclCommInitAll (comms_.data (), world_size_, ranks.data ()));
92+
93+ for (int i = 0 ; i < ranks.size (); ++i) {
94+ auto device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , ranks[i]);
95+ devices_.push_back (device);
96+ device_comm_map_[device] = comms_[i];
97+ thread_group_rank_map_[device->rank ().thread_rank ()] = i;
98+ }
99+ }
100+
101+ void ProcessGroup::InitMultiProcess (const std::vector<int > &ranks) {
82102 int n_threads = global::GetNthreadPerProc ();
83103
84104 ncclUniqueId nccl_id;
@@ -99,28 +119,19 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec
99119 auto it = std::ranges::find (ranks, global_rank);
100120 if (it != ranks.end ()) {
101121 cudaSetDevice (i);
122+
102123 ncclComm_t comm;
103124 int group_rank = std::distance (ranks.begin (), it);
104125 NCCL_CHECK (ncclCommInitRank (&comm, world_size_, nccl_id, group_rank));
105126 comms_.push_back (comm);
106- device_indices.push_back (i);
107- // FIXME(dcj): fix Init function
108- thread_group_rank_map_[DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , i)->rank ().thread_rank ()]
109- = group_rank;
127+
128+ auto device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , i);
129+ thread_group_rank_map_[device->rank ().thread_rank ()] = group_rank;
130+ devices_.push_back (device);
131+ device_comm_map_[device] = comm;
110132 }
111133 }
112134 NCCL_CHECK (ncclGroupEnd ());
113-
114- Init (device_indices);
115- }
116-
117- void ProcessGroup::Init (const std::vector<int > &device_indices) {
118- for (int i = 0 ; i < device_indices.size (); ++i) {
119- auto device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , device_indices[i]);
120- devices_.push_back (device);
121- device_comm_map_[device] = comms_[i];
122- // thread_group_rank_map_[device->rank().thread_rank()] = i;
123- }
124135}
125136
126137int ProcessGroup::GetGroupRank (int thread_rank) const { return thread_group_rank_map_.at (thread_rank); }
0 commit comments