1919#include < cub/cub.cuh>
2020
2121#include " tensorrt_llm/common/cudaUtils.h"
22+ #include " tensorrt_llm/common/envUtils.h"
2223#include " tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
2324
2425namespace cg = cooperative_groups;
@@ -28,6 +29,8 @@ namespace tensorrt_llm
2829namespace kernels
2930{
3031
32+ using tensorrt_llm::common::launchWithPdlWhenEnabled;
33+
3134int getOwnerDevice (unsigned long long int stepAndOwner)
3235{
3336 return static_cast <int >(stepAndOwner & MoeLoadBalanceSingleLayerSignal::kDevice );
@@ -71,6 +74,11 @@ __device__ __forceinline__ void moeWaitSignalForGpuStageFunc(MoeLoadBalanceSingl
7174
7275__global__ void moeWaitSignalForGpuStageKernel (MoeLoadBalanceSingleLayerSignal* signal, int * enabled)
7376{
77+
78+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
79+ cudaGridDependencySynchronize ();
80+ cudaTriggerProgrammaticLaunchCompletion ();
81+ #endif
7482 if (threadIdx .x == 0 and blockIdx .x == 0 )
7583 {
7684 moeWaitSignalForGpuStageFunc (signal, enabled);
@@ -79,6 +87,11 @@ __global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal*
7987
8088__global__ void moeSetSignalForCpuStageKernel (MoeLoadBalanceSingleLayerSignal* signal)
8189{
90+
91+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
92+ cudaGridDependencySynchronize ();
93+ cudaTriggerProgrammaticLaunchCompletion ();
94+ #endif
8295 if (threadIdx .x == 0 and blockIdx .x == 0 )
8396 {
8497 unsigned long long int loaded = signal->stepAndOwner ;
@@ -91,7 +104,8 @@ __global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* s
91104
92105void moeWaitSignalForGpuStageDevice (MoeLoadBalanceSingleLayerSignal* signal, int * enabled, cudaStream_t stream)
93106{
94- moeWaitSignalForGpuStageKernel<<<1 , 1 , 0 , stream>>> (signal, enabled);
107+ launchWithPdlWhenEnabled (
108+ " moeWaitSignalForGpuStage" , moeWaitSignalForGpuStageKernel, 1 , 1 , 0 , stream, signal, enabled);
95109}
96110
97111void moeWaitSignalForGpuStageForTest (MoeLoadBalanceSingleLayerSignal* signal, int * enabled)
@@ -119,7 +133,7 @@ void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, in
119133
120134void moeSetSignalForCpuStageDevice (MoeLoadBalanceSingleLayerSignal* signal, cudaStream_t stream)
121135{
122- moeSetSignalForCpuStageKernel<<< 1 , 1 , 0 , stream>>> ( signal);
136+ launchWithPdlWhenEnabled ( " moeSetSignalForCpuStage " , moeSetSignalForCpuStageKernel, 1 , 1 , 0 , stream, signal);
123137}
124138
125139void moeSetSignalForCpuStageForTest (MoeLoadBalanceSingleLayerSignal* signal)
@@ -134,6 +148,10 @@ __global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int*
134148 TYPE oldExpertTokenCount = {0 };
135149 int * expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx .x ;
136150 TYPE* typedExpertTokenCountPtr = reinterpret_cast <TYPE*>(expertTokenCountPtr);
151+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
152+ cudaGridDependencySynchronize ();
153+ cudaTriggerProgrammaticLaunchCompletion ();
154+ #endif
137155 typedExpertTokenCountPtr[threadIdx .x ] = oldExpertTokenCount;
138156}
139157
@@ -145,6 +163,10 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
145163 return ;
146164 }
147165 TYPE oldExpertTokenCount = {0 };
166+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
167+ cudaGridDependencySynchronize ();
168+ cudaTriggerProgrammaticLaunchCompletion ();
169+ #endif
148170 if (blockIdx .x > 0 )
149171 {
150172 int * oldExpertTokenCountPtr = expertTokenCount + metaInfo.expertCount * (blockIdx .x - 1 );
@@ -173,6 +195,10 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertToke
173195 sharedExpertCount[i] = 0 ;
174196 }
175197 __syncthreads ();
198+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
199+ cudaGridDependencySynchronize ();
200+ cudaTriggerProgrammaticLaunchCompletion ();
201+ #endif
176202 for (int idx = threadIdx .x + blockIdx .x * blockDim .x ; idx < totalEltCount; idx += gridDim .x * blockDim .x )
177203 {
178204 int expertId = gatheredRawExpertIds[idx];
@@ -196,6 +222,10 @@ __global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadB
196222 return ;
197223 }
198224 int expertIdx = blockIdx .x * blockDim .x + threadIdx .x ;
225+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
226+ cudaGridDependencySynchronize ();
227+ cudaTriggerProgrammaticLaunchCompletion ();
228+ #endif
199229 int expertTokenCount = expertTokenCountPtr[expertIdx];
200230 float * loadFactor = statisticInfo.expertLoadFactor ;
201231 loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
@@ -228,6 +258,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
228258 = {&metaInfo, static_cast <void *>(const_cast <int **>(&enabled)), static_cast <void *>(&expertTokenCount)};
229259 TLLM_CHECK_WITH_INFO (
230260 threadCount <= 1024 , " expertCount=%d is too large and not supported now." , metaInfo.expertCount );
261+ // TODO: add PDL support with cooperative launch
231262 TLLM_CUDA_CHECK (cudaLaunchCooperativeKernel (kernelFunc, gridDim , blockDim , &args[0 ], 0 , stream));
232263 }
233264
@@ -241,7 +272,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
241272 blockCount = smCount;
242273 }
243274 int sharedMemorySize = metaInfo.expertCount * sizeof (int );
244- statisticKernel<<< blockCount, threadCount, sharedMemorySize, stream>>> (
275+ launchWithPdlWhenEnabled ( " statisticKernel" , statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
245276 metaInfo, statisticInfo.expertTokenCount , totalEltCount, enabled, gatheredRawExpertIds);
246277 }
247278
@@ -250,7 +281,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
250281 // only last stage need update load factor.
251282 int threadCount = 128 ;
252283 int blockCount = (metaInfo.expertCount + threadCount - 1 ) / threadCount;
253- updateLoadFactorKernel<<< blockCount, threadCount, 0 , stream>>> (
284+ launchWithPdlWhenEnabled ( " updateLoadFactor " , updateLoadFactorKernel, blockCount, threadCount, 0 , stream,
254285 metaInfo, statisticInfo, statisticInfo.expertTokenCount , enabled);
255286 }
256287}
@@ -278,11 +309,10 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
278309 }
279310 dim3 gridDim (1 );
280311 dim3 blockDim (threadCount);
281- void * args[]
282- = {&metaInfo, static_cast <void *>(const_cast <int **>(&enabled)), static_cast <void *>(&localExpertTokenCount)};
283312 TLLM_CHECK_WITH_INFO (
284313 threadCount <= 1024 , " expertCount=%d is too large and not supported now." , metaInfo.expertCount );
285- TLLM_CUDA_CHECK (cudaLaunchKernel (kernelFunc, gridDim , blockDim , &args[0 ], 0 , stream));
314+ launchWithPdlWhenEnabled (
315+ " zeroExpertTokenCount" , kernelFunc, gridDim , blockDim , 0 , stream, metaInfo, enabled, localExpertTokenCount);
286316 }
287317
288318 {
@@ -295,7 +325,7 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
295325 blockCount = smCount;
296326 }
297327 int sharedMemorySize = metaInfo.expertCount * sizeof (int );
298- statisticKernel<<< blockCount, threadCount, sharedMemorySize, stream>>> (
328+ launchWithPdlWhenEnabled ( " statisticKernel" , statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
299329 metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
300330 }
301331}
@@ -305,8 +335,8 @@ void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBala
305335{
306336 int threadCount = 128 ;
307337 int blockCount = (metaInfo.expertCount + threadCount - 1 ) / threadCount;
308- updateLoadFactorKernel<<< blockCount, threadCount, 0 , stream>>> (
309- metaInfo, statisticInfo, globalExpertTokenCount, enabled);
338+ launchWithPdlWhenEnabled ( " updateLoadFactor " , updateLoadFactorKernel, blockCount, threadCount, 0 , stream, metaInfo,
339+ statisticInfo, globalExpertTokenCount, enabled);
310340}
311341
312342template <int MAX_EXPERT_COUNT = 1024 , int THREAD_COUNT = 256 , int ITEM_PER_THREAD = 4 >
@@ -316,13 +346,18 @@ __global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo
316346 extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
317347 int expertIds[ITEM_PER_THREAD];
318348 int slotIds[ITEM_PER_THREAD];
349+
350+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
351+ cudaGridDependencySynchronize ();
352+ cudaTriggerProgrammaticLaunchCompletion ();
353+ #endif
354+
319355 for (int slotId = threadIdx .x ; slotId < metaInfo.epSize * metaInfo.slotCountPerRank ; slotId += THREAD_COUNT)
320356 {
321357 sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds [slotId];
322358 }
323359
324360 int blockOffset = blockIdx .x * THREAD_COUNT * ITEM_PER_THREAD;
325-
326361 for (; blockOffset < tokenCount * metaInfo.topK ; blockOffset += gridDim .x * THREAD_COUNT * ITEM_PER_THREAD)
327362 {
328363 int tokenIdxBase = blockOffset + threadIdx .x ;
@@ -375,6 +410,12 @@ __global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacem
375410
376411 __shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
377412 __shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
413+
414+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
415+ cudaGridDependencySynchronize ();
416+ cudaTriggerProgrammaticLaunchCompletion ();
417+ #endif
418+
378419 for (int expertIdx = threadIdx .x ; expertIdx < metaInfo.expertCount ; expertIdx += THREAD_COUNT)
379420 {
380421 int replicaCount = placementInfo.expertReplicaCount [expertIdx];
@@ -480,6 +521,11 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
480521 __shared__ int sharedSortedExpertId[THREAD_COUNT * ITEM_PER_THREAD];
481522 __shared__ int sharedExpertStartThread[MAX_EXPERT_COUNT];
482523
524+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
525+ cudaGridDependencySynchronize ();
526+ cudaTriggerProgrammaticLaunchCompletion ();
527+ #endif
528+
483529 for (int expertIdx = threadIdx .x ; expertIdx < metaInfo.expertCount ; expertIdx += THREAD_COUNT)
484530 {
485531 sharedExpertTokenCount[expertIdx] = 0 ;
@@ -496,7 +542,6 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
496542 __syncthreads ();
497543
498544 int expertIds[ITEM_PER_THREAD];
499-
500545 for (int blockOffset = blockIdx .x * THREAD_COUNT * ITEM_PER_THREAD; blockOffset < tokenCount * metaInfo.topK ;
501546 blockOffset += gridDim .x * THREAD_COUNT * ITEM_PER_THREAD)
502547 {
@@ -582,14 +627,15 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
582627 int dynamicShmSize = sizeof (int16_t ) * metaInfo.epSize * metaInfo.slotCountPerRank ;
583628 if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank )
584629 {
630+ auto * kernelFn = moeComputeRouteNoRedundantKernel<1024 , kThreadCount , kEltPerThread >;
585631 // no redundant expert, so we don't need complex routing, but just assign to the correct solt.
586- moeComputeRouteNoRedundantKernel<1024 , kThreadCount , kEltPerThread >
587- <<<blockCount, kThreadCount , dynamicShmSize, stream>>> (
588- metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
632+ launchWithPdlWhenEnabled (" moeComputeRouteNoRedundant" , kernelFn, blockCount, kThreadCount , dynamicShmSize,
633+ stream, metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
589634 }
590635 else
591636 {
592- moeComputeRouteKernel<1024 , kThreadCount , kEltPerThread ><<<blockCount, kThreadCount , dynamicShmSize, stream>>> (
637+ auto * kernelFn = moeComputeRouteKernel<1024 , kThreadCount , kEltPerThread >;
638+ launchWithPdlWhenEnabled (" moeComputeRoute" , kernelFn, blockCount, kThreadCount , dynamicShmSize, stream,
593639 metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
594640 }
595641}
0 commit comments