@@ -96,29 +96,84 @@ inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_paralle
9696inline bool GetSequenceParallelEnabled () { return GlobalEnv::Instance ().sequence_parallel_enabled (); }
9797inline int GetDataParallelSize () { return GlobalEnv::Instance ().data_parallel_size (); }
9898
99+ // =========================
99100// Layout Helper Functions
101+ // =========================
102+
103+ /* *
104+ * @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate.
105+ */
100106inline int GetRankOf (int dp, int tp, int pp) { return GlobalEnv::Instance ().layout ().RankOf (dp, tp, pp); }
107+ /* *
108+ * @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank.
109+ */
101110inline void GetCoordOf (int rank, int &dp, int &tp, int &pp) {
102111 return GlobalEnv::Instance ().layout ().CoordOf (rank, dp, tp, pp);
103112}
113+
114+ /* *
115+ * @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis.
116+ */
104117inline int GetGroupId (Axis target, int dp, int tp, int pp) {
105118 return GlobalEnv::Instance ().layout ().GroupId (target, dp, tp, pp);
106119}
120+ /* *
121+ * @brief Get the group ID that a given rank belongs to along a specific parallel axis.
122+ */
107123inline int GetGroupId (Axis target, int rank) {
108124 int dp, tp, pp;
109125 GetCoordOf (rank, dp, tp, pp);
110126 return GlobalEnv::Instance ().layout ().GroupId (target, dp, tp, pp);
111127}
128+
129+ /* *
130+ * @brief Get all ranks that belong to the same group as the given (dp, tp, pp) coordinate
131+ * along a specified parallel axis (e.g., all ranks in the same TP group).
132+ */
112133inline std::vector<int > GetGroupRanks (Axis target, int dp, int tp, int pp) {
113134 return GlobalEnv::Instance ().layout ().GroupRanks (target, dp, tp, pp);
114135}
136+
137+ /* *
138+ * @brief Get all ranks that belong to the same group as the given rank
139+ * along a specified parallel axis (e.g., all ranks in the same DP group).
140+ */
115141inline std::vector<int > GetGroupRanks (Axis target, int rank) {
116142 int dp, tp, pp;
117143 GetCoordOf (rank, dp, tp, pp);
118144 return GlobalEnv::Instance ().layout ().GroupRanks (target, dp, tp, pp);
119145}
120146
147+ /* *
148+ * @brief Generate a human-readable overview of all parallel communication groups.
149+ *
150+ * The output is intended for debugging, logging, and runtime verification of
151+ * distributed parallelism configuration.
152+ *
153+ * @param L The Layout describing DP / TP / PP sizes and axis ordering.
154+ * @param skip_trivial_axes
155+ * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled)
156+ * will be marked as "unenabled" and their detailed group listing will be skipped.
157+ *
158+ * @return A formatted string containing the full overview of process groups.
159+ *
160+ * Example:
161+ * === Parallel Communication Groups ===
162+ * world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP}
163+ * [DP] size=2, num_groups=4
164+ * - DP 0 (dp=-, tp=0, pp=0): [0, 4]
165+ * - DP 1 (dp=-, tp=1, pp=0): [1, 5]
166+ * - DP 2 (dp=-, tp=2, pp=0): [2, 6]
167+ * - DP 3 (dp=-, tp=3, pp=0): [3, 7]
168+ *
169+ * [TP] size=4, num_groups=2
170+ * - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3]
171+ * - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7]
172+ *
173+ * [PP] size=1, unenabled
174+ */
121175std::string ProcessGroupOverview (const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true);
176+
122177#ifdef USE_NCCL
123178inline ncclUniqueId GetNcclId () { return GlobalEnv::Instance ().nccl_id (); }
124179#endif
0 commit comments