Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
*/

#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include <cstdint>
#include <cuda_runtime.h>
#include <optional>
#include <string>

Expand Down Expand Up @@ -55,6 +57,26 @@ int getEnvMmhaKernelBlockSize();
// Whether PDL is enabled.
bool getEnvEnablePDL();

template <typename KernelFn, typename... Args>
inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 grid, dim3 block, size_t dynamicShmSize,
cudaStream_t stream, Args&&... args)
{
TLLM_LOG_DEBUG("Enable PDL in %s", name);
cudaLaunchConfig_t kernelConfig;
kernelConfig.gridDim = grid;
kernelConfig.blockDim = block;
kernelConfig.dynamicSmemBytes = dynamicShmSize;
kernelConfig.stream = stream;

cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
kernelConfig.attrs = attrs;
kernelConfig.numAttrs = 1;

TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...));
}

bool getEnvUseUCXKvCache();

bool getEnvUseMPIKvCache();
Expand Down
13 changes: 12 additions & 1 deletion cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace tensorrt_llm
namespace kernels
{

using tensorrt_llm::common::launchWithPdlWhenEnabled;

// Quantize a contiguous shared-memory buffer containing elements of DType into NVFP4 with per-16-element FP8 scales.
// Output layout (repeated per 16-element group per lane), followed by one global scale float:
// [WARP_SIZE * 8 bytes packed e2m1 values] [WARP_SIZE * 1 byte E4M3 per-group scales] ... [global_scale (4 bytes)]
Expand Down Expand Up @@ -1069,6 +1071,10 @@ public:

int sendIndex = mPairInfo.channel;
uint32_t phaseParity = 0;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
{
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
Expand Down Expand Up @@ -1140,6 +1146,10 @@ public:
int recvIndex = mPairInfo.channel;
uint32_t phaseParity = 0;
bool needRelease = false;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
{
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
Expand Down Expand Up @@ -1459,7 +1469,8 @@ void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cu

dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
kernelFn<<<grid, block, totalDynamicShmSize, stream>>>(params, workspace, hasBasicFields);
launchWithPdlWhenEnabled(
"moeAllToAll", kernelFn, grid, block, totalDynamicShmSize, stream, params, workspace, hasBasicFields);
TLLM_CUDA_CHECK(cudaGetLastError());
}

Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cuda_runtime_api.h>

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"

namespace tensorrt_llm
Expand Down
78 changes: 62 additions & 16 deletions cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cub/cub.cuh>

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"

namespace cg = cooperative_groups;
Expand All @@ -28,6 +29,8 @@ namespace tensorrt_llm
namespace kernels
{

using tensorrt_llm::common::launchWithPdlWhenEnabled;

int getOwnerDevice(unsigned long long int stepAndOwner)
{
return static_cast<int>(stepAndOwner & MoeLoadBalanceSingleLayerSignal::kDevice);
Expand Down Expand Up @@ -71,6 +74,11 @@ __device__ __forceinline__ void moeWaitSignalForGpuStageFunc(MoeLoadBalanceSingl

__global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal* signal, int* enabled)
{

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
if (threadIdx.x == 0 and blockIdx.x == 0)
{
moeWaitSignalForGpuStageFunc(signal, enabled);
Expand All @@ -79,6 +87,11 @@ __global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal*

__global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* signal)
{

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
if (threadIdx.x == 0 and blockIdx.x == 0)
{
unsigned long long int loaded = signal->stepAndOwner;
Expand All @@ -91,7 +104,8 @@ __global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* s

void moeWaitSignalForGpuStageDevice(MoeLoadBalanceSingleLayerSignal* signal, int* enabled, cudaStream_t stream)
{
moeWaitSignalForGpuStageKernel<<<1, 1, 0, stream>>>(signal, enabled);
launchWithPdlWhenEnabled(
"moeWaitSignalForGpuStage", moeWaitSignalForGpuStageKernel, 1, 1, 0, stream, signal, enabled);
}

void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, int* enabled)
Expand Down Expand Up @@ -119,7 +133,7 @@ void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, in

void moeSetSignalForCpuStageDevice(MoeLoadBalanceSingleLayerSignal* signal, cudaStream_t stream)
{
moeSetSignalForCpuStageKernel<<<1, 1, 0, stream>>>(signal);
launchWithPdlWhenEnabled("moeSetSignalForCpuStage", moeSetSignalForCpuStageKernel, 1, 1, 0, stream, signal);
}

void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal)
Expand All @@ -138,6 +152,10 @@ __global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int*
TYPE oldExpertTokenCount = {0};
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
}

Expand All @@ -149,6 +167,10 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
return;
}
TYPE oldExpertTokenCount = {0};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
if (blockIdx.x > 0)
{
int* oldExpertTokenCountPtr = expertTokenCount + metaInfo.expertCount * (blockIdx.x - 1);
Expand Down Expand Up @@ -177,6 +199,10 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertToke
sharedExpertCount[i] = 0;
}
__syncthreads();
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < totalEltCount; idx += gridDim.x * blockDim.x)
{
int expertId = gatheredRawExpertIds[idx];
Expand All @@ -200,6 +226,10 @@ __global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadB
return;
}
int expertIdx = blockIdx.x * blockDim.x + threadIdx.x;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
int expertTokenCount = expertTokenCountPtr[expertIdx];
float* loadFactor = statisticInfo.expertLoadFactor;
loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
Expand Down Expand Up @@ -232,6 +262,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&expertTokenCount)};
TLLM_CHECK_WITH_INFO(
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
// TODO: add PDL support with cooperative launch
TLLM_CUDA_CHECK(cudaLaunchCooperativeKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
}

Expand All @@ -245,7 +276,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
blockCount = smCount;
}
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
metaInfo, statisticInfo.expertTokenCount, totalEltCount, enabled, gatheredRawExpertIds);
}

Expand All @@ -254,7 +285,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
// only last stage need update load factor.
int threadCount = 128;
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
launchWithPdlWhenEnabled("updateLoadFactor", updateLoadFactorKernel, blockCount, threadCount, 0, stream,
metaInfo, statisticInfo, statisticInfo.expertTokenCount, enabled);
}
}
Expand Down Expand Up @@ -282,11 +313,10 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
}
dim3 gridDim(1);
dim3 blockDim(threadCount);
void* args[]
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&localExpertTokenCount)};
TLLM_CHECK_WITH_INFO(
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
launchWithPdlWhenEnabled(
"zeroExpertTokenCount", kernelFunc, gridDim, blockDim, 0, stream, metaInfo, enabled, localExpertTokenCount);
}

{
Expand All @@ -299,7 +329,7 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
blockCount = smCount;
}
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
}
}
Expand All @@ -309,8 +339,8 @@ void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBala
{
int threadCount = 128;
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
metaInfo, statisticInfo, globalExpertTokenCount, enabled);
launchWithPdlWhenEnabled("updateLoadFactor", updateLoadFactorKernel, blockCount, threadCount, 0, stream, metaInfo,
statisticInfo, globalExpertTokenCount, enabled);
}

template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
Expand All @@ -320,13 +350,18 @@ __global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
int expertIds[ITEM_PER_THREAD];
int slotIds[ITEM_PER_THREAD];

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif

for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
{
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
}

int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;

for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
{
int tokenIdxBase = blockOffset + threadIdx.x;
Expand Down Expand Up @@ -379,6 +414,12 @@ __global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacem

__shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
__shared__ int sharedExpertCount[MAX_EXPERT_COUNT];

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif

for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
{
int replicaCount = placementInfo.expertReplicaCount[expertIdx];
Expand Down Expand Up @@ -484,6 +525,11 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
__shared__ int sharedSortedExpertId[THREAD_COUNT * ITEM_PER_THREAD];
__shared__ int sharedExpertStartThread[MAX_EXPERT_COUNT];

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif

for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
{
sharedExpertTokenCount[expertIdx] = 0;
Expand All @@ -500,7 +546,6 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
__syncthreads();

int expertIds[ITEM_PER_THREAD];

for (int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD; blockOffset < tokenCount * metaInfo.topK;
blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
{
Expand Down Expand Up @@ -586,14 +631,15 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
{
auto* kernelFn = moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>;
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
launchWithPdlWhenEnabled("moeComputeRouteNoRedundant", kernelFn, blockCount, kThreadCount, dynamicShmSize,
stream, metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
}
else
{
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
auto* kernelFn = moeComputeRouteKernel<1024, kThreadCount, kEltPerThread>;
launchWithPdlWhenEnabled("moeComputeRoute", kernelFn, blockCount, kThreadCount, dynamicShmSize, stream,
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
}
}
Expand Down
Loading