1+ #include < cstdio>
12#include < cstdlib>
2- #include < fstream>
3- #include < ios>
4- #include < sstream>
3+ #include < filesystem>
4+ #include < regex>
55#include < string>
66#include < sys/wait.h>
7- #include < sys/stat.h>
87#include < unistd.h>
98#include < vector>
10- #include < cstdio>
119
1210#ifdef USE_NCCL
1311#include < nccl.h>
@@ -21,15 +19,19 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node");
2119DEFINE_int32 (node_rank, 0 , " Rank of this node" );
2220DEFINE_string (rdzv_endpoint, " 127.0.0.1:29500" , " Rendezvous endpoint (host:port)" );
2321
24- #ifdef USE_NCCL
25- std::string NcclIdToString (const ncclUniqueId& id) {
26- std::ostringstream oss;
27- for (int i = 0 ; i < NCCL_UNIQUE_ID_BYTES; ++i) {
28- oss << std::hex << std::uppercase << std::setw (2 ) << std::setfill (' 0' ) << (int )(unsigned char )id.internal [i];
22+ void CleanupNcclIdFiles () {
23+ const std::filesystem::path cwd = std::filesystem::current_path ();
24+ std::regex pattern (R"( ncclUniqueId_.*\.bin)" );
25+
26+ for (const auto &entry : std::filesystem::directory_iterator (cwd)) {
27+ if (entry.is_regular_file ()) {
28+ const std::string filename = entry.path ().filename ().string ();
29+ if (std::regex_match (filename, pattern)) {
30+ std::filesystem::remove (entry.path ());
31+ }
32+ }
2933 }
30- return oss.str ();
3134}
32- #endif
3335
3436int main (int argc, char **argv) {
3537 gflags::ParseCommandLineFlags (&argc, &argv, true );
@@ -46,23 +48,6 @@ int main(int argc, char **argv) {
4648 std::string master_addr = FLAGS_rdzv_endpoint.substr (0 , FLAGS_rdzv_endpoint.find (' :' ));
4749 std::string master_port = FLAGS_rdzv_endpoint.substr (FLAGS_rdzv_endpoint.find (' :' ) + 1 );
4850
49- const char * nccl_id_path = " /data/shared/InfiniTrain-dev/data/nccl_id.bin" ;
50- const char * nccl_id_tmp_path = " /data/shared/InfiniTrain-dev/data/nccl_id.tmp" ;
51-
52- #ifdef USE_NCCL
53- if (FLAGS_node_rank == 0 ) {
54- ncclUniqueId id;
55- ncclGetUniqueId (&id);
56-
57- std::ofstream ofs (nccl_id_tmp_path, std::ios::binary);
58- ofs.write ((char *)&id, sizeof (id));
59- ofs.close ();
60-
61- // atomic operation
62- rename (nccl_id_tmp_path, nccl_id_path);
63- }
64- #endif
65-
6651 for (int local_proc_rank = 0 ; local_proc_rank < FLAGS_nproc_per_node; ++local_proc_rank) {
6752 pid_t pid = fork ();
6853 if (pid == 0 ) {
@@ -74,20 +59,6 @@ int main(int argc, char **argv) {
7459 setenv (" MASTER_ADDR" , master_addr.c_str (), 1 );
7560 setenv (" MASTER_PORT" , master_port.c_str (), 1 );
7661
77- #ifdef USE_NCCL
78- struct stat st;
79- while (stat (nccl_id_path, &st) != 0 ) {
80- usleep (1000 );
81- }
82-
83- ncclUniqueId id;
84- std::ifstream ifs (nccl_id_path, std::ios::binary);
85- ifs.read ((char *)&id, sizeof (id));
86-
87- std::string id_str = NcclIdToString (id);
88- setenv (" NCCL_UNIQUE_ID" , id_str.c_str (), 1 );
89- #endif
90-
9162 execvp (train_program.c_str (), train_argv.data ());
9263 perror (" exec failed" );
9364 exit (1 );
@@ -101,7 +72,7 @@ int main(int argc, char **argv) {
10172
10273#ifdef USE_NCCL
10374 if (FLAGS_node_rank == 0 ) {
104- std::remove (nccl_id_path );
75+ CleanupNcclIdFiles ( );
10576 }
10677#endif
10778
0 commit comments