@@ -30,6 +30,8 @@ namespace tensorrt_llm::kernels
3030namespace moe_prepare
3131{
3232
33+ using tensorrt_llm::common::launchWithPdlWhenEnabled;
34+
3335__device__ __forceinline__ void st_release_sys_global (uint64_t volatile * ptr, uint64_t val)
3436{
3537 asm volatile (" st.release.sys.global.u64 [%0], %1;" ::" l" (ptr), " l" (val) : " memory" );
@@ -110,6 +112,10 @@ __device__ __forceinline__ void computeCountAndSendStatics(int* experts, int tok
110112 int * localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank;
111113 int * localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank;
112114
115+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
116+ cudaGridDependencySynchronize ();
117+ #endif
118+
113119 for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock)
114120 {
115121 int expertRankId = laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize;
@@ -163,6 +169,11 @@ __device__ __forceinline__ void recvCountAndStatics(int* recvIndiceWorkspace, in
163169
164170 CounterCommunicator counter (workspace.getFifoConnInfo (false , rankId, targetRankId, 0 , rankCount, 1 ));
165171 int communicationCount = gatheredExpertStatics == nullptr ? 1 : expertCount + 1 ;
172+
173+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
174+ cudaGridDependencySynchronize ();
175+ #endif
176+
166177 for (int i = rankTile.thread_rank (); i < communicationCount; i += THREADS_PER_PIPELINE)
167178 {
168179 int recvValue = counter.acquireValue (i);
@@ -218,6 +229,9 @@ __global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, i
218229 int count = endIndex - startIndex;
219230 int * localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank;
220231 int * localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank;
232+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
233+ cudaGridDependencySynchronize ();
234+ #endif
221235 for (int localIdx = threadIdx .x ; localIdx < count; localIdx += blockDim .x )
222236 {
223237 gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx];
@@ -230,6 +244,9 @@ __global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, i
230244 int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1 ];
231245 int endIndex = recvCountsCumsum[targetRankId];
232246 int count = endIndex - startIndex;
247+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
248+ cudaGridDependencySynchronize ();
249+ #endif
233250 for (int localIdx = threadIdx .x ; localIdx < count; localIdx += blockDim .x )
234251 {
235252 gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx;
@@ -249,6 +266,10 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
249266 int threadData = tid < rankCount ? inputOutputPtr[tid] : 0 ;
250267 __syncthreads ();
251268
269+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
270+ cudaGridDependencySynchronize ();
271+ #endif
272+
252273 BlockScan (temp_storage).InclusiveSum (threadData, threadData);
253274 if (tid < rankCount)
254275 {
@@ -261,6 +282,9 @@ __global__ void memsetExpertIdsDevice(
261282{
262283 int maxTokenCount = maxTokenCountPerRank * rankCount;
263284 int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1 );
285+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
286+ cudaGridDependencySynchronize ();
287+ #endif
264288 for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i + totalRecvTokenCount * topK < maxTokenCount * topK;
265289 i += gridDim .x * blockDim .x )
266290 {
@@ -300,17 +324,20 @@ void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int*
300324 {
301325 kernelFn = computeCountAndIndiceDevice<2 >;
302326 }
303- kernelFn<<<grid, block, 0 , stream>>> (experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace,
304- recvIndiceWorkspace, expertStatics, gatheredExpertStatics, workspace, tokenCount, maxTokenCountPerRank, topK,
305- slotCount, expertCount, rankId, rankCount);
327+
328+ launchWithPdlWhenEnabled (" computeCountAndIndice" , kernelFn, grid, block, 0 , stream, experts, sendCounts, recvCounts,
329+ sendIndiceWorkspace, backwardIndiceWorkspace, recvIndiceWorkspace, expertStatics, gatheredExpertStatics,
330+ workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, expertCount, rankId, rankCount);
306331}
307332
308333void computeCumsum (int * sendCountsCumsum, int * recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream)
309334{
310335 int block_size = CUMSUM_THREADS_PER_BLOCK;
311336 dim3 block (block_size);
312337 dim3 grid (2 );
313- computeCumsumDevice<<<grid, block, 0 , stream>>> (sendCountsCumsum, recvCountsCumsum, rankId, rankCount);
338+
339+ launchWithPdlWhenEnabled (" computeCumsum" , computeCumsumDevice, grid, block, 0 , stream, sendCountsCumsum,
340+ recvCountsCumsum, rankId, rankCount);
314341}
315342
316343void moveIndice (int * sendCountsCumsum, int * recvCountsCumsum, int * sendIndice, int * gatherSendIndice,
@@ -319,17 +346,22 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
319346{
320347 dim3 block (512 );
321348 dim3 grid (rankCount, 2 );
322- moveIndiceDevice<<<grid, block, 0 , stream>>> (sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice,
323- backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank);
349+
350+ launchWithPdlWhenEnabled (" moveIndice" , moveIndiceDevice, grid, block, 0 , stream, sendCountsCumsum, recvCountsCumsum,
351+ sendIndice, gatherSendIndice, backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice,
352+ maxTokenCountPerRank);
324353}
325354
326355void memsetExpertIds (int * expertIds, int * recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
327356 int rankCount, cudaStream_t stream)
328357{
329358 int smCount = tensorrt_llm::common::getMultiProcessorCount ();
330359 int block_size = 256 ;
331- memsetExpertIdsDevice<<<smCount, block_size, 0 , stream>>> (
332- expertIds, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
360+ dim3 block (block_size);
361+ dim3 grid (smCount);
362+
363+ launchWithPdlWhenEnabled (" memsetExpertIds" , memsetExpertIdsDevice, grid, block, 0 , stream, expertIds,
364+ recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
333365}
334366
335367size_t getMoePrepareWorkspaceSize (int epSize)
0 commit comments