@@ -68,6 +68,15 @@ void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &pg_name) {
6868 ifs.read (reinterpret_cast <char *>(&nccl_id), sizeof (nccl_id));
6969 ifs.close ();
7070}
71+
72+ void CleanupNcclIdFile (const std::string &pg_name) {
73+ const std::filesystem::path cwd = std::filesystem::current_path ();
74+ std::string file_path = NcclFileName (pg_name);
75+
76+ if (std::filesystem::exists (file_path)) {
77+ std::filesystem::remove (file_path);
78+ }
79+ }
7180#endif
7281
7382} // namespace
@@ -86,6 +95,12 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec
8695 }
8796}
8897
98+ ProcessGroup::~ProcessGroup () {
99+ if (is_main_process_) {
100+ CleanupNcclIdFile (name_);
101+ }
102+ }
103+
89104void ProcessGroup::InitSingleProcess (const std::vector<int > &ranks) {
90105 comms_.resize (world_size_);
91106 NCCL_CHECK (ncclCommInitAll (comms_.data (), world_size_, ranks.data ()));
@@ -100,13 +115,17 @@ void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
100115
101116void ProcessGroup::InitMultiProcess (const std::vector<int > &ranks) {
102117 int n_threads = global::GetNthreadPerProc ();
118+ int global_proc_rank = global::GetGlobalProcRank ();
119+ int lower_rank = global_proc_rank * n_threads;
120+ int upper_rank = (global_proc_rank + 1 ) * n_threads;
103121
104122 ncclUniqueId nccl_id;
105123
106- if ( std::ranges::min (ranks) < ( global::GetGlobalProcRank () + 1 ) * global::GetNthreadPerProc ()
107- && std::ranges::min (ranks) >= global::GetGlobalProcRank () * global::GetNthreadPerProc () ) {
108- ncclGetUniqueId (&nccl_id) ;
124+ int min_rank = std::ranges::min (ranks);
125+ if (min_rank < upper_rank && min_rank >= lower_rank ) {
126+ is_main_process_ = true ;
109127
128+ ncclGetUniqueId (&nccl_id);
110129 WriteNcclUniqueId (nccl_id, name_);
111130 } else {
112131 ReadNcclUniqueId (nccl_id, name_);
@@ -115,8 +134,8 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
115134 std::vector<int > device_indices;
116135 NCCL_CHECK (ncclGroupStart ());
117136 for (int i = 0 ; i < n_threads; ++i) {
118- int global_rank = global::GetGlobalProcRank () * global::GetNthreadPerProc () + i;
119- auto it = std::ranges::find (ranks, global_rank );
137+ int global_thread_rank = lower_rank + i;
138+ auto it = std::ranges::find (ranks, global_thread_rank );
120139 if (it != ranks.end ()) {
121140 cudaSetDevice (i);
122141
0 commit comments