Skip to content

Commit 16e2d77

Browse files
committed
refactor: refactor the naming of ncclUniqueId files and add a unified cleanup logic
1 parent 8e0862f commit 16e2d77

File tree

2 files changed

+31
-52
lines changed

2 files changed

+31
-52
lines changed

infini_train/src/nn/parallel/process_group.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <algorithm>
44
#include <chrono>
55
#include <filesystem>
6+
#include <format>
67
#include <fstream>
78
#include <iterator>
89
#include <memory>
@@ -42,18 +43,28 @@ const std::unordered_map<ReduceOpType, ncclRedOp_t> kNcclReduceOpMap = {
4243
{ReduceOpType::kAvg, ncclAvg},
4344
};
4445

45-
void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &filename) {
46-
std::string tmp_path = filename + ".tmp";
46+
inline std::string NcclFileName(const std::string &name, bool tmp = false) {
47+
return std::format("ncclUniqueId_{}.{}", name, tmp ? "tmp" : "bin");
48+
}
49+
50+
void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &pg_name) {
51+
std::string tmp_path = NcclFileName(pg_name, true);
4752

4853
std::ofstream ofs(tmp_path, std::ios::binary);
4954
ofs.write(reinterpret_cast<const char *>(&nccl_id), sizeof(nccl_id));
5055
ofs.close();
5156

52-
std::rename(tmp_path.c_str(), filename.c_str());
57+
std::rename(tmp_path.c_str(), NcclFileName(pg_name).c_str());
5358
}
5459

55-
void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &filename) {
56-
std::ifstream ifs(filename, std::ios::binary);
60+
void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &pg_name) {
61+
std::string file_path = NcclFileName(pg_name);
62+
63+
while (std::filesystem::exists(file_path) == false) {
64+
std::this_thread::sleep_for(std::chrono::microseconds(1000));
65+
}
66+
67+
std::ifstream ifs(file_path, std::ios::binary);
5768
ifs.read(reinterpret_cast<char *>(&nccl_id), sizeof(nccl_id));
5869
ifs.close();
5970
}
@@ -78,9 +89,6 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec
7889

7990
WriteNcclUniqueId(nccl_id, name_);
8091
} else {
81-
while (std::filesystem::exists(name_) == false) {
82-
std::this_thread::sleep_for(std::chrono::microseconds(1000));
83-
}
8492
ReadNcclUniqueId(nccl_id, name_);
8593
}
8694

tools/infini_run/infini_run.cc

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
#include <cstdio>
12
#include <cstdlib>
2-
#include <fstream>
3-
#include <ios>
4-
#include <sstream>
3+
#include <filesystem>
4+
#include <regex>
55
#include <string>
66
#include <sys/wait.h>
7-
#include <sys/stat.h>
87
#include <unistd.h>
98
#include <vector>
10-
#include <cstdio>
119

1210
#ifdef USE_NCCL
1311
#include <nccl.h>
@@ -21,15 +19,19 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node");
2119
DEFINE_int32(node_rank, 0, "Rank of this node");
2220
DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)");
2321

24-
#ifdef USE_NCCL
25-
std::string NcclIdToString(const ncclUniqueId& id) {
26-
std::ostringstream oss;
27-
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
28-
oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0') << (int)(unsigned char)id.internal[i];
22+
void CleanupNcclIdFiles() {
23+
const std::filesystem::path cwd = std::filesystem::current_path();
24+
std::regex pattern(R"(ncclUniqueId_.*\.bin)");
25+
26+
for (const auto &entry : std::filesystem::directory_iterator(cwd)) {
27+
if (entry.is_regular_file()) {
28+
const std::string filename = entry.path().filename().string();
29+
if (std::regex_match(filename, pattern)) {
30+
std::filesystem::remove(entry.path());
31+
}
32+
}
2933
}
30-
return oss.str();
3134
}
32-
#endif
3335

3436
int main(int argc, char **argv) {
3537
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -46,23 +48,6 @@ int main(int argc, char **argv) {
4648
std::string master_addr = FLAGS_rdzv_endpoint.substr(0, FLAGS_rdzv_endpoint.find(':'));
4749
std::string master_port = FLAGS_rdzv_endpoint.substr(FLAGS_rdzv_endpoint.find(':') + 1);
4850

49-
const char* nccl_id_path = "/data/shared/InfiniTrain-dev/data/nccl_id.bin";
50-
const char* nccl_id_tmp_path = "/data/shared/InfiniTrain-dev/data/nccl_id.tmp";
51-
52-
#ifdef USE_NCCL
53-
if (FLAGS_node_rank == 0) {
54-
ncclUniqueId id;
55-
ncclGetUniqueId(&id);
56-
57-
std::ofstream ofs(nccl_id_tmp_path, std::ios::binary);
58-
ofs.write((char *)&id, sizeof(id));
59-
ofs.close();
60-
61-
// atomic operation
62-
rename(nccl_id_tmp_path, nccl_id_path);
63-
}
64-
#endif
65-
6651
for (int local_proc_rank = 0; local_proc_rank < FLAGS_nproc_per_node; ++local_proc_rank) {
6752
pid_t pid = fork();
6853
if (pid == 0) {
@@ -74,20 +59,6 @@ int main(int argc, char **argv) {
7459
setenv("MASTER_ADDR", master_addr.c_str(), 1);
7560
setenv("MASTER_PORT", master_port.c_str(), 1);
7661

77-
#ifdef USE_NCCL
78-
struct stat st;
79-
while (stat(nccl_id_path, &st) != 0) {
80-
usleep(1000);
81-
}
82-
83-
ncclUniqueId id;
84-
std::ifstream ifs(nccl_id_path, std::ios::binary);
85-
ifs.read((char*)&id, sizeof(id));
86-
87-
std::string id_str = NcclIdToString(id);
88-
setenv("NCCL_UNIQUE_ID", id_str.c_str(), 1);
89-
#endif
90-
9162
execvp(train_program.c_str(), train_argv.data());
9263
perror("exec failed");
9364
exit(1);
@@ -101,7 +72,7 @@ int main(int argc, char **argv) {
10172

10273
#ifdef USE_NCCL
10374
if (FLAGS_node_rank == 0) {
104-
std::remove(nccl_id_path);
75+
CleanupNcclIdFiles();
10576
}
10677
#endif
10778

0 commit comments

Comments
 (0)