Skip to content

Commit 20e2a2c

Browse files
dc3671dominicshanshan
authored andcommitted
[TRTLLM-6748][feat] add PDL support for more kernels (NVIDIA#7977)
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
1 parent 587638a commit 20e2a2c

File tree

9 files changed

+144
-28
lines changed

9 files changed

+144
-28
lines changed

cpp/tensorrt_llm/common/envUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717

1818
#pragma once
19+
#include "tensorrt_llm/common/cudaUtils.h"
1920
#include <cstdint>
21+
#include <cuda_runtime.h>
2022
#include <optional>
2123
#include <string>
2224

@@ -55,6 +57,26 @@ int getEnvMmhaKernelBlockSize();
5557
// Whether PDL is enabled.
5658
bool getEnvEnablePDL();
5759

60+
template <typename KernelFn, typename... Args>
61+
inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 grid, dim3 block, size_t dynamicShmSize,
62+
cudaStream_t stream, Args&&... args)
63+
{
64+
TLLM_LOG_DEBUG("Enable PDL in %s", name);
65+
cudaLaunchConfig_t kernelConfig;
66+
kernelConfig.gridDim = grid;
67+
kernelConfig.blockDim = block;
68+
kernelConfig.dynamicSmemBytes = dynamicShmSize;
69+
kernelConfig.stream = stream;
70+
71+
cudaLaunchAttribute attrs[1];
72+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
73+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
74+
kernelConfig.attrs = attrs;
75+
kernelConfig.numAttrs = 1;
76+
77+
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...));
78+
}
79+
5880
bool getEnvUseUCXKvCache();
5981

6082
bool getEnvUseMPIKvCache();

cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace tensorrt_llm
2727
namespace kernels
2828
{
2929

30+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
31+
3032
// Quantize a contiguous shared-memory buffer containing elements of DType into NVFP4 with per-16-element FP8 scales.
3133
// Output layout (repeated per 16-element group per lane), followed by one global scale float:
3234
// [WARP_SIZE * 8 bytes packed e2m1 values] [WARP_SIZE * 1 byte E4M3 per-group scales] ... [global_scale (4 bytes)]
@@ -1069,6 +1071,10 @@ public:
10691071

10701072
int sendIndex = mPairInfo.channel;
10711073
uint32_t phaseParity = 0;
1074+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1075+
cudaGridDependencySynchronize();
1076+
cudaTriggerProgrammaticLaunchCompletion();
1077+
#endif
10721078
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
10731079
{
10741080
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
@@ -1140,6 +1146,10 @@ public:
11401146
int recvIndex = mPairInfo.channel;
11411147
uint32_t phaseParity = 0;
11421148
bool needRelease = false;
1149+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1150+
cudaGridDependencySynchronize();
1151+
cudaTriggerProgrammaticLaunchCompletion();
1152+
#endif
11431153
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
11441154
{
11451155
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
@@ -1459,7 +1469,8 @@ void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cu
14591469

14601470
dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
14611471
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
1462-
kernelFn<<<grid, block, totalDynamicShmSize, stream>>>(params, workspace, hasBasicFields);
1472+
launchWithPdlWhenEnabled(
1473+
"moeAllToAll", kernelFn, grid, block, totalDynamicShmSize, stream, params, workspace, hasBasicFields);
14631474
TLLM_CUDA_CHECK(cudaGetLastError());
14641475
}
14651476

cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cuda_runtime_api.h>
2121

2222
#include "tensorrt_llm/common/cudaUtils.h"
23+
#include "tensorrt_llm/common/envUtils.h"
2324
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
2425

2526
namespace tensorrt_llm

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cub/cub.cuh>
2020

2121
#include "tensorrt_llm/common/cudaUtils.h"
22+
#include "tensorrt_llm/common/envUtils.h"
2223
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
2324

2425
namespace cg = cooperative_groups;
@@ -28,6 +29,8 @@ namespace tensorrt_llm
2829
namespace kernels
2930
{
3031

32+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
33+
3134
int getOwnerDevice(unsigned long long int stepAndOwner)
3235
{
3336
return static_cast<int>(stepAndOwner & MoeLoadBalanceSingleLayerSignal::kDevice);
@@ -71,6 +74,11 @@ __device__ __forceinline__ void moeWaitSignalForGpuStageFunc(MoeLoadBalanceSingl
7174

7275
__global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal* signal, int* enabled)
7376
{
77+
78+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
79+
cudaGridDependencySynchronize();
80+
cudaTriggerProgrammaticLaunchCompletion();
81+
#endif
7482
if (threadIdx.x == 0 and blockIdx.x == 0)
7583
{
7684
moeWaitSignalForGpuStageFunc(signal, enabled);
@@ -79,6 +87,11 @@ __global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal*
7987

8088
__global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* signal)
8189
{
90+
91+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
92+
cudaGridDependencySynchronize();
93+
cudaTriggerProgrammaticLaunchCompletion();
94+
#endif
8295
if (threadIdx.x == 0 and blockIdx.x == 0)
8396
{
8497
unsigned long long int loaded = signal->stepAndOwner;
@@ -91,7 +104,8 @@ __global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* s
91104

92105
void moeWaitSignalForGpuStageDevice(MoeLoadBalanceSingleLayerSignal* signal, int* enabled, cudaStream_t stream)
93106
{
94-
moeWaitSignalForGpuStageKernel<<<1, 1, 0, stream>>>(signal, enabled);
107+
launchWithPdlWhenEnabled(
108+
"moeWaitSignalForGpuStage", moeWaitSignalForGpuStageKernel, 1, 1, 0, stream, signal, enabled);
95109
}
96110

97111
void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, int* enabled)
@@ -119,7 +133,7 @@ void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, in
119133

120134
void moeSetSignalForCpuStageDevice(MoeLoadBalanceSingleLayerSignal* signal, cudaStream_t stream)
121135
{
122-
moeSetSignalForCpuStageKernel<<<1, 1, 0, stream>>>(signal);
136+
launchWithPdlWhenEnabled("moeSetSignalForCpuStage", moeSetSignalForCpuStageKernel, 1, 1, 0, stream, signal);
123137
}
124138

125139
void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal)
@@ -134,6 +148,10 @@ __global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int*
134148
TYPE oldExpertTokenCount = {0};
135149
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
136150
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);
151+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
152+
cudaGridDependencySynchronize();
153+
cudaTriggerProgrammaticLaunchCompletion();
154+
#endif
137155
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
138156
}
139157

@@ -145,6 +163,10 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
145163
return;
146164
}
147165
TYPE oldExpertTokenCount = {0};
166+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
167+
cudaGridDependencySynchronize();
168+
cudaTriggerProgrammaticLaunchCompletion();
169+
#endif
148170
if (blockIdx.x > 0)
149171
{
150172
int* oldExpertTokenCountPtr = expertTokenCount + metaInfo.expertCount * (blockIdx.x - 1);
@@ -173,6 +195,10 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertToke
173195
sharedExpertCount[i] = 0;
174196
}
175197
__syncthreads();
198+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
199+
cudaGridDependencySynchronize();
200+
cudaTriggerProgrammaticLaunchCompletion();
201+
#endif
176202
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < totalEltCount; idx += gridDim.x * blockDim.x)
177203
{
178204
int expertId = gatheredRawExpertIds[idx];
@@ -196,6 +222,10 @@ __global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadB
196222
return;
197223
}
198224
int expertIdx = blockIdx.x * blockDim.x + threadIdx.x;
225+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
226+
cudaGridDependencySynchronize();
227+
cudaTriggerProgrammaticLaunchCompletion();
228+
#endif
199229
int expertTokenCount = expertTokenCountPtr[expertIdx];
200230
float* loadFactor = statisticInfo.expertLoadFactor;
201231
loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
@@ -228,6 +258,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
228258
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&expertTokenCount)};
229259
TLLM_CHECK_WITH_INFO(
230260
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
261+
// TODO: add PDL support with cooperative launch
231262
TLLM_CUDA_CHECK(cudaLaunchCooperativeKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
232263
}
233264

@@ -241,7 +272,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
241272
blockCount = smCount;
242273
}
243274
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
244-
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
275+
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
245276
metaInfo, statisticInfo.expertTokenCount, totalEltCount, enabled, gatheredRawExpertIds);
246277
}
247278

@@ -250,7 +281,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
250281
// only last stage need update load factor.
251282
int threadCount = 128;
252283
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
253-
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
284+
launchWithPdlWhenEnabled("updateLoadFactor", updateLoadFactorKernel, blockCount, threadCount, 0, stream,
254285
metaInfo, statisticInfo, statisticInfo.expertTokenCount, enabled);
255286
}
256287
}
@@ -278,11 +309,10 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
278309
}
279310
dim3 gridDim(1);
280311
dim3 blockDim(threadCount);
281-
void* args[]
282-
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&localExpertTokenCount)};
283312
TLLM_CHECK_WITH_INFO(
284313
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
285-
TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
314+
launchWithPdlWhenEnabled(
315+
"zeroExpertTokenCount", kernelFunc, gridDim, blockDim, 0, stream, metaInfo, enabled, localExpertTokenCount);
286316
}
287317

288318
{
@@ -295,7 +325,7 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
295325
blockCount = smCount;
296326
}
297327
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
298-
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
328+
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
299329
metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
300330
}
301331
}
@@ -305,8 +335,8 @@ void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBala
305335
{
306336
int threadCount = 128;
307337
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
308-
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
309-
metaInfo, statisticInfo, globalExpertTokenCount, enabled);
338+
launchWithPdlWhenEnabled("updateLoadFactor", updateLoadFactorKernel, blockCount, threadCount, 0, stream, metaInfo,
339+
statisticInfo, globalExpertTokenCount, enabled);
310340
}
311341

312342
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
@@ -316,13 +346,18 @@ __global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo
316346
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
317347
int expertIds[ITEM_PER_THREAD];
318348
int slotIds[ITEM_PER_THREAD];
349+
350+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
351+
cudaGridDependencySynchronize();
352+
cudaTriggerProgrammaticLaunchCompletion();
353+
#endif
354+
319355
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
320356
{
321357
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
322358
}
323359

324360
int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;
325-
326361
for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
327362
{
328363
int tokenIdxBase = blockOffset + threadIdx.x;
@@ -375,6 +410,12 @@ __global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacem
375410

376411
__shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
377412
__shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
413+
414+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
415+
cudaGridDependencySynchronize();
416+
cudaTriggerProgrammaticLaunchCompletion();
417+
#endif
418+
378419
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
379420
{
380421
int replicaCount = placementInfo.expertReplicaCount[expertIdx];
@@ -480,6 +521,11 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
480521
__shared__ int sharedSortedExpertId[THREAD_COUNT * ITEM_PER_THREAD];
481522
__shared__ int sharedExpertStartThread[MAX_EXPERT_COUNT];
482523

524+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
525+
cudaGridDependencySynchronize();
526+
cudaTriggerProgrammaticLaunchCompletion();
527+
#endif
528+
483529
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
484530
{
485531
sharedExpertTokenCount[expertIdx] = 0;
@@ -496,7 +542,6 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
496542
__syncthreads();
497543

498544
int expertIds[ITEM_PER_THREAD];
499-
500545
for (int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD; blockOffset < tokenCount * metaInfo.topK;
501546
blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
502547
{
@@ -582,14 +627,15 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
582627
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
583628
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
584629
{
630+
auto* kernelFn = moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>;
585631
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
586-
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
587-
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
588-
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
632+
launchWithPdlWhenEnabled("moeComputeRouteNoRedundant", kernelFn, blockCount, kThreadCount, dynamicShmSize,
633+
stream, metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
589634
}
590635
else
591636
{
592-
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
637+
auto* kernelFn = moeComputeRouteKernel<1024, kThreadCount, kEltPerThread>;
638+
launchWithPdlWhenEnabled("moeComputeRoute", kernelFn, blockCount, kThreadCount, dynamicShmSize, stream,
593639
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
594640
}
595641
}

0 commit comments

Comments
 (0)