Skip to content

Commit b4b571c

Browse files
committed
feat: Move the logic for cleaning up the ncclUniqueId file into the ProcessGroup destructor.
1 parent 4188230 commit b4b571c

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

infini_train/include/nn/parallel/process_group.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace infini_train::nn::parallel {
2929
class ProcessGroup {
3030
public:
3131
explicit ProcessGroup(const std::string &process_group_name, const std::vector<int> &device_indices);
32+
~ProcessGroup();
3233

3334
int GetGroupRank(int thread_rank) const;
3435

@@ -69,6 +70,8 @@ class ProcessGroup {
6970
int world_size_ = 0;
7071

7172
const std::string name_ = "";
73+
74+
bool is_main_process_ = false;
7275
};
7376
#endif
7477

infini_train/src/nn/parallel/process_group.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &pg_name) {
6868
ifs.read(reinterpret_cast<char *>(&nccl_id), sizeof(nccl_id));
6969
ifs.close();
7070
}
71+
72+
void CleanupNcclIdFile(const std::string &pg_name) {
73+
const std::filesystem::path cwd = std::filesystem::current_path();
74+
std::string file_path = NcclFileName(pg_name);
75+
76+
if (std::filesystem::exists(file_path)) {
77+
std::filesystem::remove(file_path);
78+
}
79+
}
7180
#endif
7281

7382
} // namespace
@@ -86,6 +95,12 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec
8695
}
8796
}
8897

98+
ProcessGroup::~ProcessGroup() {
99+
if (is_main_process_) {
100+
CleanupNcclIdFile(name_);
101+
}
102+
}
103+
89104
void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
90105
comms_.resize(world_size_);
91106
NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data()));
@@ -105,8 +120,9 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
105120

106121
if (std::ranges::min(ranks) < (global::GetGlobalProcRank() + 1) * global::GetNthreadPerProc()
107122
&& std::ranges::min(ranks) >= global::GetGlobalProcRank() * global::GetNthreadPerProc()) {
108-
ncclGetUniqueId(&nccl_id);
123+
is_main_process_ = true;
109124

125+
ncclGetUniqueId(&nccl_id);
110126
WriteNcclUniqueId(nccl_id, name_);
111127
} else {
112128
ReadNcclUniqueId(nccl_id, name_);

tools/infini_run/infini_run.cc

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,6 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node");
1515
DEFINE_int32(node_rank, 0, "Rank of this node");
1616
DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)");
1717

18-
void CleanupNcclIdFiles() {
19-
const std::filesystem::path cwd = std::filesystem::current_path();
20-
std::regex pattern(R"(ncclUniqueId_.*\.bin)");
21-
22-
for (const auto &entry : std::filesystem::directory_iterator(cwd)) {
23-
if (entry.is_regular_file()) {
24-
const std::string filename = entry.path().filename().string();
25-
if (std::regex_match(filename, pattern)) {
26-
std::filesystem::remove(entry.path());
27-
}
28-
}
29-
}
30-
}
31-
3218
int main(int argc, char **argv) {
3319
gflags::ParseCommandLineFlags(&argc, &argv, true);
3420
google::InitGoogleLogging(argv[0]);
@@ -70,9 +56,5 @@ int main(int argc, char **argv) {
7056
wait(&status);
7157
}
7258

73-
if (FLAGS_node_rank == 0) {
74-
CleanupNcclIdFiles();
75-
}
76-
7759
return 0;
7860
}

0 commit comments

Comments
 (0)