Skip to content

Commit f3322f1

Browse files
committed
add pdl support for more kernels
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
1 parent 2e5850c commit f3322f1

File tree

7 files changed

+100
-18
lines changed

7 files changed

+100
-18
lines changed

cpp/tensorrt_llm/common/envUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717

1818
#pragma once
19+
#include "tensorrt_llm/common/cudaUtils.h"
1920
#include <cstdint>
21+
#include <cuda_runtime.h>
2022
#include <optional>
2123
#include <string>
2224

@@ -55,6 +57,26 @@ int getEnvMmhaKernelBlockSize();
5557
// Whether PDL is enabled.
5658
bool getEnvEnablePDL();
5759

60+
template <typename KernelFn, typename... Args>
61+
inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 grid, dim3 block, size_t dynamicShmSize,
62+
cudaStream_t stream, Args&&... args)
63+
{
64+
TLLM_LOG_DEBUG("Enable PDL in %s", name);
65+
cudaLaunchConfig_t kernelConfig;
66+
kernelConfig.gridDim = grid;
67+
kernelConfig.blockDim = block;
68+
kernelConfig.dynamicSmemBytes = dynamicShmSize;
69+
kernelConfig.stream = stream;
70+
71+
cudaLaunchAttribute attrs[1];
72+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
73+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
74+
kernelConfig.attrs = attrs;
75+
kernelConfig.numAttrs = 1;
76+
77+
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...));
78+
}
79+
5880
bool getEnvUseUCXKvCache();
5981

6082
bool getEnvUseMPIKvCache();

cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace tensorrt_llm
2727
namespace kernels
2828
{
2929

30+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
31+
3032
// Quantize a contiguous shared-memory buffer containing elements of DType into NVFP4 with per-16-element FP8 scales.
3133
// Output layout (repeated per 16-element group per lane), followed by one global scale float:
3234
// [WARP_SIZE * 8 bytes packed e2m1 values] [WARP_SIZE * 1 byte E4M3 per-group scales] ... [global_scale (4 bytes)]
@@ -1069,6 +1071,9 @@ public:
10691071

10701072
int sendIndex = mPairInfo.channel;
10711073
uint32_t phaseParity = 0;
1074+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1075+
cudaGridDependencySynchronize();
1076+
#endif
10721077
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
10731078
{
10741079
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
@@ -1140,6 +1145,9 @@ public:
11401145
int recvIndex = mPairInfo.channel;
11411146
uint32_t phaseParity = 0;
11421147
bool needRelease = false;
1148+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1149+
cudaGridDependencySynchronize();
1150+
#endif
11431151
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
11441152
{
11451153
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
@@ -1459,7 +1467,8 @@ void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cu
14591467

14601468
dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
14611469
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
1462-
kernelFn<<<grid, block, totalDynamicShmSize, stream>>>(params, workspace, hasBasicFields);
1470+
launchWithPdlWhenEnabled(
1471+
"moeAllToAll", kernelFn, grid, block, totalDynamicShmSize, stream, params, workspace, hasBasicFields);
14631472
TLLM_CUDA_CHECK(cudaGetLastError());
14641473
}
14651474

cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cuda_runtime_api.h>
2121

2222
#include "tensorrt_llm/common/cudaUtils.h"
23+
#include "tensorrt_llm/common/envUtils.h"
2324
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
2425

2526
namespace tensorrt_llm

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cub/cub.cuh>
2020

2121
#include "tensorrt_llm/common/cudaUtils.h"
22+
#include "tensorrt_llm/common/envUtils.h"
2223
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
2324

2425
namespace cg = cooperative_groups;
@@ -28,6 +29,8 @@ namespace tensorrt_llm
2829
namespace kernels
2930
{
3031

32+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
33+
3134
int getOwnerDevice(unsigned long long int stepAndOwner)
3235
{
3336
return static_cast<int>(stepAndOwner & MoeLoadBalanceSingleLayerSignal::kDevice);
@@ -138,6 +141,9 @@ __global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int*
138141
TYPE oldExpertTokenCount = {0};
139142
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
140143
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);
144+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
145+
cudaGridDependencySynchronize();
146+
#endif
141147
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
142148
}
143149

@@ -177,6 +183,9 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertToke
177183
sharedExpertCount[i] = 0;
178184
}
179185
__syncthreads();
186+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
187+
cudaGridDependencySynchronize();
188+
#endif
180189
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < totalEltCount; idx += gridDim.x * blockDim.x)
181190
{
182191
int expertId = gatheredRawExpertIds[idx];
@@ -282,11 +291,10 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
282291
}
283292
dim3 gridDim(1);
284293
dim3 blockDim(threadCount);
285-
void* args[]
286-
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&localExpertTokenCount)};
287294
TLLM_CHECK_WITH_INFO(
288295
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
289-
TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
296+
launchWithPdlWhenEnabled(
297+
"zeroExpertTokenCount", kernelFunc, gridDim, blockDim, 0, stream, metaInfo, enabled, localExpertTokenCount);
290298
}
291299

292300
{
@@ -299,7 +307,7 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
299307
blockCount = smCount;
300308
}
301309
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
302-
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
310+
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
303311
metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
304312
}
305313
}
@@ -327,6 +335,10 @@ __global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo
327335

328336
int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;
329337

338+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
339+
cudaGridDependencySynchronize();
340+
#endif
341+
330342
for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
331343
{
332344
int tokenIdxBase = blockOffset + threadIdx.x;
@@ -501,6 +513,10 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
501513

502514
int expertIds[ITEM_PER_THREAD];
503515

516+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
517+
cudaGridDependencySynchronize();
518+
#endif
519+
504520
for (int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD; blockOffset < tokenCount * metaInfo.topK;
505521
blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
506522
{
@@ -586,14 +602,15 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
586602
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
587603
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
588604
{
605+
auto* kernelFn = moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>;
589606
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
590-
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
591-
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
592-
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
607+
launchWithPdlWhenEnabled("moeComputeRouteNoRedundant", kernelFn, blockCount, kThreadCount, dynamicShmSize,
608+
stream, metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
593609
}
594610
else
595611
{
596-
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
612+
auto* kernelFn = moeComputeRouteKernel<1024, kThreadCount, kEltPerThread>;
613+
launchWithPdlWhenEnabled("moeComputeRoute", kernelFn, blockCount, kThreadCount, dynamicShmSize, stream,
597614
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
598615
}
599616
}

cpp/tensorrt_llm/kernels/moePrepareKernels.cu

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ namespace tensorrt_llm::kernels
3030
namespace moe_prepare
3131
{
3232

33+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
34+
3335
__device__ __forceinline__ void st_release_sys_global(uint64_t volatile* ptr, uint64_t val)
3436
{
3537
asm volatile("st.release.sys.global.u64 [%0], %1;" ::"l"(ptr), "l"(val) : "memory");
@@ -110,6 +112,10 @@ __device__ __forceinline__ void computeCountAndSendStatics(int* experts, int tok
110112
int* localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank;
111113
int* localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank;
112114

115+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
116+
cudaGridDependencySynchronize();
117+
#endif
118+
113119
for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock)
114120
{
115121
int expertRankId = laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize;
@@ -163,6 +169,11 @@ __device__ __forceinline__ void recvCountAndStatics(int* recvIndiceWorkspace, in
163169

164170
CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
165171
int communicationCount = gatheredExpertStatics == nullptr ? 1 : expertCount + 1;
172+
173+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
174+
cudaGridDependencySynchronize();
175+
#endif
176+
166177
for (int i = rankTile.thread_rank(); i < communicationCount; i += THREADS_PER_PIPELINE)
167178
{
168179
int recvValue = counter.acquireValue(i);
@@ -218,6 +229,9 @@ __global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, i
218229
int count = endIndex - startIndex;
219230
int* localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank;
220231
int* localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank;
232+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
233+
cudaGridDependencySynchronize();
234+
#endif
221235
for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x)
222236
{
223237
gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx];
@@ -230,6 +244,9 @@ __global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, i
230244
int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1];
231245
int endIndex = recvCountsCumsum[targetRankId];
232246
int count = endIndex - startIndex;
247+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
248+
cudaGridDependencySynchronize();
249+
#endif
233250
for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x)
234251
{
235252
gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx;
@@ -249,6 +266,10 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
249266
int threadData = tid < rankCount ? inputOutputPtr[tid] : 0;
250267
__syncthreads();
251268

269+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
270+
cudaGridDependencySynchronize();
271+
#endif
272+
252273
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
253274
if (tid < rankCount)
254275
{
@@ -261,6 +282,9 @@ __global__ void memsetExpertIdsDevice(
261282
{
262283
int maxTokenCount = maxTokenCountPerRank * rankCount;
263284
int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1);
285+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
286+
cudaGridDependencySynchronize();
287+
#endif
264288
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i + totalRecvTokenCount * topK < maxTokenCount * topK;
265289
i += gridDim.x * blockDim.x)
266290
{
@@ -300,17 +324,20 @@ void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int*
300324
{
301325
kernelFn = computeCountAndIndiceDevice<2>;
302326
}
303-
kernelFn<<<grid, block, 0, stream>>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace,
304-
recvIndiceWorkspace, expertStatics, gatheredExpertStatics, workspace, tokenCount, maxTokenCountPerRank, topK,
305-
slotCount, expertCount, rankId, rankCount);
327+
328+
launchWithPdlWhenEnabled("computeCountAndIndice", kernelFn, grid, block, 0, stream, experts, sendCounts, recvCounts,
329+
sendIndiceWorkspace, backwardIndiceWorkspace, recvIndiceWorkspace, expertStatics, gatheredExpertStatics,
330+
workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, expertCount, rankId, rankCount);
306331
}
307332

308333
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream)
309334
{
310335
int block_size = CUMSUM_THREADS_PER_BLOCK;
311336
dim3 block(block_size);
312337
dim3 grid(2);
313-
computeCumsumDevice<<<grid, block, 0, stream>>>(sendCountsCumsum, recvCountsCumsum, rankId, rankCount);
338+
339+
launchWithPdlWhenEnabled("computeCumsum", computeCumsumDevice, grid, block, 0, stream, sendCountsCumsum,
340+
recvCountsCumsum, rankId, rankCount);
314341
}
315342

316343
void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice,
@@ -319,17 +346,22 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
319346
{
320347
dim3 block(512);
321348
dim3 grid(rankCount, 2);
322-
moveIndiceDevice<<<grid, block, 0, stream>>>(sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice,
323-
backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank);
349+
350+
launchWithPdlWhenEnabled("moveIndice", moveIndiceDevice, grid, block, 0, stream, sendCountsCumsum, recvCountsCumsum,
351+
sendIndice, gatherSendIndice, backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice,
352+
maxTokenCountPerRank);
324353
}
325354

326355
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
327356
int rankCount, cudaStream_t stream)
328357
{
329358
int smCount = tensorrt_llm::common::getMultiProcessorCount();
330359
int block_size = 256;
331-
memsetExpertIdsDevice<<<smCount, block_size, 0, stream>>>(
332-
expertIds, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
360+
dim3 block(block_size);
361+
dim3 grid(smCount);
362+
363+
launchWithPdlWhenEnabled("memsetExpertIds", memsetExpertIdsDevice, grid, block, 0, stream, expertIds,
364+
recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
333365
}
334366

335367
size_t getMoePrepareWorkspaceSize(int epSize)

cpp/tensorrt_llm/kernels/moePrepareKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <map>
2020

2121
#include "tensorrt_llm/common/cudaUtils.h"
22+
#include "tensorrt_llm/common/envUtils.h"
2223

2324
#define DEBUG_PIPELINE 0
2425

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
#include "KernelRunner.h"
2020
#include "tensorrt_llm/common/assert.h"
21-
#include "tensorrt_llm/common/envUtils.h"
2221
#include "trtllmGen_bmm_export/BatchedGemmInterface.h"
2322
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2423
// DO NOT include cudaUtils.h and logger.h before BatchedGemmInterface.h as it #undef TLLM_LOG_INFO and co.
2524
#include "tensorrt_llm/common/cudaUtils.h"
25+
#include "tensorrt_llm/common/envUtils.h"
2626
#include "tensorrt_llm/common/logger.h"
2727

2828
namespace tensorrt_llm

0 commit comments

Comments
 (0)