Skip to content
10 changes: 5 additions & 5 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ void Train(const nn::parallel::Rank &rank) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
GetDataParallelGroupRanks(rank.thread_rank()));
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
}

if (tp_world_size > 1) {
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
GetTensorParallelGroupRanks(rank.thread_rank()));
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
Expand Down Expand Up @@ -312,7 +312,7 @@ int main(int argc, char *argv[]) {
if (FLAGS_nthread_per_process > 1) {
std::vector<std::thread> threads;
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
threads.emplace_back(Train, rank);
}
Expand Down
10 changes: 5 additions & 5 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ void Train(const nn::parallel::Rank &rank) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
GetDataParallelGroupRanks(rank.thread_rank()));
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
}

if (tp_world_size > 1) {
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
GetTensorParallelGroupRanks(rank.thread_rank()));
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
Expand Down Expand Up @@ -292,7 +292,7 @@ int main(int argc, char *argv[]) {
if (FLAGS_nthread_per_process > 1) {
std::vector<std::thread> threads;
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
threads.emplace_back(Train, rank);
}
Expand Down
65 changes: 62 additions & 3 deletions infini_train/include/nn/parallel/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class GlobalEnv {

void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false);

int nnodes() const;

int nproc_per_node() const;

int world_size() const;

int global_proc_rank() const;

int local_proc_rank() const;

int nproc_per_node() const;

int nthread_per_process() const;

int tensor_parallel_size() const;
Expand All @@ -54,9 +56,11 @@ class GlobalEnv {
GlobalEnv &operator=(const GlobalEnv &) = delete;

private:
int world_size_ = 1;
int nnodes_ = 1;
int nproc_per_node_ = 1;
int nthread_per_process_ = 1;
int world_size_ = 1;

int global_proc_rank_ = 0;
int local_proc_rank_ = 0;

Expand All @@ -75,6 +79,7 @@ inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool s
GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled);
}

inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); }
inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); }
inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); }
inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); }
Expand All @@ -85,28 +90,82 @@ inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_paralle
inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); }
inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); }

// =========================
// Layout Helper Functions
// =========================

/**
* @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate.
*/
inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); }
/**
* @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank.
*/
inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) {
return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp);
}

/**
* @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis.
*/
inline int GetGroupId(Axis target, int dp, int tp, int pp) {
return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp);
}
/**
* @brief Get the group ID that a given rank belongs to along a specific parallel axis.
*/
inline int GetGroupId(Axis target, int rank) {
int dp, tp, pp;
GetCoordOf(rank, dp, tp, pp);
return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp);
}

/**
* @brief Get all ranks that belong to the same group as the given (dp, tp, pp) coordinate
* along a specified parallel axis (e.g., all ranks in the same TP group).
*/
inline std::vector<int> GetGroupRanks(Axis target, int dp, int tp, int pp) {
return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp);
}

/**
* @brief Get all ranks that belong to the same group as the given rank
* along a specified parallel axis (e.g., all ranks in the same DP group).
*/
inline std::vector<int> GetGroupRanks(Axis target, int rank) {
int dp, tp, pp;
GetCoordOf(rank, dp, tp, pp);
return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp);
}

/**
* @brief Generate a human-readable overview of all parallel communication groups.
*
* The output is intended for debugging, logging, and runtime verification of
* distributed parallelism configuration.
*
* @param L The Layout describing DP / TP / PP sizes and axis ordering.
* @param skip_trivial_axes
* If true, axes whose size <= 1(i.e. parallel strategy that is not enabled)
* will be marked as "unenabled" and their detailed group listing will be skipped.
*
* @return A formatted string containing the full overview of process groups.
*
* Example:
* === Parallel Communication Groups ===
* world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP}
* [DP] size=2, num_groups=4
* - DP 0 (dp=-, tp=0, pp=0): [0, 4]
* - DP 1 (dp=-, tp=1, pp=0): [1, 5]
* - DP 2 (dp=-, tp=2, pp=0): [2, 6]
* - DP 3 (dp=-, tp=3, pp=0): [3, 7]
*
* [TP] size=4, num_groups=2
* - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3]
* - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7]
*
* [PP] size=1, unenabled
*/
std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true);

} // namespace infini_train::nn::parallel::global
37 changes: 35 additions & 2 deletions infini_train/include/nn/parallel/process_group.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once

#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

Expand All @@ -26,7 +28,8 @@ namespace infini_train::nn::parallel {
#ifdef USE_NCCL
class ProcessGroup {
public:
explicit ProcessGroup(const std::vector<int> &device_indices);
explicit ProcessGroup(const std::string &process_group_name, const std::vector<int> &device_indices);
~ProcessGroup();

int GetGroupRank(int thread_rank) const;

Expand All @@ -52,14 +55,23 @@ class ProcessGroup {

std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;

private:
void InitSingleProcess(const std::vector<int> &ranks);

void InitMultiProcess(const std::vector<int> &ranks);

private:
std::vector<ncclComm_t> comms_;
std::vector<const Device *> devices_;

std::unordered_map<const Device *, ncclComm_t> device_comm_map_;
std::unordered_map<int, int> thread_group_rank_map_; // thread_rank : group_rank

int comm_size_ = 0;
int world_size_ = 0;

const std::string name_ = "";

bool is_main_process_ = false;
};
#endif

Expand All @@ -79,8 +91,29 @@ class ProcessGroupFactory {

private:
ProcessGroupFactory();

template <typename Creator, typename = std::enable_if_t<std::is_invocable_v<Creator>>>
const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) {
std::unique_lock<std::mutex> lock(mutex_);
auto [it, inserted] = name_to_group_.emplace(name, nullptr);
if (!inserted) {
while (it->second == nullptr) { cond_.wait(lock); }
return it->second.get();
}

lock.unlock();
auto new_group = creator();
lock.lock();

it->second = std::move(new_group);
cond_.notify_all();
return it->second.get();
}

private:
// TODO(dcj): maybe RWLock later?
mutable std::mutex mutex_;
std::condition_variable cond_;
std::unordered_map<std::string, std::unique_ptr<ProcessGroup>> name_to_group_;
};
} // namespace infini_train::nn::parallel
2 changes: 2 additions & 0 deletions infini_train/include/nn/parallel/rank.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class Rank {
int process_size() const;
int thread_size() const;

int GlobalRank() const;

bool IsParallel() const;

bool IsMainRank() const;
Expand Down
2 changes: 1 addition & 1 deletion infini_train/src/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ nn::parallel::Rank CudaDevice::rank() const { return rank_; }

CudaDevice::CudaDevice(int8_t index)
: Device(DeviceType::kCUDA, index),
rank_({nn::parallel::global::GetLocalProcRank(), index, nn::parallel::global::GetNprocPerNode(),
rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(),
nn::parallel::global::GetNthreadPerProc()}) {
// TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode
SetDevice();
Expand Down
3 changes: 2 additions & 1 deletion infini_train/src/nn/parallel/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module";

auto ddp_pg
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank()));
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().GlobalRank()));
// FIXME(dcj): use multi-node ddp_pg here
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg,
ddp_pg);
param->RegisterPostAccumulateGradHook(std::move(hook));
Expand Down
62 changes: 23 additions & 39 deletions infini_train/src/nn/parallel/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdlib>
#include <format>
#include <nccl.h>
#include <string>

#include "glog/logging.h"
Expand All @@ -13,6 +14,11 @@ int GetEnvAsInt(const std::string &name, int default_value) {
return value ? std::atoi(value) : default_value;
}

std::string GetEnvAsStr(const std::string &name, const std::string &default_value) {
const char *value = std::getenv(name.c_str());
return value ? std::string(value) : default_value;
}

} // namespace

namespace infini_train::nn::parallel::global {
Expand Down Expand Up @@ -89,8 +95,9 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq

CHECK(!initialized_) << "Repeated initialization of GlobalEnv!";

world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process;
nnodes_ = GetEnvAsInt("NNODES", 1);
nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", 1);
world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process;
global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", 0);
local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", 0);

Expand All @@ -109,29 +116,34 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
initialized_ = true;
}

int GlobalEnv::world_size() const {
int GlobalEnv::nnodes() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return world_size_;
return nnodes_;
}

int GlobalEnv::global_proc_rank() const {
int GlobalEnv::nproc_per_node() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return global_proc_rank_;
return nproc_per_node_;
}

int GlobalEnv::local_proc_rank() const {
int GlobalEnv::nthread_per_process() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return local_proc_rank_;
return nthread_per_process_;
}

int GlobalEnv::nproc_per_node() const {
int GlobalEnv::world_size() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return nproc_per_node_;
return world_size_;
}

int GlobalEnv::nthread_per_process() const {
int GlobalEnv::global_proc_rank() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return nthread_per_process_;
return global_proc_rank_;
}

int GlobalEnv::local_proc_rank() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return local_proc_rank_;
}

int GlobalEnv::tensor_parallel_size() const {
Expand Down Expand Up @@ -201,34 +213,6 @@ inline void AppendAxisGroups(std::ostringstream &oss, const Layout &L, Axis targ
}
}

/**
* @brief Generate a human-readable overview of all parallel communication groups.
*
* The output is intended for debugging, logging, and runtime verification of
* distributed parallelism configuration.
*
* @param L The Layout describing DP / TP / PP sizes and axis ordering.
* @param skip_trivial_axes
* If true, axes whose size <= 1(i.e. parallel strategy that is not enabled)
* will be marked as "unenabled" and their detailed group listing will be skipped.
*
* @return A formatted string containing the full overview of process groups.
*
* Example:
* === Parallel Communication Groups ===
* world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP}
* [DP] size=2, num_groups=4
* - DP 0 (dp=-, tp=0, pp=0): [0, 4]
* - DP 1 (dp=-, tp=1, pp=0): [1, 5]
* - DP 2 (dp=-, tp=2, pp=0): [2, 6]
* - DP 3 (dp=-, tp=3, pp=0): [3, 7]
*
* [TP] size=4, num_groups=2
* - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3]
* - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7]
*
* [PP] size=1, unenabled
*/
std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) {
std::ostringstream oss;
oss << std::format("\n=== Parallel Communication Groups ===\n"
Expand Down
Loading