From ecdc5cec47655c7401a78d222cea605b07b024cf Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 28 Oct 2025 12:53:07 +0000 Subject: [PATCH 1/8] feat: use shared file to distribute ncclUniqueId in infini_run --- infini_train/src/nn/parallel/global.cc | 15 +++++++ tools/infini_run/infini_run.cc | 62 +++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 3feaab9c..424819f3 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include "glog/logging.h" @@ -13,6 +14,20 @@ int GetEnvAsInt(const std::string &name, int default_value) { return value ? std::atoi(value) : default_value; } +#ifdef USE_NCCL +ncclUniqueId StringToNcclId(const std::string &str) { + ncclUniqueId id; + for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { + unsigned int byte; + std::stringstream ss; + ss << std::hex << str.substr(i * 2, 2); + ss >> byte; + id.internal[i] = static_cast(byte); + } + return id; +} +#endif + } // namespace namespace infini_train::nn::parallel::global { diff --git a/tools/infini_run/infini_run.cc b/tools/infini_run/infini_run.cc index bcc3f25a..67aa372e 100644 --- a/tools/infini_run/infini_run.cc +++ b/tools/infini_run/infini_run.cc @@ -1,8 +1,17 @@ +#include +#include +#include +#include #include #include -#include -#include +#include #include +#include +#include + +#ifdef USE_NCCL +#include +#endif #include "gflags/gflags.h" #include "glog/logging.h" @@ -12,23 +21,48 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node"); DEFINE_int32(node_rank, 0, "Rank of this node"); DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)"); -int main(int argc, char** argv) { +#ifdef USE_NCCL +std::string NcclIdToString(const ncclUniqueId& id) { + std::ostringstream oss; + for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { + oss << std::hex << std::uppercase << (int)(unsigned char)id.internal[i]; + } + return oss.str(); +} +#endif + +int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); CHECK_GE(argc, 2) << "No training prgram specified!"; std::string train_program = argv[1]; - std::vector train_argv; - for (int i = 1; i < argc; ++i) { - train_argv.push_back(argv[i]); - } + std::vector train_argv; + for (int i = 1; i < argc; ++i) { train_argv.push_back(argv[i]); } train_argv.push_back(nullptr); int world_size = FLAGS_nnodes * FLAGS_nproc_per_node; std::string master_addr = FLAGS_rdzv_endpoint.substr(0, FLAGS_rdzv_endpoint.find(':')); std::string master_port = FLAGS_rdzv_endpoint.substr(FLAGS_rdzv_endpoint.find(':') + 1); + const char* nccl_id_path = "/data/shared/InfiniTrain-dev/data/nccl_id.bin"; + const char* nccl_id_tmp_path = "/data/shared/InfiniTrain-dev/data/nccl_id.tmp"; + +#ifdef USE_NCCL + if (FLAGS_node_rank == 0) { + ncclUniqueId id; + ncclGetUniqueId(&id); + + std::ofstream ofs(nccl_id_tmp_path, std::ios::binary); + ofs.write((char *)&id, sizeof(id)); + ofs.close(); + + // atomic operation + rename(nccl_id_tmp_path, nccl_id_path); + } +#endif + for (int local_proc_rank = 0; local_proc_rank < FLAGS_nproc_per_node; ++local_proc_rank) { pid_t pid = fork(); if (pid == 0) { @@ -40,6 +74,20 @@ int main(int argc, char** argv) { setenv("MASTER_ADDR", master_addr.c_str(), 1); setenv("MASTER_PORT", master_port.c_str(), 1); +#ifdef USE_NCCL + struct stat st; + while (stat(nccl_id_path, &st) != 0) { + usleep(1000); + } + + ncclUniqueId id; + std::ifstream ifs(nccl_id_path, std::ios::binary); + ifs.read((char*)&id, sizeof(id)); + + std::string id_str = NcclIdToString(id); + setenv("NCCL_UNIQUE_ID", id_str.c_str(), 1); +#endif + execvp(train_program.c_str(), train_argv.data()); perror("exec failed"); exit(1); From 78047a7f8028293db26c2fd6b187dd3558c4eb6b Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 29 Oct 2025 12:11:38 +0000 Subject: [PATCH 2/8] feat: add multi-node default ProcessGroup --- infini_train/include/nn/parallel/global.h | 14 ++++ .../include/nn/parallel/process_group.h | 34 +++++++++- infini_train/src/nn/parallel/global.cc | 15 +++++ infini_train/src/nn/parallel/process_group.cc | 64 +++++++++++-------- tools/infini_run/CMakeLists.txt | 3 + tools/infini_run/infini_run.cc | 8 ++- 6 files changed, 110 insertions(+), 28 deletions(-) diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 0ae75995..b90a1b9a 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -4,6 +4,10 @@ #include #include +#ifdef USE_NCCL +#include +#endif + namespace infini_train::nn::parallel::global { enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 }; @@ -45,6 +49,9 @@ class GlobalEnv { int data_parallel_size() const; Layout layout() const; +#ifdef USE_NCCL + ncclUniqueId nccl_id() const; +#endif private: GlobalEnv() = default; @@ -65,6 +72,10 @@ class GlobalEnv { int data_parallel_size_ = 1; +#ifdef USE_NCCL + ncclUniqueId nccl_id_; +#endif + mutable std::mutex mutex_; bool initialized_ = false; @@ -108,5 +119,8 @@ inline std::vector GetGroupRanks(Axis target, int rank) { } std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true); +#ifdef USE_NCCL +inline ncclUniqueId GetNcclId() { return GlobalEnv::Instance().nccl_id(); } +#endif } // namespace infini_train::nn::parallel::global diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 5919e054..e4a5f307 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -28,6 +29,9 @@ class ProcessGroup { public: explicit ProcessGroup(const std::vector &device_indices); + // support for multi-node distributed training + explicit ProcessGroup(const ncclUniqueId &nccl_id); + int GetGroupRank(int thread_rank) const; void AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; @@ -52,6 +56,9 @@ class ProcessGroup { std::vector> NcclRecv(std::vector> tensors, int src_rank) const; +private: + void Init(const std::vector &device_indices); + private: std::vector comms_; std::vector devices_; @@ -59,7 +66,7 @@ class ProcessGroup { std::unordered_map device_comm_map_; std::unordered_map thread_group_rank_map_; // thread_rank : group_rank - int comm_size_ = 0; + int world_size_ = 0; }; #endif @@ -73,12 +80,37 @@ class ProcessGroupFactory { const ProcessGroup *GetOrCreate(const std::string &name, const std::vector &device_indices); +#ifdef USE_NCCL + const ProcessGroup *GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id); +#endif + const ProcessGroup *Get(const std::string &name) const; const ProcessGroup *GetDefaultProcessGroup() const; private: ProcessGroupFactory(); + + template >> + const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) { + { + std::lock_guard lock(mutex_); + auto it = name_to_group_.find(name); + if (it != name_to_group_.end()) { + return it->second.get(); + } + } + + auto new_group = creator(); + + { + std::lock_guard lock(mutex_); + auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group)); + return it->second.get(); + } + } + +private: // TODO(dcj): maybe RWLock later? mutable std::mutex mutex_; std::unordered_map> name_to_group_; diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 424819f3..d34cbe18 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -14,6 +14,11 @@ int GetEnvAsInt(const std::string &name, int default_value) { return value ? std::atoi(value) : default_value; } +std::string GetEnvAsStr(const std::string &name, const std::string &default_value) { + const char *value = std::getenv(name.c_str()); + return value ? std::string(value) : default_value; +} + #ifdef USE_NCCL ncclUniqueId StringToNcclId(const std::string &str) { ncclUniqueId id; @@ -120,6 +125,10 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq // FIXME(zbl): set PP size layout_.sizes[PP] = 1; layout_.InitStrides(); + // FIXME(dcj): what if no nccl id? +#ifdef USE_NCCL + nccl_id_ = StringToNcclId(GetEnvAsStr("NCCL_UNIQUE_ID", "")); +#endif initialized_ = true; } @@ -267,5 +276,11 @@ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { oss << "\n"; return oss.str(); } +#ifdef USE_NCCL +ncclUniqueId GlobalEnv::nccl_id() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return nccl_id_; +} +#endif } // namespace infini_train::nn::parallel::global diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3271331c..1d8ab2d0 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/process_group.h" +#include #include #include @@ -40,11 +41,25 @@ const std::unordered_map kNcclReduceOpMap = { namespace infini_train::nn::parallel { #ifdef USE_NCCL -ProcessGroup::ProcessGroup(const std::vector &device_indices) : comm_size_(device_indices.size()) { - comms_.resize(comm_size_); - NCCL_CHECK(ncclCommInitAll(comms_.data(), comm_size_, device_indices.data())); +ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::GetWorldSize()) { + int local_comm_size = global::GetNthreadPerProc(); + comms_.resize(local_comm_size); + std::vector device_indices(local_comm_size); - for (int i = 0; i < comm_size_; ++i) { + NCCL_CHECK(ncclGroupStart()); + for (int i = 0; i < local_comm_size; ++i) { + device_indices[i] = i; + + int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i; + NCCL_CHECK(ncclCommInitRank(&comms_[i], world_size_, nccl_id, global_rank)); + } + NCCL_CHECK(ncclGroupEnd()); + + Init(device_indices); +} + +void ProcessGroup::Init(const std::vector &device_indices) { + for (int i = 0; i < world_size_; ++i) { auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; @@ -92,7 +107,9 @@ ProcessGroup::BroadCast(const std::vector> &input_tensor std::vector comms; std::vector devices; - for (size_t i = 0; i < comm_size_; ++i) { + CHECK_EQ(world_size_, comms_.size()); + + for (size_t i = 0; i < world_size_; ++i) { auto device = devices_[i]; for (const auto &input_tensor : input_tensors) { outputs.push_back(std::make_shared(input_tensor->Dims(), input_tensor->Dtype(), device)); @@ -323,31 +340,20 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() { } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) { - std::vector devices(comm_size); - std::iota(devices.begin(), devices.end(), 0); - const std::vector &device_indices = devices; - - return GetOrCreate(name, device_indices); + std::vector device_indices(comm_size); + std::iota(device_indices.begin(), device_indices.end(), 0); + return GetOrCreate(name, [&]() { return std::make_unique(device_indices); }); } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector &device_indices) { - { - std::lock_guard lock(mutex_); - auto it = name_to_group_.find(name); - if (it != name_to_group_.end()) { - return it->second.get(); - } - } - - auto new_group = std::make_unique(device_indices); - - { - std::lock_guard lock(mutex_); + return GetOrCreate(name, [&]() { return std::make_unique(device_indices); }); +} - auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group)); - return it->second.get(); - } +#ifdef USE_NCCL +const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id) { + return GetOrCreate(name, [&]() { return std::make_unique(nccl_id); }); } +#endif const ProcessGroup *ProcessGroupFactory::Get(const std::string &name) const { std::lock_guard lock(mutex_); @@ -358,5 +364,11 @@ const ProcessGroup *ProcessGroupFactory::GetDefaultProcessGroup() const { return name_to_group_.at(kDefaltProcessGroupName).get(); } -ProcessGroupFactory::ProcessGroupFactory() { GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize()); } +ProcessGroupFactory::ProcessGroupFactory() { +#ifdef USE_NCCL + GetOrCreate(kDefaltProcessGroupName, global::GetNcclId()); +#else + GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize()); +#endif +} } // namespace infini_train::nn::parallel diff --git a/tools/infini_run/CMakeLists.txt b/tools/infini_run/CMakeLists.txt index edc01bab..a8d15bce 100644 --- a/tools/infini_run/CMakeLists.txt +++ b/tools/infini_run/CMakeLists.txt @@ -1,2 +1,5 @@ add_executable(infini_run infini_run.cc) target_link_libraries(infini_run PRIVATE gflags glog) +if (USE_NCCL) + target_link_libraries(infini_run PRIVATE nccl) +endif() diff --git a/tools/infini_run/infini_run.cc b/tools/infini_run/infini_run.cc index 67aa372e..0ddef0b9 100644 --- a/tools/infini_run/infini_run.cc +++ b/tools/infini_run/infini_run.cc @@ -25,7 +25,7 @@ DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port) std::string NcclIdToString(const ncclUniqueId& id) { std::ostringstream oss; for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { - oss << std::hex << std::uppercase << (int)(unsigned char)id.internal[i]; + oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0') << (int)(unsigned char)id.internal[i]; } return oss.str(); } @@ -99,5 +99,11 @@ int main(int argc, char **argv) { wait(&status); } +#ifdef USE_NCCL + if (FLAGS_node_rank == 0) { + std::remove(nccl_id_path); + } +#endif + return 0; } From 6cc74b78e7a5ce4de7a51882d0188c3f6fea5656 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 3 Nov 2025 10:38:41 +0000 Subject: [PATCH 3/8] fix: make multi-node DDP precision work --- example/llama3/main.cc | 2 +- infini_train/include/nn/parallel/rank.h | 2 ++ .../src/nn/parallel/distributed_data_parallel.cc | 4 ++-- infini_train/src/nn/parallel/process_group.cc | 9 +++++++-- infini_train/src/nn/parallel/rank.cc | 2 ++ 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 9dc4cf07..c3a2c3f6 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -107,7 +107,7 @@ void Train(const nn::parallel::Rank &rank) { if (ddp_world_size > 1) { ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()), GetDataParallelGroupRanks(rank.thread_rank())); - ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank()); + ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); } if (tp_world_size > 1) { diff --git a/infini_train/include/nn/parallel/rank.h b/infini_train/include/nn/parallel/rank.h index 56b76b97..c5d9e185 100644 --- a/infini_train/include/nn/parallel/rank.h +++ b/infini_train/include/nn/parallel/rank.h @@ -10,6 +10,8 @@ class Rank { int process_size() const; int thread_size() const; + int GlobalRank() const; + bool IsParallel() const; bool IsMainRank() const; diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 26ca62c5..014a3f75 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -24,8 +24,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); - auto hook = std::make_unique(function::ReduceOpType::kAvg, - ddp_pg); + // FIXME(dcj): use multi-node ddp_pg here + auto hook = std::make_unique(function::ReduceOpType::kAvg); param->RegisterPostAccumulateGradHook(std::move(hook)); } for (auto &buffer : module->Buffers()) { diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 1d8ab2d0..075c6a66 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/process_group.h" +#include #include #include #include @@ -51,6 +52,8 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge device_indices[i] = i; int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i; + + cudaSetDevice(i); NCCL_CHECK(ncclCommInitRank(&comms_[i], world_size_, nccl_id, global_rank)); } NCCL_CHECK(ncclGroupEnd()); @@ -59,11 +62,13 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge } void ProcessGroup::Init(const std::vector &device_indices) { - for (int i = 0; i < world_size_; ++i) { + // FIXME(dcj): This is a temporary solution to get the device and comm for each thread. + int local_comm_size = std::min(static_cast(device_indices.size()), global::GetNthreadPerProc()); + for (int i = 0; i < local_comm_size; ++i) { auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; - thread_group_rank_map_[device->rank().thread_rank()] = i; + thread_group_rank_map_[device->rank().thread_rank()] = i + global::GetGlobalProcRank() * local_comm_size; } } diff --git a/infini_train/src/nn/parallel/rank.cc b/infini_train/src/nn/parallel/rank.cc index 73ecd5de..0ec36b8f 100644 --- a/infini_train/src/nn/parallel/rank.cc +++ b/infini_train/src/nn/parallel/rank.cc @@ -10,6 +10,8 @@ int Rank::thread_rank() const { return thread_rank_; } int Rank::process_size() const { return process_size_; } int Rank::thread_size() const { return thread_size_; } +int Rank::GlobalRank() const { return process_rank_ * thread_size_ + thread_rank_; } + bool Rank::IsParallel() const { return thread_size_ * process_size_ > 1; } bool Rank::IsMainRank() const { return thread_rank_ == 0; } } // namespace infini_train::nn::parallel From ed9fba48354ea7d4441a1adb26f180aa1c834753 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 11 Nov 2025 07:06:35 +0000 Subject: [PATCH 4/8] chore: add comments in global.h --- infini_train/include/nn/parallel/global.h | 55 +++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index b90a1b9a..73aac199 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -96,29 +96,84 @@ inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_paralle inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } +// ========================= // Layout Helper Functions +// ========================= + +/** + * @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate. + */ inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); } +/** + * @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank. + */ inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) { return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp); } + +/** + * @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis. + */ inline int GetGroupId(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); } +/** + * @brief Get the group ID that a given rank belongs to along a specific parallel axis. + */ inline int GetGroupId(Axis target, int rank) { int dp, tp, pp; GetCoordOf(rank, dp, tp, pp); return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); } + +/** + * @brief Get all ranks that belong to the same group as the given (dp, tp, pp) coordinate + * along a specified parallel axis (e.g., all ranks in the same TP group). + */ inline std::vector GetGroupRanks(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); } + +/** + * @brief Get all ranks that belong to the same group as the given rank + * along a specified parallel axis (e.g., all ranks in the same DP group). + */ inline std::vector GetGroupRanks(Axis target, int rank) { int dp, tp, pp; GetCoordOf(rank, dp, tp, pp); return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); } +/** + * @brief Generate a human-readable overview of all parallel communication groups. + * + * The output is intended for debugging, logging, and runtime verification of + * distributed parallelism configuration. + * + * @param L The Layout describing DP / TP / PP sizes and axis ordering. + * @param skip_trivial_axes + * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled) + * will be marked as "unenabled" and their detailed group listing will be skipped. + * + * @return A formatted string containing the full overview of process groups. + * + * Example: + * === Parallel Communication Groups === + * world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP} + * [DP] size=2, num_groups=4 + * - DP 0 (dp=-, tp=0, pp=0): [0, 4] + * - DP 1 (dp=-, tp=1, pp=0): [1, 5] + * - DP 2 (dp=-, tp=2, pp=0): [2, 6] + * - DP 3 (dp=-, tp=3, pp=0): [3, 7] + * + * [TP] size=4, num_groups=2 + * - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3] + * - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7] + * + * [PP] size=1, unenabled + */ std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true); + #ifdef USE_NCCL inline ncclUniqueId GetNcclId() { return GlobalEnv::Instance().nccl_id(); } #endif From 8e0862f8752b9b4a59902790648ac7e81a347c68 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 11 Nov 2025 07:27:29 +0000 Subject: [PATCH 5/8] feat: support multi-node DDP + TP + SP parallel training --- example/gpt2/main.cc | 10 +- example/llama3/main.cc | 12 +-- infini_train/include/nn/parallel/global.h | 4 - .../include/nn/parallel/process_group.h | 34 +++---- infini_train/src/device.cc | 2 +- .../nn/parallel/distributed_data_parallel.cc | 5 +- infini_train/src/nn/parallel/global.cc | 52 ----------- infini_train/src/nn/parallel/process_group.cc | 93 ++++++++++++------- .../src/nn/parallel/tensor_parallel.cc | 13 ++- 9 files changed, 98 insertions(+), 127 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index e3322d3d..71f6e76f 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -122,14 +122,14 @@ void Train(const nn::parallel::Rank &rank) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { - ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()), - GetDataParallelGroupRanks(rank.thread_rank())); + ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), + GetDataParallelGroupRanks(rank.GlobalRank())); ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank()); } if (tp_world_size > 1) { - tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()), - GetTensorParallelGroupRanks(rank.thread_rank())); + tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), + GetTensorParallelGroupRanks(rank.GlobalRank())); tp_rank = tp_pg->GetGroupRank(rank.thread_rank()); // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; @@ -312,7 +312,7 @@ int main(int argc, char *argv[]) { if (FLAGS_nthread_per_process > 1) { std::vector threads; for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) { - nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx, + nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx, nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process); threads.emplace_back(Train, rank); } diff --git a/example/llama3/main.cc b/example/llama3/main.cc index c3a2c3f6..4ac28506 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -105,14 +105,14 @@ void Train(const nn::parallel::Rank &rank) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { - ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()), - GetDataParallelGroupRanks(rank.thread_rank())); - ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); + ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), + GetDataParallelGroupRanks(rank.GlobalRank())); + ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank()); } if (tp_world_size > 1) { - tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()), - GetTensorParallelGroupRanks(rank.thread_rank())); + tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), + GetTensorParallelGroupRanks(rank.GlobalRank())); tp_rank = tp_pg->GetGroupRank(rank.thread_rank()); // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; @@ -292,7 +292,7 @@ int main(int argc, char *argv[]) { if (FLAGS_nthread_per_process > 1) { std::vector threads; for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) { - nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx, + nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx, nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process); threads.emplace_back(Train, rank); } diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 73aac199..909d370e 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -174,8 +174,4 @@ inline std::vector GetGroupRanks(Axis target, int rank) { */ std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true); -#ifdef USE_NCCL -inline ncclUniqueId GetNcclId() { return GlobalEnv::Instance().nccl_id(); } -#endif - } // namespace infini_train::nn::parallel::global diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index e4a5f307..95e4de59 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -27,10 +28,7 @@ namespace infini_train::nn::parallel { #ifdef USE_NCCL class ProcessGroup { public: - explicit ProcessGroup(const std::vector &device_indices); - - // support for multi-node distributed training - explicit ProcessGroup(const ncclUniqueId &nccl_id); + explicit ProcessGroup(const std::string &process_group_name, const std::vector &device_indices); int GetGroupRank(int thread_rank) const; @@ -67,6 +65,8 @@ class ProcessGroup { std::unordered_map thread_group_rank_map_; // thread_rank : group_rank int world_size_ = 0; + + const std::string name_ = ""; }; #endif @@ -80,10 +80,6 @@ class ProcessGroupFactory { const ProcessGroup *GetOrCreate(const std::string &name, const std::vector &device_indices); -#ifdef USE_NCCL - const ProcessGroup *GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id); -#endif - const ProcessGroup *Get(const std::string &name) const; const ProcessGroup *GetDefaultProcessGroup() const; @@ -93,26 +89,26 @@ class ProcessGroupFactory { template >> const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) { - { - std::lock_guard lock(mutex_); - auto it = name_to_group_.find(name); - if (it != name_to_group_.end()) { - return it->second.get(); - } + std::unique_lock lock(mutex_); + auto [it, inserted] = name_to_group_.emplace(name, nullptr); + if (!inserted) { + while (it->second == nullptr) { cond_.wait(lock); } + return it->second.get(); } + lock.unlock(); auto new_group = creator(); + lock.lock(); - { - std::lock_guard lock(mutex_); - auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group)); - return it->second.get(); - } + it->second = std::move(new_group); + cond_.notify_all(); + return it->second.get(); } private: // TODO(dcj): maybe RWLock later? mutable std::mutex mutex_; + std::condition_variable cond_; std::unordered_map> name_to_group_; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index 99bfc52e..4271ff97 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -64,7 +64,7 @@ nn::parallel::Rank CudaDevice::rank() const { return rank_; } CudaDevice::CudaDevice(int8_t index) : Device(DeviceType::kCUDA, index), - rank_({nn::parallel::global::GetLocalProcRank(), index, nn::parallel::global::GetNprocPerNode(), + rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(), nn::parallel::global::GetNthreadPerProc()}) { // TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode SetDevice(); diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 014a3f75..b7fc5d91 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -23,9 +23,10 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; auto ddp_pg - = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().GlobalRank())); // FIXME(dcj): use multi-node ddp_pg here - auto hook = std::make_unique(function::ReduceOpType::kAvg); + auto hook = std::make_unique(function::ReduceOpType::kAvg, + ddp_pg); param->RegisterPostAccumulateGradHook(std::move(hook)); } for (auto &buffer : module->Buffers()) { diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index d34cbe18..74edf50f 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -19,20 +19,6 @@ std::string GetEnvAsStr(const std::string &name, const std::string &default_valu return value ? std::string(value) : default_value; } -#ifdef USE_NCCL -ncclUniqueId StringToNcclId(const std::string &str) { - ncclUniqueId id; - for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { - unsigned int byte; - std::stringstream ss; - ss << std::hex << str.substr(i * 2, 2); - ss >> byte; - id.internal[i] = static_cast(byte); - } - return id; -} -#endif - } // namespace namespace infini_train::nn::parallel::global { @@ -125,10 +111,6 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq // FIXME(zbl): set PP size layout_.sizes[PP] = 1; layout_.InitStrides(); - // FIXME(dcj): what if no nccl id? -#ifdef USE_NCCL - nccl_id_ = StringToNcclId(GetEnvAsStr("NCCL_UNIQUE_ID", "")); -#endif initialized_ = true; } @@ -225,34 +207,6 @@ inline void AppendAxisGroups(std::ostringstream &oss, const Layout &L, Axis targ } } -/** - * @brief Generate a human-readable overview of all parallel communication groups. - * - * The output is intended for debugging, logging, and runtime verification of - * distributed parallelism configuration. - * - * @param L The Layout describing DP / TP / PP sizes and axis ordering. - * @param skip_trivial_axes - * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled) - * will be marked as "unenabled" and their detailed group listing will be skipped. - * - * @return A formatted string containing the full overview of process groups. - * - * Example: - * === Parallel Communication Groups === - * world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP} - * [DP] size=2, num_groups=4 - * - DP 0 (dp=-, tp=0, pp=0): [0, 4] - * - DP 1 (dp=-, tp=1, pp=0): [1, 5] - * - DP 2 (dp=-, tp=2, pp=0): [2, 6] - * - DP 3 (dp=-, tp=3, pp=0): [3, 7] - * - * [TP] size=4, num_groups=2 - * - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3] - * - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7] - * - * [PP] size=1, unenabled - */ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { std::ostringstream oss; oss << std::format("\n=== Parallel Communication Groups ===\n" @@ -276,11 +230,5 @@ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { oss << "\n"; return oss.str(); } -#ifdef USE_NCCL -ncclUniqueId GlobalEnv::nccl_id() const { - CHECK(initialized_) << "GlobalEnv is not initialized!"; - return nccl_id_; -} -#endif } // namespace infini_train::nn::parallel::global diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 075c6a66..ac9fcc82 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -1,8 +1,13 @@ #include "infini_train/include/nn/parallel/process_group.h" #include +#include +#include +#include +#include #include #include +#include #include #ifdef USE_NCCL @@ -20,6 +25,9 @@ namespace infini_train { namespace { +using nn::parallel::function::ReduceOpType; + +#ifdef USE_NCCL const std::unordered_map kNcclDtypeMap = { {DataType::kUINT8, ncclUint8}, {DataType::kINT8, ncclInt8}, {DataType::kUINT32, ncclUint32}, {DataType::kINT32, ncclInt32}, {DataType::kUINT64, ncclUint64}, {DataType::kINT64, ncclInt64}, @@ -27,14 +35,30 @@ const std::unordered_map kNcclDtypeMap = { {DataType::kFLOAT64, ncclFloat64}, }; -using nn::parallel::function::ReduceOpType; - const std::unordered_map kNcclReduceOpMap = { {ReduceOpType::kSum, ncclSum}, {ReduceOpType::kProd, ncclProd}, {ReduceOpType::kMax, ncclMax}, {ReduceOpType::kAvg, ncclAvg}, }; + +void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &filename) { + std::string tmp_path = filename + ".tmp"; + + std::ofstream ofs(tmp_path, std::ios::binary); + ofs.write(reinterpret_cast(&nccl_id), sizeof(nccl_id)); + ofs.close(); + + std::rename(tmp_path.c_str(), filename.c_str()); +} + +void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &filename) { + std::ifstream ifs(filename, std::ios::binary); + ifs.read(reinterpret_cast(&nccl_id), sizeof(nccl_id)); + ifs.close(); +} +#endif + } // namespace } // namespace infini_train @@ -42,19 +66,40 @@ const std::unordered_map kNcclReduceOpMap = { namespace infini_train::nn::parallel { #ifdef USE_NCCL -ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::GetWorldSize()) { - int local_comm_size = global::GetNthreadPerProc(); - comms_.resize(local_comm_size); - std::vector device_indices(local_comm_size); +ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector &ranks) + : world_size_(ranks.size()), name_(process_group_name) { + int n_threads = global::GetNthreadPerProc(); - NCCL_CHECK(ncclGroupStart()); - for (int i = 0; i < local_comm_size; ++i) { - device_indices[i] = i; + ncclUniqueId nccl_id; - int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i; + if (std::ranges::min(ranks) < (global::GetGlobalProcRank() + 1) * global::GetNthreadPerProc() + && std::ranges::min(ranks) >= global::GetGlobalProcRank() * global::GetNthreadPerProc()) { + ncclGetUniqueId(&nccl_id); + + WriteNcclUniqueId(nccl_id, name_); + } else { + while (std::filesystem::exists(name_) == false) { + std::this_thread::sleep_for(std::chrono::microseconds(1000)); + } + ReadNcclUniqueId(nccl_id, name_); + } - cudaSetDevice(i); - NCCL_CHECK(ncclCommInitRank(&comms_[i], world_size_, nccl_id, global_rank)); + std::vector device_indices; + NCCL_CHECK(ncclGroupStart()); + for (int i = 0; i < n_threads; ++i) { + int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i; + auto it = std::ranges::find(ranks, global_rank); + if (it != ranks.end()) { + cudaSetDevice(i); + ncclComm_t comm; + int group_rank = std::distance(ranks.begin(), it); + NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank)); + comms_.push_back(comm); + device_indices.push_back(i); + // FIXME(dcj): fix Init function + thread_group_rank_map_[DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i)->rank().thread_rank()] + = group_rank; + } } NCCL_CHECK(ncclGroupEnd()); @@ -62,13 +107,11 @@ ProcessGroup::ProcessGroup(const ncclUniqueId &nccl_id) : world_size_(global::Ge } void ProcessGroup::Init(const std::vector &device_indices) { - // FIXME(dcj): This is a temporary solution to get the device and comm for each thread. - int local_comm_size = std::min(static_cast(device_indices.size()), global::GetNthreadPerProc()); - for (int i = 0; i < local_comm_size; ++i) { + for (int i = 0; i < device_indices.size(); ++i) { auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; - thread_group_rank_map_[device->rank().thread_rank()] = i + global::GetGlobalProcRank() * local_comm_size; + // thread_group_rank_map_[device->rank().thread_rank()] = i; } } @@ -347,19 +390,13 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() { const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) { std::vector device_indices(comm_size); std::iota(device_indices.begin(), device_indices.end(), 0); - return GetOrCreate(name, [&]() { return std::make_unique(device_indices); }); + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector &device_indices) { - return GetOrCreate(name, [&]() { return std::make_unique(device_indices); }); + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } -#ifdef USE_NCCL -const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const ncclUniqueId &nccl_id) { - return GetOrCreate(name, [&]() { return std::make_unique(nccl_id); }); -} -#endif - const ProcessGroup *ProcessGroupFactory::Get(const std::string &name) const { std::lock_guard lock(mutex_); return name_to_group_.at(name).get(); @@ -369,11 +406,5 @@ const ProcessGroup *ProcessGroupFactory::GetDefaultProcessGroup() const { return name_to_group_.at(kDefaltProcessGroupName).get(); } -ProcessGroupFactory::ProcessGroupFactory() { -#ifdef USE_NCCL - GetOrCreate(kDefaltProcessGroupName, global::GetNcclId()); -#else - GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize()); -#endif -} +ProcessGroupFactory::ProcessGroupFactory() { GetOrCreate(kDefaltProcessGroupName, global::GetWorldSize()); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 09f04c88..bc94ac8b 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -35,7 +35,7 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); std::vector output_shape = tensor->Dims(); output_shape[0] *= world_size; @@ -55,7 +55,7 @@ std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); std::vector output_shape = tensor->Dims(); output_shape[0] *= world_size; @@ -80,7 +80,7 @@ std::shared_ptr SplitAlongLastDim(const std::shared_ptr &tensor) auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); auto rank = tp_group->GetGroupRank(device->rank().thread_rank()); auto last_dim_size = tensor->Dims().back() / world_size; @@ -98,7 +98,7 @@ std::shared_ptr Reduce(const std::shared_ptr &tensor) { auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); auto output = std::make_shared(*tensor); @@ -116,7 +116,7 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); auto output_shape = tensor->Dims(); CHECK_EQ(output_shape[0] % world_size, 0) << "First dimension of the tensor should be divisible by TP world size"; @@ -435,8 +435,7 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i const ProcessGroup *tp_group = nullptr; int rank = 0; if (tp_size > 1) { - tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); rank = tp_group->GetGroupRank(device->rank().thread_rank()); } From 16e2d777a15d33bd2688ac3e0d83afca3b331670 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 11 Nov 2025 08:41:37 +0000 Subject: [PATCH 6/8] refactor: refactor the naming of ncclUniqueId files and add a unified cleanup logic --- infini_train/src/nn/parallel/process_group.cc | 24 +++++--- tools/infini_run/infini_run.cc | 59 +++++-------------- 2 files changed, 31 insertions(+), 52 deletions(-) diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index ac9fcc82..67dae1b5 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -42,18 +43,28 @@ const std::unordered_map kNcclReduceOpMap = { {ReduceOpType::kAvg, ncclAvg}, }; -void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &filename) { - std::string tmp_path = filename + ".tmp"; +inline std::string NcclFileName(const std::string &name, bool tmp = false) { + return std::format("ncclUniqueId_{}.{}", name, tmp ? "tmp" : "bin"); +} + +void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &pg_name) { + std::string tmp_path = NcclFileName(pg_name, true); std::ofstream ofs(tmp_path, std::ios::binary); ofs.write(reinterpret_cast(&nccl_id), sizeof(nccl_id)); ofs.close(); - std::rename(tmp_path.c_str(), filename.c_str()); + std::rename(tmp_path.c_str(), NcclFileName(pg_name).c_str()); } -void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &filename) { - std::ifstream ifs(filename, std::ios::binary); +void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &pg_name) { + std::string file_path = NcclFileName(pg_name); + + while (std::filesystem::exists(file_path) == false) { + std::this_thread::sleep_for(std::chrono::microseconds(1000)); + } + + std::ifstream ifs(file_path, std::ios::binary); ifs.read(reinterpret_cast(&nccl_id), sizeof(nccl_id)); ifs.close(); } @@ -78,9 +89,6 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec WriteNcclUniqueId(nccl_id, name_); } else { - while (std::filesystem::exists(name_) == false) { - std::this_thread::sleep_for(std::chrono::microseconds(1000)); - } ReadNcclUniqueId(nccl_id, name_); } diff --git a/tools/infini_run/infini_run.cc b/tools/infini_run/infini_run.cc index 0ddef0b9..5ea3b7d0 100644 --- a/tools/infini_run/infini_run.cc +++ b/tools/infini_run/infini_run.cc @@ -1,13 +1,11 @@ +#include #include -#include -#include -#include +#include +#include #include #include -#include #include #include -#include #ifdef USE_NCCL #include @@ -21,15 +19,19 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node"); DEFINE_int32(node_rank, 0, "Rank of this node"); DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)"); -#ifdef USE_NCCL -std::string NcclIdToString(const ncclUniqueId& id) { - std::ostringstream oss; - for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { - oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0') << (int)(unsigned char)id.internal[i]; +void CleanupNcclIdFiles() { + const std::filesystem::path cwd = std::filesystem::current_path(); + std::regex pattern(R"(ncclUniqueId_.*\.bin)"); + + for (const auto &entry : std::filesystem::directory_iterator(cwd)) { + if (entry.is_regular_file()) { + const std::string filename = entry.path().filename().string(); + if (std::regex_match(filename, pattern)) { + std::filesystem::remove(entry.path()); + } + } } - return oss.str(); } -#endif int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -46,23 +48,6 @@ int main(int argc, char **argv) { std::string master_addr = FLAGS_rdzv_endpoint.substr(0, FLAGS_rdzv_endpoint.find(':')); std::string master_port = FLAGS_rdzv_endpoint.substr(FLAGS_rdzv_endpoint.find(':') + 1); - const char* nccl_id_path = "/data/shared/InfiniTrain-dev/data/nccl_id.bin"; - const char* nccl_id_tmp_path = "/data/shared/InfiniTrain-dev/data/nccl_id.tmp"; - -#ifdef USE_NCCL - if (FLAGS_node_rank == 0) { - ncclUniqueId id; - ncclGetUniqueId(&id); - - std::ofstream ofs(nccl_id_tmp_path, std::ios::binary); - ofs.write((char *)&id, sizeof(id)); - ofs.close(); - - // atomic operation - rename(nccl_id_tmp_path, nccl_id_path); - } -#endif - for (int local_proc_rank = 0; local_proc_rank < FLAGS_nproc_per_node; ++local_proc_rank) { pid_t pid = fork(); if (pid == 0) { @@ -74,20 +59,6 @@ int main(int argc, char **argv) { setenv("MASTER_ADDR", master_addr.c_str(), 1); setenv("MASTER_PORT", master_port.c_str(), 1); -#ifdef USE_NCCL - struct stat st; - while (stat(nccl_id_path, &st) != 0) { - usleep(1000); - } - - ncclUniqueId id; - std::ifstream ifs(nccl_id_path, std::ios::binary); - ifs.read((char*)&id, sizeof(id)); - - std::string id_str = NcclIdToString(id); - setenv("NCCL_UNIQUE_ID", id_str.c_str(), 1); -#endif - execvp(train_program.c_str(), train_argv.data()); perror("exec failed"); exit(1); @@ -101,7 +72,7 @@ int main(int argc, char **argv) { #ifdef USE_NCCL if (FLAGS_node_rank == 0) { - std::remove(nccl_id_path); + CleanupNcclIdFiles(); } #endif From 41882307f263be14721c436c1d23b4893f14731d Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 11 Nov 2025 09:41:18 +0000 Subject: [PATCH 7/8] feat: Separate ProcessGroup initialization for single-node multi-thread and multi-node multi-process scenarios --- infini_train/include/nn/parallel/global.h | 22 ++++------ .../include/nn/parallel/process_group.h | 4 +- infini_train/src/nn/parallel/global.cc | 28 ++++++++----- infini_train/src/nn/parallel/process_group.cc | 41 ++++++++++++------- tools/infini_run/CMakeLists.txt | 3 -- tools/infini_run/infini_run.cc | 16 ++++---- 6 files changed, 61 insertions(+), 53 deletions(-) diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 909d370e..93014e2a 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -4,10 +4,6 @@ #include #include -#ifdef USE_NCCL -#include -#endif - namespace infini_train::nn::parallel::global { enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 }; @@ -32,14 +28,16 @@ class GlobalEnv { void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false); + int nnodes() const; + + int nproc_per_node() const; + int world_size() const; int global_proc_rank() const; int local_proc_rank() const; - int nproc_per_node() const; - int nthread_per_process() const; int tensor_parallel_size() const; @@ -49,9 +47,6 @@ class GlobalEnv { int data_parallel_size() const; Layout layout() const; -#ifdef USE_NCCL - ncclUniqueId nccl_id() const; -#endif private: GlobalEnv() = default; @@ -61,9 +56,11 @@ class GlobalEnv { GlobalEnv &operator=(const GlobalEnv &) = delete; private: - int world_size_ = 1; + int nnodes_ = 1; int nproc_per_node_ = 1; int nthread_per_process_ = 1; + int world_size_ = 1; + int global_proc_rank_ = 0; int local_proc_rank_ = 0; @@ -72,10 +69,6 @@ class GlobalEnv { int data_parallel_size_ = 1; -#ifdef USE_NCCL - ncclUniqueId nccl_id_; -#endif - mutable std::mutex mutex_; bool initialized_ = false; @@ -86,6 +79,7 @@ inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool s GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled); } +inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); } inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); } diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 95e4de59..e21b4db3 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -55,7 +55,9 @@ class ProcessGroup { std::vector> NcclRecv(std::vector> tensors, int src_rank) const; private: - void Init(const std::vector &device_indices); + void InitSingleProcess(const std::vector &ranks); + + void InitMultiProcess(const std::vector &ranks); private: std::vector comms_; diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 74edf50f..f5f68157 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -95,8 +95,9 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; - world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process; + nnodes_ = GetEnvAsInt("NNODES", 1); nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", 1); + world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process; global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", 0); local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", 0); @@ -115,29 +116,34 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq initialized_ = true; } -int GlobalEnv::world_size() const { +int GlobalEnv::nnodes() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return world_size_; + return nnodes_; } -int GlobalEnv::global_proc_rank() const { +int GlobalEnv::nproc_per_node() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return global_proc_rank_; + return nproc_per_node_; } -int GlobalEnv::local_proc_rank() const { +int GlobalEnv::nthread_per_process() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return local_proc_rank_; + return nthread_per_process_; } -int GlobalEnv::nproc_per_node() const { +int GlobalEnv::world_size() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return nproc_per_node_; + return world_size_; } -int GlobalEnv::nthread_per_process() const { +int GlobalEnv::global_proc_rank() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return nthread_per_process_; + return global_proc_rank_; +} + +int GlobalEnv::local_proc_rank() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return local_proc_rank_; } int GlobalEnv::tensor_parallel_size() const { diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 67dae1b5..649c4446 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -79,6 +79,26 @@ namespace infini_train::nn::parallel { #ifdef USE_NCCL ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector &ranks) : world_size_(ranks.size()), name_(process_group_name) { + if (global::GetNnodes() == 1 && global::GetNprocPerNode() == 1) { + InitSingleProcess(ranks); + } else { + InitMultiProcess(ranks); + } +} + +void ProcessGroup::InitSingleProcess(const std::vector &ranks) { + comms_.resize(world_size_); + NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); + + for (int i = 0; i < ranks.size(); ++i) { + auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, ranks[i]); + devices_.push_back(device); + device_comm_map_[device] = comms_[i]; + thread_group_rank_map_[device->rank().thread_rank()] = i; + } +} + +void ProcessGroup::InitMultiProcess(const std::vector &ranks) { int n_threads = global::GetNthreadPerProc(); ncclUniqueId nccl_id; @@ -99,28 +119,19 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec auto it = std::ranges::find(ranks, global_rank); if (it != ranks.end()) { cudaSetDevice(i); + ncclComm_t comm; int group_rank = std::distance(ranks.begin(), it); NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank)); comms_.push_back(comm); - device_indices.push_back(i); - // FIXME(dcj): fix Init function - thread_group_rank_map_[DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i)->rank().thread_rank()] - = group_rank; + + auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i); + thread_group_rank_map_[device->rank().thread_rank()] = group_rank; + devices_.push_back(device); + device_comm_map_[device] = comm; } } NCCL_CHECK(ncclGroupEnd()); - - Init(device_indices); -} - -void ProcessGroup::Init(const std::vector &device_indices) { - for (int i = 0; i < device_indices.size(); ++i) { - auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]); - devices_.push_back(device); - device_comm_map_[device] = comms_[i]; - // thread_group_rank_map_[device->rank().thread_rank()] = i; - } } int ProcessGroup::GetGroupRank(int thread_rank) const { return thread_group_rank_map_.at(thread_rank); } diff --git a/tools/infini_run/CMakeLists.txt b/tools/infini_run/CMakeLists.txt index a8d15bce..edc01bab 100644 --- a/tools/infini_run/CMakeLists.txt +++ b/tools/infini_run/CMakeLists.txt @@ -1,5 +1,2 @@ add_executable(infini_run infini_run.cc) target_link_libraries(infini_run PRIVATE gflags glog) -if (USE_NCCL) - target_link_libraries(infini_run PRIVATE nccl) -endif() diff --git a/tools/infini_run/infini_run.cc b/tools/infini_run/infini_run.cc index 5ea3b7d0..d41cd4dd 100644 --- a/tools/infini_run/infini_run.cc +++ b/tools/infini_run/infini_run.cc @@ -7,10 +7,6 @@ #include #include -#ifdef USE_NCCL -#include -#endif - #include "gflags/gflags.h" #include "glog/logging.h" @@ -52,13 +48,17 @@ int main(int argc, char **argv) { pid_t pid = fork(); if (pid == 0) { int global_proc_rank = FLAGS_node_rank * FLAGS_nproc_per_node + local_proc_rank; - setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1); - setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1); - setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1); + setenv("NNODES", std::to_string(FLAGS_nnodes).c_str(), 1); setenv("NPROC_PER_NODE", std::to_string(FLAGS_nproc_per_node).c_str(), 1); + setenv("MASTER_ADDR", master_addr.c_str(), 1); setenv("MASTER_PORT", master_port.c_str(), 1); + setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1); + setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1); + + setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1); + execvp(train_program.c_str(), train_argv.data()); perror("exec failed"); exit(1); @@ -70,11 +70,9 @@ int main(int argc, char **argv) { wait(&status); } -#ifdef USE_NCCL if (FLAGS_node_rank == 0) { CleanupNcclIdFiles(); } -#endif return 0; } From 30640d68865e3d0ca20af61097757af5e4d0e585 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 12 Nov 2025 06:05:17 +0000 Subject: [PATCH 8/8] feat: Move the logic for cleaning up the ncclUniqueId file into the ProcessGroup destructor. --- .../include/nn/parallel/process_group.h | 3 ++ infini_train/src/nn/parallel/process_group.cc | 29 +++++++++++++++---- tools/infini_run/infini_run.cc | 18 ------------ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index e21b4db3..0258aac7 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -29,6 +29,7 @@ namespace infini_train::nn::parallel { class ProcessGroup { public: explicit ProcessGroup(const std::string &process_group_name, const std::vector &device_indices); + ~ProcessGroup(); int GetGroupRank(int thread_rank) const; @@ -69,6 +70,8 @@ class ProcessGroup { int world_size_ = 0; const std::string name_ = ""; + + bool is_main_process_ = false; }; #endif diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 649c4446..eee60401 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -68,6 +68,15 @@ void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &pg_name) { ifs.read(reinterpret_cast(&nccl_id), sizeof(nccl_id)); ifs.close(); } + +void CleanupNcclIdFile(const std::string &pg_name) { + const std::filesystem::path cwd = std::filesystem::current_path(); + std::string file_path = NcclFileName(pg_name); + + if (std::filesystem::exists(file_path)) { + std::filesystem::remove(file_path); + } +} #endif } // namespace @@ -86,6 +95,12 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec } } +ProcessGroup::~ProcessGroup() { + if (is_main_process_) { + CleanupNcclIdFile(name_); + } +} + void ProcessGroup::InitSingleProcess(const std::vector &ranks) { comms_.resize(world_size_); NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); @@ -100,13 +115,17 @@ void ProcessGroup::InitSingleProcess(const std::vector &ranks) { void ProcessGroup::InitMultiProcess(const std::vector &ranks) { int n_threads = global::GetNthreadPerProc(); + int global_proc_rank = global::GetGlobalProcRank(); + int lower_rank = global_proc_rank * n_threads; + int upper_rank = (global_proc_rank + 1) * n_threads; ncclUniqueId nccl_id; - if (std::ranges::min(ranks) < (global::GetGlobalProcRank() + 1) * global::GetNthreadPerProc() - && std::ranges::min(ranks) >= global::GetGlobalProcRank() * global::GetNthreadPerProc()) { - ncclGetUniqueId(&nccl_id); + int min_rank = std::ranges::min(ranks); + if (min_rank < upper_rank && min_rank >= lower_rank) { + is_main_process_ = true; + ncclGetUniqueId(&nccl_id); WriteNcclUniqueId(nccl_id, name_); } else { ReadNcclUniqueId(nccl_id, name_); @@ -115,8 +134,8 @@ void ProcessGroup::InitMultiProcess(const std::vector &ranks) { std::vector device_indices; NCCL_CHECK(ncclGroupStart()); for (int i = 0; i < n_threads; ++i) { - int global_rank = global::GetGlobalProcRank() * global::GetNthreadPerProc() + i; - auto it = std::ranges::find(ranks, global_rank); + int global_thread_rank = lower_rank + i; + auto it = std::ranges::find(ranks, global_thread_rank); if (it != ranks.end()) { cudaSetDevice(i); diff --git a/tools/infini_run/infini_run.cc b/tools/infini_run/infini_run.cc index d41cd4dd..86604f54 100644 --- a/tools/infini_run/infini_run.cc +++ b/tools/infini_run/infini_run.cc @@ -15,20 +15,6 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node"); DEFINE_int32(node_rank, 0, "Rank of this node"); DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)"); -void CleanupNcclIdFiles() { - const std::filesystem::path cwd = std::filesystem::current_path(); - std::regex pattern(R"(ncclUniqueId_.*\.bin)"); - - for (const auto &entry : std::filesystem::directory_iterator(cwd)) { - if (entry.is_regular_file()) { - const std::string filename = entry.path().filename().string(); - if (std::regex_match(filename, pattern)) { - std::filesystem::remove(entry.path()); - } - } - } -} - int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); @@ -70,9 +56,5 @@ int main(int argc, char **argv) { wait(&status); } - if (FLAGS_node_rank == 0) { - CleanupNcclIdFiles(); - } - return 0; }