@@ -146,13 +146,103 @@ commResult_t ctranAlltoallvDynamicSplit(
146146 CtranComm* comm,
147147 cudaStream_t stream);
148148
149- // Note: we support the combine and dispatch APIs by keeping both recvbuffs (for
150- // dispatch) and recvbuff (for combine), for implementation simplicity, as we
151- // now only have two moving variables for the two APIs: recvbuffs and
152- // recvAllSplitLengths.
153- // In the future, if the moving varibles increase and it requires more
154- // extendiblity, we should implement a new class/struct like Hints, to support
155- // various metadata type.
149+ /* *
150+ * Note: we support the combine and dispatch APIs by keeping both recvbuffs (for
151+ * dispatch) and recvbuff (for combine), for implementation simplicity, as we
152+ * now only have two moving variables for the two APIs: recvbuffs and
153+ * recvAllSplitLengths.
154+ * In the future, if the moving varibles increase and it requires more
155+ * extendiblity, we should implement a new class/struct like Hints, to support
156+ * various metadata type.
157+ *
158+ * All-to-all communication with dynamic split lengths and non-contiguous
159+ * receive buffers. Designed for expert-parallel workloads (e.g., Mixture of
160+ * Experts) where data needs to be routed to different experts across ranks.
161+ *
162+ * EXAMPLE SCENARIO:
163+ * ----------------
164+ * 4 ranks
165+ *
166+ * Rank 0 wants to send data to other ranks with different amounts:
167+ *
168+ * SEND-SIDE (What Rank 0 sends):
169+ * ------------------------------
170+ * sendbuff: [100 ints | 200 ints | 150 ints | 50 ints | 300 ints | 100 ints]
171+ * chunk-id: 0 1 2 3 4 5
172+ *
173+ * sendSplitLengths = [100, 200, 150, 50, 300, 100]
174+ * Meaning: how many elements are in each chunk
175+ * numSendSplitLengths = 6 (total number of chunks)
176+ *
177+ * sendIndices = [0, 1, 4, 3, 5]
178+ * Meaning: A list of chunk ids to send
179+ * (NOTE: we don't send chunk2 at all)
180+ *
181+ * sendIndicesBlockLengths = [2, 0, 1, 2] (#entries = #ranks)
182+ * Meaning: chunk 0,1 goes to rank0
183+ * NOTHING goes to rank1
184+ * chunk 4 goes to rank2
185+ * chunk 3,5 goes to rank3
186+ *
187+ *
188+ * RECEIVE-SIDE (What Rank 0 receives):
189+ * ------------------------------------
190+ * recvbuffs[0]: Data from Rank 0
191+ * recvbuffs[1]: Data from Rank 1
192+ * recvbuffs[2]: Data from Rank 2
193+ * recvbuffs[3]: Data from Rank 3
194+ *
195+ * recvAllSplitLengths (output): allgathered sendSplitLengths from all ranks.
196+ * e.g [sendSplitLengths-rank0, sendSplitLengths-rank1, sendSplitLengths-rank2,
197+ * sendSplitLengths-rank3]
198+ * this is used in Dispatch(), but not in Combine().
199+ *
200+ *
201+ * PARAMETERS:
202+ * -----------
203+ * @param sendbuff GPU buffer containing all data to send
204+ * Layout: concatenated chunks in order of
205+ * sendSplitLengths
206+ *
207+ * @param sendSplitLengths GPU array of size numSendSplitLengths
208+ * Specifies number of elements in each chunk
209+ *
210+ * @param numSendSplitLengths Total number of chunks
211+ *
212+ *
213+ * @param sendIndices GPU array of chunk ids to send
214+ *
215+ *
216+ * @param sendIndicesBlockLengths GPU array of size numRanks
217+ * Number of chunks sent to each rank
218+ * Partitions sendIndices array by destination
219+ *
220+ * @param recvbuffs Array of numRanks GPU buffer pointers
221+ * recvbuffs[i] receives all data from rank i
222+ * Non-contiguous: separate buffer per sender
223+ *
224+ * @param maxSendcount Maximum total elements that can be sent
225+ * (buffer capacity)
226+ *
227+ * @param maxRecvcount Maximum total elements that can be received
228+ * (buffer capacity)
229+ *
230+ * @param hints Communication hints for optimization
231+ *
232+ * @param datatype Data type of elements (e.g., commInt32)
233+ *
234+ * @param comm Ctran communicator
235+ *
236+ * @param stream CUDA stream for async operations
237+ *
238+ * @param combine false = dispatch mode (scatter TO experts)
239+ * true = combine mode (gather FROM experts)
240+ *
241+ * @param recvAllSplitLengths Optional GPU output buffer of size
242+ * (numRanks * numSendSplitLengths)
243+ * Stores actual received sizes for each chunk
244+ * from each sender
245+ */
156246commResult_t ctranAlltoallvDynamicSplitNonContig (
157247 const void * sendbuff,
158248 const size_t * sendSplitLengths,
0 commit comments