Skip to content

Commit ed9fba4

Browse files
committed
chore: add comments in global.h
1 parent 6cc74b7 commit ed9fba4

File tree

1 file changed

+55
-0
lines changed
  • infini_train/include/nn/parallel

1 file changed

+55
-0
lines changed

infini_train/include/nn/parallel/global.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,29 +96,84 @@ inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_paralle
9696
inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); }
9797
inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); }
9898

99+
// =========================
99100
// Layout Helper Functions
101+
// =========================
102+
103+
/**
104+
* @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate.
105+
*/
100106
inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); }
107+
/**
108+
* @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank.
109+
*/
101110
inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) {
102111
return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp);
103112
}
113+
114+
/**
115+
* @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis.
116+
*/
104117
inline int GetGroupId(Axis target, int dp, int tp, int pp) {
105118
return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp);
106119
}
120+
/**
121+
* @brief Get the group ID that a given rank belongs to along a specific parallel axis.
122+
*/
107123
inline int GetGroupId(Axis target, int rank) {
108124
int dp, tp, pp;
109125
GetCoordOf(rank, dp, tp, pp);
110126
return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp);
111127
}
128+
129+
/**
130+
* @brief Get all ranks that belong to the same group as the given (dp, tp, pp) coordinate
131+
* along a specified parallel axis (e.g., all ranks in the same TP group).
132+
*/
112133
inline std::vector<int> GetGroupRanks(Axis target, int dp, int tp, int pp) {
113134
return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp);
114135
}
136+
137+
/**
138+
* @brief Get all ranks that belong to the same group as the given rank
139+
* along a specified parallel axis (e.g., all ranks in the same DP group).
140+
*/
115141
inline std::vector<int> GetGroupRanks(Axis target, int rank) {
116142
int dp, tp, pp;
117143
GetCoordOf(rank, dp, tp, pp);
118144
return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp);
119145
}
120146

147+
/**
148+
* @brief Generate a human-readable overview of all parallel communication groups.
149+
*
150+
* The output is intended for debugging, logging, and runtime verification of
151+
* distributed parallelism configuration.
152+
*
153+
* @param L The Layout describing DP / TP / PP sizes and axis ordering.
154+
* @param skip_trivial_axes
155+
* If true, axes whose size <= 1(i.e. parallel strategy that is not enabled)
156+
* will be marked as "unenabled" and their detailed group listing will be skipped.
157+
*
158+
* @return A formatted string containing the full overview of process groups.
159+
*
160+
* Example:
161+
* === Parallel Communication Groups ===
162+
* world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP}
163+
* [DP] size=2, num_groups=4
164+
* - DP 0 (dp=-, tp=0, pp=0): [0, 4]
165+
* - DP 1 (dp=-, tp=1, pp=0): [1, 5]
166+
* - DP 2 (dp=-, tp=2, pp=0): [2, 6]
167+
* - DP 3 (dp=-, tp=3, pp=0): [3, 7]
168+
*
169+
* [TP] size=4, num_groups=2
170+
* - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3]
171+
* - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7]
172+
*
173+
* [PP] size=1, unenabled
174+
*/
121175
std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true);
176+
122177
#ifdef USE_NCCL
123178
inline ncclUniqueId GetNcclId() { return GlobalEnv::Instance().nccl_id(); }
124179
#endif

0 commit comments

Comments
 (0)