From e8d7e2b532e1a0984d2e80f221fc864c4e15f726 Mon Sep 17 00:00:00 2001 From: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Date: Wed, 1 Oct 2025 06:00:56 -0700 Subject: [PATCH 1/4] Update the routing for TRTLLMGEN to support kimi k2 and qwen Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> --- csrc/trtllm_fused_moe_routing_deepseek.cu | 453 +++++++++++------- csrc/trtllm_fused_moe_routing_llama4.cu | 227 +++++---- csrc/trtllm_fused_moe_routing_renormalize.cu | 329 ++++++++++--- .../flashinfer/trtllm/fused_moe/DevKernel.h | 84 ++-- .../trtllm/fused_moe/RoutingKernel.cuh | 154 +++--- .../trtllm/fused_moe/RoutingKernel.h | 54 ++- .../trtllm/fused_moe/RoutingKernelTopK.cuh | 62 ++- tests/moe/test_trtllm_gen_fused_moe.py | 4 +- 8 files changed, 911 insertions(+), 456 deletions(-) diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 4ea1ba178e..c14e591326 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -23,34 +23,25 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumThreads = 384; -static constexpr int NumWarps = NumThreads / WarpSize; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; static constexpr int NumTopGroupScores = 2; static constexpr int MaxNumTopExperts = 8; -static constexpr int MaxNumTopGroupsDefault = 16; - -__host__ __device__ int getMaxNumTopGroups(const bool useGroups, const int numExperts) { - if (useGroups || numExperts <= 256) { - return 4; - } else { - return 16; - } -} +static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxNumGroups = 8; template __global__ void routingMainKernel(KernelParams params) { // declare types using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; - static constexpr int NumWarps = NumThreads / WarpSize; - int MaxNumTopGroups = getMaxNumTopGroups(KernelParams::UseGroups, params.mNumExperts); // declare shared memory structure // number of experts is bounded by number of threads - __shared__ float __attribute((aligned(128))) smemScoreSigmoid[NumThreads]; - __shared__ float __attribute((aligned(128))) smemScoreBias[NumThreads]; + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; // number of expert groups is bounded by number of warps - __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; + __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; // needed for warp reduce auto block = cg::this_thread_block(); @@ -83,8 +74,8 @@ __global__ void routingMainKernel(KernelParams params) { // initialize the mPtrExpertCounts if (params.mPtrExpertCounts) { - int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x; - int32_t globalThreadStride = gridDim.x * NumThreads; + int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; + int32_t globalThreadStride = gridDim.x * blockDim.x; int32_t expertCountsNum = 2 * params.mNumExperts; initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); } @@ -97,114 +88,165 @@ __global__ void routingMainKernel(KernelParams params) { } #endif - // get our assigned thread score; each warp represents one expert group - float score = - expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; - // get the sigmoid score - // note that for invalid values, we simply use a negative value: - // sigmoig scores are always strictly positive - auto scoreSigmoid = sigmoid_accurate(score); - // write the sigmoid score to shared for later use - if (expertSelected) { - smemScoreSigmoid[threadExpert] = scoreSigmoid; - } - // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value - auto scoreBias = float{scoreSigmoid + float{biasVal}}; - - if (expertSelected) { - smemScoreBias[threadExpert] = scoreBias; - } - - // registers for top group score reduction - float topExpGroupScores[NumTopGroupScores]; - [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; - float topGroups[MaxNumTopGroupsDefault]; // bound of params.mNumLimitedGroups - int32_t topGroupIdx[MaxNumTopGroupsDefault]; - float expertScoreGroup[MaxNumTopGroupsDefault]; - int32_t expertIdxGroup[MaxNumTopGroupsDefault]; - float topScores[MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[MaxNumTopExperts]; - - if constexpr (KernelParams::UseGroups) { - topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, - /* minValue */ invalidScoreFloat); + if (params.mPtrScores != nullptr) { + // get our assigned thread score; each warp represents one expert group + float score = + expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; + // get the sigmoid score + // note that for invalid values, we simply use a negative value: + // sigmoig scores are always strictly positive + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; - // get the final group score and write it to shared - if (cute::elect_one_sync()) { - auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; - smemGroupScores[warpIdx] = groupScore; + if (expertSelected) { + smemScoreBias[threadExpert] = scoreBias; } - } - // make group scores available to all warps - __syncthreads(); + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[MaxNumTopExperts]; - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - if (warpIdx == 0) { - // a single warp performs the selection of top groups, and goes on to select the final experts if constexpr (KernelParams::UseGroups) { - float groupScore = - laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; - - topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, /* minValue */ invalidScoreFloat); + // get the final group score and write it to shared + if (cute::elect_one_sync()) { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); - // final expert selection: get relevant indexes and scores from shared + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + if constexpr (KernelParams::UseGroups) { // a single warp performs the selection of top groups, + // and goes on to select the final experts + if (warpIdx == 0) { + float groupScore = + laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; + topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + // final expert selection: get relevant indexes and scores from shared #pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups - auto groupIdx = topGroupIdx[ii]; - expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; - // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. - // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, - // thus groupIdx <= params.mNumExpertGroups - 1 => - // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup - // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, - // so the access is safe here - expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected - ? smemScoreBias[expertIdxGroup[ii]] - : invalidScoreFloat; + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups + auto groupIdx = topGroupIdx[ii]; + expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; + // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. + // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, + // thus groupIdx <= params.mNumExpertGroups - 1 => + // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - + // params.mNumExpertsPerGroup + // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, + // so the access is safe here + expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); } - } else { + } else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) { // without groups, each thread just takes `MaxNumTopGroups` experts + int constexpr NumExpertWarps = + (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WarpSize * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts + ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = - expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + if (laneIdx < params.mTopK) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } } - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - - // determine our lane's expert index and write to output - int32_t expertIdx = 0; + __syncthreads(); + if (warpIdx == 0) { + int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; + float intermidiateScore[NumInterTopKPerThread]; + int32_t intermidiateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { + int ii = i / WarpSize; + if (i < NumInterTopK) { + intermidiateScore[ii] = smemInterTopScores[i]; + intermidiateExpert[ii] = smemInterTopExperts[i]; + } else { + intermidiateScore[ii] = invalidScoreFloat; + intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; + } + } + topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } else { + if (warpIdx == 0) { + // without groups, each thread just takes `MaxNumTopGroups` experts #pragma unroll - for (int ii = 0; ii < params.mTopK; ++ii) { // bound of params.mTopK - expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = + expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } } - // determine whether our expert is local to this GPU - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - - float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; - auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); - auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; - // write expert idx out already - auto idxTopK = blockIdx.x * params.mTopK + laneIdx; - if (laneIdx < params.mTopK && params.mPtrExpertIdx != nullptr) { - PackedScoreIdx packedScore{static_cast(finalScore), - static_cast(expertIdx)}; - params.mPtrExpertIdx[idxTopK] = packedScore; - } + if (warpIdx == 0) { + // determine our lane's expert index and write to output + int32_t expertIdx = 0; +#pragma unroll + for (int ii = 0; ii < params.mTopK; ++ii) { // bound of params.mTopK + expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; + } + // determine whether our expert is local to this GPU + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + + float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; + + // write expert idx out already + auto idxTopK = blockIdx.x * params.mTopK + laneIdx; + if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) { + PackedScoreIdx packedScore{static_cast(finalScore), + static_cast(expertIdx)}; + params.mPtrTopKPacked[idxTopK] = packedScore; + } - if (laneIdx < params.mTopK && params.mPtrExpertWeights != nullptr) { - params.mPtrExpertWeights[idxTopK] = finalScore; + if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && + params.mPtrTopKIds == nullptr) { + params.mPtrTopKWeights[idxTopK] = finalScore; + } } } } @@ -214,7 +256,8 @@ __global__ void routingMainKernel(KernelParams params) { template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) - routingIndicesClusterKernel(KernelParams params) { + __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesClusterKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); @@ -225,8 +268,8 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } - - routingPermutation(params, nullptr, warpIdx, clusterBlockRank); } #else @@ -239,9 +282,10 @@ __global__ void routingIndicesClusterKernel(KernelParams params) { template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void routingIndicesCoopKernel(KernelParams params) { - static constexpr int NumWarps = NumThreads / WarpSize; +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesCoopKernel(KernelParams params) { // number of experts is bounded by number of threads + int constexpr NumThreads = KernelParams::MaxNumExperts; __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; // needed for the exclusive sum of token offsets @@ -283,7 +327,8 @@ __global__ void routingIndicesCoopKernel(KernelParams params) { // Define a lambda to avoid code duplication in both branches. auto loopBody = [&](int ii, int expandedIdx) { - int32_t expertIdx = params.mPtrExpertIdx[expandedIdx].idx; + int32_t expertIdx = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] + : params.mPtrTopKPacked[expandedIdx].idx; expertIndexes[ii] = expertIdx; // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; @@ -366,7 +411,7 @@ __global__ void routingIndicesCoopKernel(KernelParams params) { const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); // write out padded count - if (gridBlockIdx == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { + if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -416,60 +461,95 @@ __global__ void routingIndicesCoopKernel(KernelParams params) { } #endif +int constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else if (numExperts <= NumDeepseekExperts) { + return NumDeepseekExperts; + } else if (numExperts <= NumKimiK2Experts) { + return NumKimiK2Experts; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// +#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \ + numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, topk::MaxNumExpertsUnit); \ + } else if (data.mNumExperts <= NumDeepseekExperts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \ + numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumDeepseekExperts); \ + } else if (data.mNumExperts <= NumKimiK2Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \ + numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumKimiK2Experts); \ + } else { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + void runImpl(Data& data, void* stream) { - static constexpr int NumWarps = NumThreads / WarpSize; - int MaxNumTopGroups = getMaxNumTopGroups(data.mNumExpertGroups > 1, data.mNumExperts); - - // Validate that the template parameter matches the data - // FLASHINFER_CHECK(data.mNumExperts == NumExperts, "DeepSeek routing kernel expects exactly ", - // NumExperts, " experts, got ", data.mNumExperts); - FLASHINFER_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr || - data.mPtrExpertWeights != nullptr, - "Routing kernel requires at least one output parameter"); + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "DeepSeek routing."); + } if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) - FLASHINFER_CHECK(data.mPtrExpertIdx != nullptr && data.mPtrPermutedIdxSize, - "If permuted index is required, `mPtrExpertIdx` is also required"); + FLASHINFER_CHECK( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, + "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); FLASHINFER_CHECK(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, - "Routing kernel expects <= ", MaxNumTopGroups, " top groups, got ", + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", + "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, data.mTopK); - FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got ", + FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK); FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, - "Routing kernel expects top K * top groups <= warp size (for now), got ", - data.mTopK, " * ", data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects ", - MaxNumTopExperts, " to be at most #experts ", data.mNumExperts); - FLASHINFER_CHECK(data.mNumExperts <= NumThreads, "Routing kernel expects #experts ", - data.mNumExperts, " <= #threads ", NumThreads); + "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", + data.mTopK, data.mNumLimitedGroups); + FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, + "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, + data.mNumExperts); + FLASHINFER_CHECK(data.mNumExperts <= NumKimiK2Experts, + "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, + NumKimiK2Experts); FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, - "Routing kernel expects top groups ", data.mNumLimitedGroups, - " to be limited by #expert groups ", data.mNumExpertGroups); + "Routing kernel expects top groups %d to be limited by #expert groups %d", + data.mNumLimitedGroups, data.mNumExpertGroups); if (data.mNumExpertGroups > 1) { - FLASHINFER_CHECK(data.mNumExpertGroups <= NumWarps, "Routing kernel expects #experts groups ", - data.mNumExpertGroups, " to be <= #warps ", NumWarps); + FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups, + "Routing kernel expects #experts groups %d to be <= #warps %d", + data.mNumExpertGroups, MaxNumGroups); FLASHINFER_CHECK(data.mNumExperts % data.mNumExpertGroups == 0, - "Routing kernel expects #experts ", data.mNumExperts, - " to be a multiple of #expert groups ", data.mNumExpertGroups); - FLASHINFER_CHECK(data.mNumExperts / data.mNumExpertGroups <= WarpSize, - "Routing kernel expects #experts per group <= warp size, got ", - data.mNumExperts / data.mNumExpertGroups); + "Routing kernel expects #experts %d to be a multiple of #expert groups %d", + data.mNumExperts, data.mNumExpertGroups); + FLASHINFER_CHECK( + data.mNumExperts / data.mNumExpertGroups <= WarpSize, + "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", + data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); } else { - FLASHINFER_CHECK(data.mNumExperts <= WarpSize * MaxNumTopGroups, - "Routing kernel expects #experts ", data.mNumExperts, - " <= WarpSize * MaxNumTopGroups ", WarpSize * MaxNumTopGroups); - FLASHINFER_CHECK(data.mTopK <= NumWarps, "Routing kernel expects top K ", data.mTopK, - " to be <= #warps ", NumWarps); + FLASHINFER_CHECK(data.mTopK <= topk::MaxNumTopK, + "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, + topk::MaxNumTopK); } - FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts ", data.mNumExperts, - " to be a multiple of 4."); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2); + int const numBlocks = data.mNumTokens; + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); bool const useSingleCluster = data.mNumTokens <= 1024; if (!useSingleCluster) { @@ -495,30 +575,39 @@ void runImpl(Data& data, void* stream) { int const numBlocksCoop = 128; // Maximum number of tokens supported by the kernel using a cooperative launch. - int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK; - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + if (data.mPtrTopKIds == nullptr) { + int const numThreadsMain = + data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + } else { + // Reset the global histograms. + LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/false); + } if (data.mPtrPermutedIdxSize != nullptr) { if (useSingleCluster) { - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, - NumBlocksPerCluster, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, + NumBlocksPerCluster, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } else if (data.mNumTokens <= maxTokensCoop) { - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, - NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } else { const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; - - const int32_t histogramEltsPerBlock = 8 * NumThreads; - const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreads; + const int32_t histogramEltsPerBlock = 8 * numThreadsHist; + const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; // Limit grid size (both kernels use a grid-stride loop). const int32_t maxNumBlocks = 1024; @@ -528,20 +617,22 @@ void runImpl(Data& data, void* stream) { int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, - numBlocksHistogram, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, - numBlocksOffsets, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, + numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } } } +//////////////////////////////////////////////////////////////////////////////////////////////////// + void run(Data& data, void* stream) { runImpl(data, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/trtllm_fused_moe_routing_llama4.cu b/csrc/trtllm_fused_moe_routing_llama4.cu index 4b9d8da130..ebdd0b8720 100644 --- a/csrc/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/trtllm_fused_moe_routing_llama4.cu @@ -25,7 +25,7 @@ namespace routingLlama4 { static constexpr int NumThreads = 1024; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int MaxNumTopExperts = 1; -static constexpr int MaxNumExperts = 128; +static constexpr int NumExpertsLimit = 128; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; static constexpr int WarpKernelSmemStride = 33; @@ -85,7 +85,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam __shared__ int32_t __attribute(( aligned(128))) smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride]; static_assert(WarpKernelSmemStride == WarpSize + 1); - static_assert(MaxNumExperts / sizeof(int32_t) <= WarpSize); + static_assert(KernelParams::MaxNumExperts / sizeof(int32_t) <= WarpSize); // values needed for the top-1 reduction, if required InputT minScore = InputT{-INFINITY}; @@ -106,7 +106,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } #endif - if (params.mPtrScores != nullptr) { + if (params.mPtrScores != nullptr && params.mPtrTopKIds == nullptr) { // if we use `mPtrScores` as input, we need to perform the top-1 reduction // for each token, we load the scores then use `reduceTopK` for this. // each thread works on 4 experts, so a local reduction is done before @@ -128,29 +128,40 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam expertTokenCount; // we also compute the final score here and write it out if required auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; - if (params.mPtrExpertWeights != nullptr) { - params.mPtrExpertWeights[tokenIdx] = finalScore; + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[tokenIdx] = finalScore; } } } } else { - // if we do not have `mPtrScores` as input, we expect that `mPtrExpertWeights` - // contains the top-1 packed score and index already. - // Each thread represents a token here, and we extract the relevant score - // The assumption is that the #tokens is limited by warp-size + // if we do not have `mPtrScores` as input, we expect that `params.mPtrTopKPacked` or + // `params.mPtrTopKIds` and `params.mPtrTopKWeights` contains the top-1 packed score and index + // already. Each thread represents a token here, and we extract the relevant score The + // assumption is that the #tokens is limited by warp-size static_assert(WarpKernelMaxNumTokens <= WarpSize); - TypePacked scoreIdx = - threadIdx.x < params.mNumTokens ? params.mPtrExpertIdx[threadIdx.x] : TypePacked{}; + TypePacked scoreIdx = TypePacked{}; + if (params.mPtrTopKIds != nullptr) { + if (threadIdx.x < params.mNumTokens) { + scoreIdx = TypePacked{static_cast(params.mPtrTopKWeights[threadIdx.x]), + static_cast(params.mPtrTopKIds[threadIdx.x])}; + } + } else { + if (threadIdx.x < params.mNumTokens) { + scoreIdx = TypePacked{static_cast(params.mPtrTopKPacked[threadIdx.x].score), + static_cast(params.mPtrTopKPacked[threadIdx.x].idx)}; + if (params.mPtrTopKWeights != nullptr) { + // we also compute the final score here and write it out if required + auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; + params.mPtrTopKWeights[threadIdx.x] = finalScore; + } + } + } + int32_t expertTokenCount = 0; setBits(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread); if (threadIdx.x < params.mNumTokens) { smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount; } - // we also compute the final score here and write it out if required - auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; - if (params.mPtrExpertWeights != nullptr && threadIdx.x < params.mNumTokens) { - params.mPtrExpertWeights[threadIdx.x] = finalScore; - } } // make the full table available to all threads @@ -212,12 +223,10 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 // we can trigger the next kernel at this point if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif // at this point, all values for offsets are ready, except the final offsets @@ -298,7 +307,13 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu cudaGridDependencySynchronize(); } - if (params.mPtrScores != nullptr) { + if (params.mPtrTopKIds != nullptr) { + if (validToken) { + TypePacked packedScore{static_cast(params.mPtrTopKWeights[warpTokenIdx]), + static_cast(params.mPtrTopKIds[warpTokenIdx])}; + smemPackedScoreIdx[warpIdx] = packedScore; + } + } else if (params.mPtrScores != nullptr) { // in this case, each warp represents a token // we then exchange all token max scores, s.t. afterwards, each thread // represents a token @@ -306,23 +321,34 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu int32_t warpMaxExpertIdx[MaxNumTopExperts]; if (validToken) { - routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, - laneIdx, params.mNumExperts, - params.mPtrScores + scoreOffset); + routingTopKExperts( + warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, + params.mPtrScores + scoreOffset); if (cute::elect_one_sync()) { auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; smemPackedScoreIdx[warpIdx] = packedScore; } } - // make packed scores available to all threads in cluster - __cluster_barrier_arrive(); - __cluster_barrier_wait(); + } else { + if (validToken) { + smemPackedScoreIdx[warpIdx] = params.mPtrTopKPacked[warpTokenIdx]; + } } - routingPermutation(params, smemPackedScoreIdx, warpIdx, - clusterBlockRank); + // make packed scores available to all threads in cluster + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrTopKIds != nullptr || params.mPtrScores != nullptr) { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } else { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } } #else __global__ void routingIndicesClusterKernel(KernelParams params) { @@ -334,27 +360,27 @@ __global__ void routingIndicesClusterKernel(KernelParams params) { // this kernel is needed in case we have scores as input for the histogram kernel template -__global__ void __launch_bounds__(NumThreadsHist) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHistogramScoresKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; - static constexpr int VecSize = MaxNumExperts / WarpSize; + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; // we assume that #experts is a multiple of 4, so VecSize must be 4. static_assert(VecSize == 4); int32_t const laneIdx = cutlass::arch::LaneId(); int32_t const warpIdx = threadIdx.x / WarpSize; - int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx; - int32_t const globalWarpStride = gridDim.x * NumWarpsHist; + int32_t const globalWarpIdx = blockIdx.x * KernelParams::MaxNumExperts / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * KernelParams::MaxNumExperts / WarpSize; InputT minScore = InputT{-INFINITY}; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); // initialize the mPtrExpertCounts int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x; - int32_t globalThreadStride = gridDim.x * NumThreads; + int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; + int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -372,67 +398,98 @@ __global__ void __launch_bounds__(NumThreadsHist) int32_t warpMaxExpertIdx[MaxNumTopExperts]; InputT warpMaxScore[MaxNumTopExperts]; - routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, - laneIdx, params.mNumExperts, - params.mPtrScores + scoreOffset); + if (params.mPtrTopKIds != nullptr) { + if (laneIdx < MaxNumTopExperts) { + warpMaxExpertIdx[laneIdx] = params.mPtrTopKIds[tokenIdx]; + warpMaxScore[laneIdx] = static_cast(params.mPtrTopKWeights[tokenIdx]); + } + } else if (params.mPtrScores != nullptr) { + routingTopKExperts( + warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, + params.mPtrScores + scoreOffset); + } else { + if (laneIdx < MaxNumTopExperts) { + warpMaxExpertIdx[laneIdx] = params.mPtrTopKPacked[tokenIdx].idx; + warpMaxScore[laneIdx] = params.mPtrTopKPacked[tokenIdx].score; + } + } if (cute::elect_one_sync()) { auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; - params.mPtrExpertIdx[tokenIdx] = packedScore; + params.mPtrTopKPacked[tokenIdx] = packedScore; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// +int constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void runImpl(Data const& data, void* stream) { - FLASHINFER_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr, - "Routing kernel requires at least one input parameter"); + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK( + data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + } FLASHINFER_CHECK( data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", + "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, data.mTopK); - FLASHINFER_CHECK(data.mNumExperts <= MaxNumExperts, "Routing kernel expects #experts ", - data.mNumExperts, " to be at most max #experts ", MaxNumExperts); - static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads"); - static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads"); - FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts ", data.mNumExperts, - " to be a multiple of 4."); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", + FLASHINFER_CHECK(data.mNumExperts <= NumExpertsLimit, + "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, + NumExpertsLimit); + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2); bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || data.mNumTokens < WarpKernelMaxNumTokens; bool const useSingleCluster = - data.mNumTokens <= - (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster); + data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + ? MaxNumTokensSingleClusterScores + : MaxNumTokensSingleCluster); if (!useSingleCluster) { - FLASHINFER_CHECK(data.mPtrExpertIdx != nullptr, - "When #tokens is large, `mPtrExpertIdx` is a required input."); + FLASHINFER_CHECK( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), + "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); if (useSingleWarp) { - LAUNCH_ROUTING(data, - /*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize, - /*smemSize=*/0, // No dynamic smem - stream); + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize, + /*smemSize=*/0, // No dynamic smem + stream); } else if (useSingleCluster) { - LAUNCH_ROUTING(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, - NumThreads, - /*smemSize=*/0, // No dynamic smem - stream); + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream); } else { const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK; - const uint32_t histogramEltsPerBlock = 8 * NumThreadsHist; - const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist; + const uint32_t histogramEltsPerBlock = 8 * numThreadsHist; + const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; // Limit grid size (all kernels use a grid-stride loop). const uint32_t maxNumBlocks = 1024; @@ -442,34 +499,36 @@ void runImpl(Data const& data, void* stream) { int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - if (data.mPtrScores != nullptr) { - LAUNCH_ROUTING(data, - /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, - NumThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); + if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); } else { // Reset the global histograms. - CHECK_CUDA_ERROR(cudaMemsetAsync(data.mPtrExpertCounts, 0, - static_cast(2 * data.mNumExperts) * sizeof(int32_t), - (cudaStream_t)stream)); + LAUNCH_ROUTING_LLAMA4(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); } - LAUNCH_ROUTING(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, - NumThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); - LAUNCH_ROUTING(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, - NumThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); } } void run(Data const& data, void* stream) { - FLASHINFER_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr, - "Routing kernel requires at least one input parameter"); + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); FLASHINFER_CHECK( data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 87611bc18b..63c2a3a72e 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -71,6 +71,150 @@ __forceinline__ __device__ void routingTopKExperts( } } +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesBlockKernel(KernelParams params) { + // types used in this kernel + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = std::conditional_t; + using TypePacked = PackedScoreIdx; + int constexpr MaxNumExperts = KernelParams::MaxNumExperts; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const expert = threadIdx.x; + auto scoreOffset = warpIdx * params.mNumExperts; + bool validToken = warpIdx < params.mNumTokens; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts; + __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; + __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x) { + smemOffset[i] = int8_t{-1}; + smemKIdx[i] = int8_t{-1}; + } + __syncthreads(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrTopKIds != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; + smemKIdx[offset] = static_cast(laneIdx); + } + } + } else if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + BaseType score[VecSize]; + int32_t idx[VecSize]; + + BaseType warpTopKScore[MaxNumTopExperts]; + int32_t warpTopKExpertIdx[MaxNumTopExperts]; + + BaseType minScore = BaseType{-INFINITY}; + if (validToken) { + routingTopKExperts( + warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, + params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx]; + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + OutputT{warpTopKScore[laneIdx]}; + } + } + } // end if (validToken) + } + __syncthreads(); + + // set local experts + auto localExpertIdx = expert - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < params.mNumLocalExperts && + (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + // Get the count of each expert and the offset for each token + int accExpertCount = 0; + + if (isLocalExpert) { + int offset = expert; + for (int j = 0; j < BlockKernelMaxNumTokens; j++) { + if (smemKIdx[offset] >= 0) { + smemOffset[offset] = static_cast(accExpertCount); + accExpertCount++; + } + offset += MaxNumExperts; + } + } + __syncthreads(); + // Get the number of CTAs and the offset for each CTA + const int32_t numCta = divUpLog2(accExpertCount, params.mPaddingLog2); + int32_t ctaOffset = 0; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + int32_t expertScanCounts = 0; + Scan(tempStorage) + .ExclusiveSum(divUpMulLog2(accExpertCount, params.mPaddingLog2), expertScanCounts); + __syncthreads(); + + if (isLocalExpert) { + for (int cta = 0; cta < numCta; ++cta) { + const int32_t localExpertIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = + min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), + mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount); + } + } + + // at this point, we can write out padded count + if (threadIdx.x == 0) { + const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // we can trigger the next kernel at this point + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + + for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { + int offset = tokenIdx * MaxNumExperts + threadIdx.x; + if (smemKIdx[offset] >= 0) { + int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset]; + int const offsetWithinExpert = static_cast(smemOffset[offset]); + int const offsetForExpert = expertScanCounts; + int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t{-1}; + + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + if (isLocalExpert) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + } +} + template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) @@ -82,9 +226,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu using BaseType = std::conditional_t; using TypePacked = PackedScoreIdx; - static constexpr int VecSize = MaxNumExperts / WarpSize; - // we assume that #experts is a multiple of 4, so VecSize must be 4. - static_assert(VecSize == 4); + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts]; @@ -125,27 +267,17 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; } } // end if (validToken) + } + + // make packed scores available to all threads in cluster + __cluster_barrier_arrive(); + __cluster_barrier_wait(); - // make packed scores available to all threads in cluster - __cluster_barrier_arrive(); - __cluster_barrier_wait(); + if (params.mPtrScores != nullptr) { routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); - } - // else { - // if (validToken && laneIdx < params.mTopK) { - // // auto score = reinterpret_cast(params.mPtrExpertWeights)[warpTokenIdx * - // params.mNumExperts + laneIdx]; - // // auto expertIdx = reinterpret_cast(params.mPtrExpertIdx)[warpTokenIdx * - // params.mTopK + laneIdx]; TypePacked packed = - // reinterpret_cast(params.mPtrExpertIdx)[warpTokenIdx * params.mTopK + laneIdx]; - // smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] = packed; - // } - // __cluster_barrier_arrive(); - // __cluster_barrier_wait(); - // } - else { + } else { routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); @@ -161,20 +293,18 @@ __global__ void __launch_bounds__(NumThreads) // this kernel is needed in case we have scores as input for the histogram kernel template -__global__ void __launch_bounds__(NumThreadsHist) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHistogramScoresKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using BaseType = std::conditional_t; - static constexpr int VecSize = MaxNumExperts / WarpSize; - // we assume that #experts is a multiple of 4, so VecSize must be 4. - static_assert(VecSize == 4); + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; int32_t const laneIdx = cutlass::arch::LaneId(); int32_t const warpIdx = threadIdx.x / WarpSize; - int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx; - int32_t const globalWarpStride = gridDim.x * NumWarpsHist; + int32_t const globalWarpIdx = blockIdx.x * KernelParams::MaxNumExperts / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * KernelParams::MaxNumExperts / WarpSize; BaseType minScore = BaseType{-INFINITY}; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -188,8 +318,8 @@ __global__ void __launch_bounds__(NumThreadsHist) // initialize the mPtrExpertCounts int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x; - int32_t globalThreadStride = gridDim.x * NumThreads; + int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; + int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -216,52 +346,94 @@ __global__ void __launch_bounds__(NumThreadsHist) if (laneIdx < params.mTopK) { PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; - params.mPtrExpertIdx[tokenIdx * params.mTopK + laneIdx] = packedScore; + params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; } } } +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int32_t constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else if (numExperts <= NumExpertsLimit) { + return NumExpertsLimit; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// -void runImpl(Data const& data, void* stream) { - TVM_FFI_ICHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr) - << "Routing kernel requires at least one input parameter"; - TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) - << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; - TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) - << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; - TVM_FFI_ICHECK_LE(data.mNumExperts, MaxNumExperts) - << "Routing kernel expects #experts " << data.mNumExperts << " to be at most max #experts " - << MaxNumExperts; - static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads"); - static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads"); - TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) - << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - TVM_FFI_ICHECK_LT(data.mPaddingLog2, 8) - << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; +#define LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, topk::MaxNumExpertsUnit); \ + } else if (data.mNumExperts <= NumExpertsLimit) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, NumExpertsLimit); \ + } else { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +void run(Data const& data, void* stream) { + TVM_FFI_ICHECK_LE( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + TVM_FFI_ICHECK_LE(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "Renormalize routing."); + } + TVM_FFI_ICHECK_LE( + data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); + TVM_FFI_ICHECK_LE(data.mTopK <= MaxNumTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, + data.mTopK); + TVM_FFI_ICHECK_LE(data.mNumExperts <= NumExpertsLimit, + "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, + NumExpertsLimit); + TVM_FFI_ICHECK_LE(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + TVM_FFI_ICHECK_LE(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", + data.mPaddingLog2); + + bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; bool const useSingleCluster = - data.mNumTokens <= - (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster); - - if (!useSingleCluster) { - TVM_FFI_ICHECK(data.mPtrExpertIdx != nullptr) - << "When #tokens is large, `mPtrExpertIdx` is a required input."; - TVM_FFI_ICHECK(data.mPtrExpertCounts != nullptr) - << "When #tokens is large, `mPtrExpertCounts` is a required input."; + data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + ? MaxNumTokensSingleClusterScores + : MaxNumTokensSingleCluster); + + if (!useSingleCluster && !useSingleBlock) { + TVM_FFI_ICHECK_LE( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), + "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); + TVM_FFI_ICHECK_LE(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); } - if (useSingleCluster) { - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, - NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false); + if (useSingleBlock) { + //@TODO: For now we use the single block kernel for cases with token number no larger than 4. + // We will future tune this threshold based on the performance. + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesBlockKernel, 1, data.mNumExperts, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + } else if (useSingleCluster) { + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); } else { uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; - - uint32_t const histogramEltsPerBlock = 8 * NumThreadsHist; - uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist; + uint32_t const numThreadsHist = getMaxNumExperts(data.mNumExperts); + uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; + uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; // Limit grid size (all kernels use a grid-stride loop). uint32_t const maxNumBlocks = 1024; @@ -271,25 +443,26 @@ void runImpl(Data const& data, void* stream) { int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - if (data.mPtrScores != nullptr) { - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, - NumThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false); + if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); } else { // Reset the global histograms. - CHECK_CUDA_ERROR(cudaMemsetAsync(data.mPtrExpertCounts, 0, - static_cast(2 * data.mNumExperts) * sizeof(int32_t), - (cudaStream_t)stream)); + LAUNCH_ROUTING_RENORNALIZE(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); } - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesHistogramKernel, numBlocksHistogram, - NumThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false); - LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, - NumThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false); + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); + LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); } } diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 0eb426b0e0..2d8f5491b1 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -112,39 +112,61 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported pair"); \ } -#define LAUNCH_ROUTING(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, forceFloatInput) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, true), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, false), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, true), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, false), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \ + numThreads, smemSize, stream, extraFlag, \ + forceFloatInput, numExperts) \ + if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, numExperts) \ + if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh index 3422e2999a..dd7d5c474d 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh @@ -28,7 +28,6 @@ #include "RoutingKernelTopK.cuh" //////////////////////////////////////////////////////////////////////////////////////////////////// - namespace moe::dev { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -43,9 +42,6 @@ static constexpr int NumBlocksPerCluster = 8; // Performance tuning knob. static constexpr int NumEltsPerOffsetTilePerThread = 8; -static constexpr int NumThreadsHist = 256; -static constexpr int NumWarpsHist = NumThreadsHist / WarpSize; - //////////////////////////////////////////////////////////////////////////////////////////////////// static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; } @@ -103,27 +99,32 @@ __device__ void initArr(int startIdx, int numElts, int stride, DataType* arr, Da template __device__ void calcSoftmax(cg::thread_block_tile const& warp, DataType (&scores)[VecSize]) { - DataType maxScore = DataType{-INFINITY}; - DataType sumScore = DataType{0.f}; - + // Compute in float to support half/bfloat16 inputs safely. + float maxScore = -INFINITY; + float sumScore = 0.f; // Get the max score for each token +#pragma unroll for (int i = 0; i < VecSize; ++i) { - maxScore = scores[i] >= maxScore ? scores[i] : maxScore; + float si = static_cast(scores[i]); + maxScore = si >= maxScore ? si : maxScore; } - maxScore = cg::reduce(warp, maxScore, cg::greater()); + maxScore = cg::reduce(warp, maxScore, cg::greater()); // Get the summation of scores for each token #pragma unroll for (int i = 0; i < VecSize; ++i) { - scores[i] = static_cast(exp(scores[i] - maxScore)); - sumScore += scores[i]; + float si = static_cast(scores[i]); + float e = expf(si - maxScore); + scores[i] = static_cast(e); + sumScore += e; } - sumScore = cg::reduce(warp, sumScore, cg::plus()); + sumScore = cg::reduce(warp, sumScore, cg::plus()); // Normalize the scores #pragma unroll for (int i = 0; i < VecSize; ++i) { - scores[i] = static_cast(scores[i] / sumScore); + float si = static_cast(scores[i]) / sumScore; + scores[i] = static_cast(si); } } @@ -207,8 +208,13 @@ __device__ void routingPermutation(KernelParams params, auto loopBody = [&](int ii, int expandedIdx) { TypePacked scoreIdx; if constexpr (LoadExpertIdxFromGlobal) { - scoreIdx = TypePacked{static_cast(params.mPtrExpertIdx[expandedIdx].score), - static_cast(params.mPtrExpertIdx[expandedIdx].idx)}; + if (params.mPtrTopKIds != nullptr) { + scoreIdx = TypePacked{static_cast(params.mPtrTopKWeights[expandedIdx]), + static_cast(params.mPtrTopKIds[expandedIdx])}; + } else { + scoreIdx = TypePacked{static_cast(params.mPtrTopKPacked[expandedIdx].score), + static_cast(params.mPtrTopKPacked[expandedIdx].idx)}; + } } else { TypePacked const* remoteSmem = cg::cluster_group::map_shared_rank( smemPackedScoreIdx, expandedIdx / (NumWarps * params.mTopK)); @@ -221,8 +227,8 @@ __device__ void routingPermutation(KernelParams params, auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0; - if (params.mPtrExpertWeights != nullptr) { - params.mPtrExpertWeights[expandedIdx] = OutputT{scoreIdx.score}; + if (params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) { + params.mPtrTopKWeights[expandedIdx] = OutputT{scoreIdx.score}; } }; @@ -335,7 +341,7 @@ __device__ void routingPermutation(KernelParams params, // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens // TODO: this is not sufficient to ensure visibility in the next kernel! -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } @@ -378,11 +384,12 @@ __device__ void routingPermutation(KernelParams params, // Note: the histogram calculation could also be fused with routingMainKernel, but this might be // inefficient if we have one CTA per token doing a single global atomic. template -__global__ void __launch_bounds__(NumThreadsHist) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHistogramKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; + // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist]; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; // For unrolling. uint32_t constexpr NumEltsPerThread = 8; @@ -404,22 +411,29 @@ __global__ void __launch_bounds__(NumThreadsHist) uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK; uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - uint32_t const gridBlockOffset = blockIdx.x * NumThreadsHist; - uint32_t const gridStride = gridDim.x * NumThreadsHist; + uint32_t const gridBlockOffset = blockIdx.x * KernelParams::MaxNumExperts; + uint32_t const gridStride = gridDim.x * KernelParams::MaxNumExperts; // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int expandedIdx) { - PackedScoreIdx scoreIdx = params.mPtrExpertIdx[expandedIdx]; + PackedScoreIdx scoreIdx; + int idx; + if (params.mPtrTopKIds != nullptr) { + idx = params.mPtrTopKIds[expandedIdx]; + } else { + // If params.mPtrTopKIds != nullptr, we don't need to store the weights + if (params.mPtrTopKWeights != nullptr) { + scoreIdx = params.mPtrTopKPacked[expandedIdx]; + idx = scoreIdx.idx; + params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + } // check whether this expert is local to our GPU at all and ignore if not - auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx; + auto localExpertIdx = idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; if (isLocalExpert) { - atomicAdd(&smemExpertCount[scoreIdx.idx], 1); - } - - if (params.mPtrExpertWeights != nullptr) { - params.mPtrExpertWeights[expandedIdx] = static_cast(scoreIdx.score); + atomicAdd(&smemExpertCount[idx], 1); } }; @@ -427,15 +441,15 @@ __global__ void __launch_bounds__(NumThreadsHist) for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize; expandedIdx0 += gridStride * NumEltsPerThread) { // Fast path if bound checks aren't necessary - if (expandedIdx0 + NumEltsPerThread * NumThreadsHist <= expandedIdxSize) { + if (expandedIdx0 + NumEltsPerThread * KernelParams::MaxNumExperts <= expandedIdxSize) { #pragma unroll for (uint32_t ii = 0; ii < NumEltsPerThread; ii++) { - uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsHist + threadIdx.x; + uint32_t expandedIdx = expandedIdx0 + ii * KernelParams::MaxNumExperts + threadIdx.x; loopBody(expandedIdx); } } else { for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize; - expandedIdx += NumThreadsHist) { + expandedIdx += KernelParams::MaxNumExperts) { loopBody(expandedIdx); } } @@ -456,18 +470,20 @@ __global__ void __launch_bounds__(NumThreadsHist) //////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params) { +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesOffsetsKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreadsHist]; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist]; - __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[NumThreadsHist]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[KernelParams::MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[KernelParams::MaxNumExperts]; // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; + using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread; - static constexpr int MaxExpandedIdxPerBlock = NumThreadsHist * MaxExpandedIdxPerThread; + static constexpr int MaxExpandedIdxPerBlock = + KernelParams::MaxNumExperts * MaxExpandedIdxPerThread; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); @@ -514,7 +530,8 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke __syncthreads(); // The first block writes out padded count - if (blockIdx.x == 0 && warpIdx == NumWarpsHist - 1 && cute::elect_one_sync()) { + if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 && + cute::elect_one_sync()) { const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; @@ -557,20 +574,21 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int ii, int expandedIdx) { - PackedScoreIdx scoreIdx = params.mPtrExpertIdx[expandedIdx]; - expertIndexes[ii] = scoreIdx.idx; + expertIndexes[ii] = params.mPtrTopKIds ? params.mPtrTopKIds[expandedIdx] + : params.mPtrTopKPacked[expandedIdx].idx; // check whether this expert is local to our GPU at all and ignore if not - auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx; + auto localExpertIdx = expertIndexes[ii] - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0; + expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIndexes[ii], 1) : 0; }; // For all tiles but the last, all indices are in bounds. if (tileIdx < numTiles - 1) { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { - auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x; + auto expandedIdx = + tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; loopBody(ii, expandedIdx); } } else { @@ -584,13 +602,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = - tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsHist <= + tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * KernelParams::MaxNumExperts <= expandedIdxSize; if (takeFastPath) { #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; - auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x; + auto expandedIdx = + tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; loopBody(ii, expandedIdx); } } else { @@ -598,7 +617,8 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; - auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x; + auto expandedIdx = + tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; if (expandedIdx >= expandedIdxSize) { doBreak = true; break; @@ -653,13 +673,15 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke if (tileIdx < numTiles - 1) { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { - auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x; + auto expandedIdx = + tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; storeLoopBody(ii, expandedIdx); } } else { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { - auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x; + auto expandedIdx = + tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; if (expandedIdx >= expandedIdxSize) { break; } @@ -669,16 +691,40 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -// Trigger secondary kernel. -// Note: this does not guarantee the visibility of prior writes unless the consumer executes a -// dependency sync. -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 + // Trigger secondary kernel. + // Note: this does not guarantee the visibility of prior writes unless the consumer executes a + // dependency sync. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingInitExpertCounts(KernelParams params) { + // initialize the mPtrExpertCounts + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; + int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +} } // namespace routing } // namespace moe::dev diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index e182fc114d..064b99f4ae 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -56,15 +56,23 @@ struct DataBase { int32_t* mPtrPermutedIdxToTokenIdx{nullptr}; // optional: if `nullptr`, it is not filled // dim: [mNumTokens, mTopK] - void* mPtrExpertWeights{nullptr}; + // When mPtrTopKIds is provided, mPtrTopKWeights must be also provided as inputs. + // Otherwise, mPtrTopKWeights is the output scores of the topK experts. + void* mPtrTopKWeights{nullptr}; + // optional: if `nullptr`, it is not filled + // dim: [mNumTokens, mTopK] + // mPtrTopKIds[i] is the index of the expert for the i-th token in the top-k experts + // Together with mPtrTopKWeights, they form the top-k experts for each token + int32_t* mPtrTopKIds{nullptr}; + // optional: if `nullptr`, scores are used directly as input. // If it is given, it must represent a packed value s.t. the most significant // 16/32 bits represent the score without sigmoid activation and // the least significant 16 bits represent the index of the chosen expert (unsigned). // note: this is required if the number of tokens is large. // dim: [mNumTokens, mTopK] - void* mPtrExpertIdx{nullptr}; - // optional: if `nullptr`, `mPtrExpertIdx` must be provided. + void* mPtrTopKPacked{nullptr}; + // optional: if `nullptr`, `mPtrTopKPacked` must be provided. // If it is given, it represents the scores without sigmoid activation for // each token and expert. // note: if it is provided, we always re-compute the top1 scores @@ -92,10 +100,11 @@ struct DataBase { int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; + static constexpr int MaxNumExperts = MaxNumExperts_; static constexpr bool UsePdl = UsePdl_; // Public pointer members @@ -106,7 +115,8 @@ struct KernelParamsBase { int32_t* mPtrCtaIdxXyToBatchIdx = nullptr; int32_t* mPtrCtaIdxXyToMnLimit = nullptr; int32_t* mPtrNumNonExitingCtas = nullptr; - OutputT* mPtrExpertWeights = nullptr; + OutputT* mPtrTopKWeights = nullptr; + int32_t* mPtrTopKIds = nullptr; InputT const* mPtrScores = nullptr; // Public scalar members @@ -128,7 +138,8 @@ struct KernelParamsBase { mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx; mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit; mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas; - mPtrExpertWeights = static_cast(data.mPtrExpertWeights); + mPtrTopKWeights = static_cast(data.mPtrTopKWeights); + mPtrTopKIds = static_cast(data.mPtrTopKIds); mPtrScores = (InputT const*)data.mPtrScores; mNumTokens = data.mNumTokens; @@ -160,17 +171,17 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; - PackedScoreIdx* mPtrExpertIdx = nullptr; + PackedScoreIdx* mPtrTopKPacked = nullptr; - // OutputT* mPtrExpertWeightsFull = nullptr; - // Note: this variable(mPtrExpertWeightsFull) might need to be added back for the low-latency + // OutputT* mPtrTopKWeightsFull = nullptr; + // Note: this variable(mPtrTopKWeightsFull) might need to be added back for the low-latency // kernels for MoE in tllm-gen in the future OutputT const* mPtrRoutingBias = nullptr; @@ -186,9 +197,9 @@ struct KernelParams : public KernelParamsBase { KernelParams params; params.setBaseParams(data); - params.mPtrExpertIdx = (PackedScoreIdx*)data.mPtrExpertIdx; + params.mPtrTopKPacked = (PackedScoreIdx*)data.mPtrTopKPacked; - // params.mPtrExpertWeightsFull = static_cast(data.mPtrExpertWeightsFull); + // params.mPtrTopKWeightsFull = static_cast(data.mPtrTopKWeightsFull); params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); params.mNumExpertGroups = data.mNumExpertGroups; @@ -215,12 +226,12 @@ struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; - PackedScoreIdx* mPtrExpertIdx = nullptr; + PackedScoreIdx* mPtrTopKPacked = nullptr; int32_t mTopK; @@ -228,7 +239,7 @@ struct KernelParams : public KernelParamsBase { KernelParams params; params.setBaseParams(data); - params.mPtrExpertIdx = (PackedScoreIdx*)data.mPtrExpertIdx; + params.mPtrTopKPacked = (PackedScoreIdx*)data.mPtrTopKPacked; params.mTopK = data.mTopK; return params; } @@ -253,14 +264,15 @@ struct Data : public DataBase { bool mApplySoftmaxAfterTopK{false}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr bool DoSoftmaxBeforeTopK = DoSoftmaxBeforeTopK_; - PackedScoreIdx* mPtrExpertIdx = nullptr; + PackedScoreIdx* mPtrTopKPacked = nullptr; int32_t mTopK = 0; @@ -271,7 +283,7 @@ struct KernelParams : public KernelParamsBase { KernelParams params; params.setBaseParams(data); - params.mPtrExpertIdx = (PackedScoreIdx*)data.mPtrExpertIdx; + params.mPtrTopKPacked = (PackedScoreIdx*)data.mPtrTopKPacked; params.mNormTopkProb = data.mNormTopkProb; params.mApplySoftmaxAfterTopK = data.mApplySoftmaxAfterTopK; params.mTopK = data.mTopK; diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh index 0f61d81eb2..ed08a70ae8 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh @@ -312,14 +312,14 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const }; template -__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, - Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], - int32_t (&idx)[N], Type const minValue, - int actualK = K) { +__forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], + Type (&value)[N], int32_t (&idx)[N], + Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < WarpSize, "Top K must have K < WarpSize"); static_assert(N > 0, "Top K must have N > 0"); - static_assert(N <= 16, "Only support candidates number less than or equal to 128"); + static_assert(N < 5, "Only support candidates number less than or equal to 128"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll @@ -346,6 +346,58 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const } }; +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], + int32_t (&idx)[N], Type const minValue, + int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < WarpSize, "Top K must have K < WarpSize"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); + using RedType = TopKRedType; + + if constexpr (N <= 4) { + reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); + } else { + constexpr int numLoops = (N - 1) / 4 + 1; + constexpr int numResults = (numLoops * K - 1) / WarpSize + 1; + + Type topKBufferValue[numResults]; + int32_t topKBufferIdx[numResults]; + int32_t laneIdx = threadIdx.x % WarpSize; + + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + topKBufferIdx[ii] = ii * WarpSize - 1; //@todo: check if this is correct + } + for (int loop = 0; loop < numLoops; ++loop) { + int start = loop * 4; + Type topKValue[K]; + int32_t topKIdx[K]; + Type inValue[4]; + int32_t inIdx[4]; + for (int i = 0; i < 4; ++i) { + inValue[i] = value[start + i]; + inIdx[i] = idx[start + i]; + } + reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); + int inOffset = laneIdx % K; + if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { + topKBufferValue[0] = topKValue[inOffset]; + topKBufferIdx[0] = topKIdx[inOffset]; + } + if (loop == numLoops - 1 && (laneIdx < (numLoops * K - WarpSize))) { + topKBufferValue[1] = topKValue[inOffset]; + topKBufferIdx[1] = topKIdx[inOffset]; + } + } + + reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, + actualK); + } +}; + #undef TOPK_SWAP } // namespace topk } // namespace moe::dev::routing diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 880c739259..8382106239 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -1858,8 +1858,8 @@ def cache_permute_indices(): "num_experts": 384, "top_k": 8, "padding": 8, - "n_groups": 12, - "top_k_groups": 4, + "n_groups": 1, + "top_k_groups": 1, "routed_scaling": 2.5, "has_routing_bias": True, "routing_method_type": RoutingMethodType.DeepSeekV3, From 49297c7cbfd61e4f39b00112add11ed26456e92c Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 1 Oct 2025 16:19:12 -0700 Subject: [PATCH 2/4] fix compile error Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 24 +++--- csrc/trtllm_fused_moe_routing_renormalize.cu | 76 +++++++++---------- csrc/trtllm_fused_moe_runner.cu | 12 +-- flashinfer/fused_moe/core.py | 9 ++- .../flashinfer/trtllm/fused_moe/DevKernel.h | 1 + .../trtllm/fused_moe/RoutingKernelTopK.cuh | 2 + include/flashinfer/trtllm/fused_moe/runner.h | 2 +- 7 files changed, 64 insertions(+), 62 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5741611644..9963c47af6 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -333,18 +333,18 @@ void trtllm_fp8_block_scale_moe_launcher( << "routing_bias has incorrect shape."; } - if (n_group <= 0 || topk_group <= 0) { - TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1."; - } else { - TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8."; - TVM_FFI_ICHECK_LE(topk_group, 4) - << "Current routing kernel (with groups) only supports topk_group<=4."; - TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group."; - TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group"; - // This check ensures we have enough experts in the selected groups to handle the top_k routing - TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group)) - << "top_k must be less than total number of experts in selected groups"; - } +// if (n_group <= 0 || topk_group <= 0) { +// TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1."; +// } else { +// TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8."; +// TVM_FFI_ICHECK_LE(topk_group, 4) +// << "Current routing kernel (with groups) only supports topk_group<=4."; +// TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group."; +// TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group"; +// // This check ensures we have enough experts in the selected groups to handle the top_k routing +// TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group)) +// << "top_k must be less than total number of experts in selected groups"; +// } TVM_FFI_ICHECK_EQ(num_experts % 4, 0) << "Routing kernel expects that num_experts must be divisible by 4"; TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k"; diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 63c2a3a72e..0778861d93 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -22,10 +22,11 @@ namespace routingRenormalize { static constexpr int NumThreads = 1024; static constexpr int NumWarps = NumThreads / WarpSize; -static constexpr int MaxNumTopExperts = 8; -static constexpr int MaxNumExperts = 128; +static constexpr int MaxNumTopExperts = 10; +static constexpr int NumExpertsLimit = 512; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; +static constexpr int BlockKernelMaxNumTokens = 4; template __forceinline__ __device__ void routingTopKExperts( @@ -380,28 +381,25 @@ int32_t constexpr getMaxNumExperts(int32_t numExperts) { //////////////////////////////////////////////////////////////////////////////////////////////////// void run(Data const& data, void* stream) { - TVM_FFI_ICHECK_LE( - data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); + TVM_FFI_ICHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + << "Routing kernel requires at least one input parameter"; if (data.mPtrTopKIds != nullptr) { - TVM_FFI_ICHECK_LE(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " - "Renormalize routing."); + TVM_FFI_ICHECK(data.mPtrTopKWeights != nullptr) + << "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "Renormalize routing."; } - TVM_FFI_ICHECK_LE( + TVM_FFI_ICHECK( data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, - "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); - TVM_FFI_ICHECK_LE(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, - data.mTopK); - TVM_FFI_ICHECK_LE(data.mNumExperts <= NumExpertsLimit, - "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, - NumExpertsLimit); - TVM_FFI_ICHECK_LE(data.mNumExperts % 4 == 0, - "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - TVM_FFI_ICHECK_LE(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", - data.mPaddingLog2); + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) + << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; + TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) + << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; + TVM_FFI_ICHECK_LE(data.mNumExperts, NumExpertsLimit) + << "Routing kernel expects #experts " << data.mNumExperts << " to be no more than " << NumExpertsLimit << "."; + TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) + << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; + TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8) << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; @@ -411,11 +409,11 @@ void run(Data const& data, void* stream) { : MaxNumTokensSingleCluster); if (!useSingleCluster && !useSingleBlock) { - TVM_FFI_ICHECK_LE( - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), - "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); - TVM_FFI_ICHECK_LE(data.mPtrExpertCounts != nullptr, - "When #tokens is large, `mPtrExpertCounts` is a required input."); + TVM_FFI_ICHECK + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) + << "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."; + TVM_FFI_ICHECK(data.mPtrExpertCounts != nullptr) + << "When #tokens is large, `mPtrExpertCounts` is a required input."; } if (useSingleBlock) { @@ -466,19 +464,19 @@ void run(Data const& data, void* stream) { } } -void run(Data const& data, void* stream) { - TVM_FFI_ICHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr) - << "Routing kernel requires at least one input parameter"; - TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) - << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; - TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) - << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; - TVM_FFI_ICHECK_LT(data.mPaddingLog2, 8) - << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; - - runImpl(data, stream); -} +// void run(Data const& data, void* stream) { +// TVM_FFI_ICHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr) +// << "Routing kernel requires at least one input parameter"; +// TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && +// data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) +// << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; +// TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) +// << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; +// TVM_FFI_ICHECK_LT(data.mPaddingLog2, 8) +// << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; + +// runImpl(data, stream); +// } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 931431fa2f..37134d950a 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -65,12 +65,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mUsePdl = true; // output: - routingData.mPtrExpertIdx = routingExpertIndexes; + routingData.mPtrTopKPacked = routingExpertIndexes; routingData.mPtrExpertCounts = expertCountHistogram; routingData.mPtrPermutedIdxSize = permutedIdxSize; routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; - routingData.mPtrExpertWeights = expertWeights; + routingData.mPtrTopKWeights = expertWeights; routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; @@ -102,12 +102,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mUsePdl = true; // output: - routingData.mPtrExpertIdx = routingExpertIndexes; + routingData.mPtrTopKPacked = routingExpertIndexes; routingData.mPtrExpertCounts = expertCountHistogram; routingData.mPtrPermutedIdxSize = permutedIdxSize; routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; - routingData.mPtrExpertWeights = expertWeights; + routingData.mPtrTopKWeights = expertWeights; routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; @@ -144,12 +144,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // // Outputs // - routingData.mPtrExpertIdx = routingExpertIndexes; + routingData.mPtrTopKPacked = routingExpertIndexes; routingData.mPtrExpertCounts = expertCountHistogram; routingData.mPtrPermutedIdxSize = permutedIdxSize; routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; - routingData.mPtrExpertWeights = expertWeights; + routingData.mPtrTopKWeights = expertWeights; // // Grouped Gemm Launch Config Buffers diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index e98e47d2a0..d4306bd59c 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1340,7 +1340,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float, + routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, tile_tokens_dim: int = 8, routing_method_type: int = 0, @@ -1372,7 +1372,7 @@ def trtllm_fp8_block_scale_moe_op( intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float, + routed_scaling_factor: Optional[float], tile_tokens_dim: int, routing_method_type: int, use_shuffled_weight: bool = False, @@ -1381,6 +1381,7 @@ def trtllm_fp8_block_scale_moe_op( ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) + # Call the C++ function for block scale MoE moe_op.trtllm_fp8_block_scale_moe( routing_logits, @@ -1427,7 +1428,7 @@ def _fake_trtllm_fp8_block_scale_moe( intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float, + routed_scaling_factor: Optional[float], tile_tokens_dim: int = 8, routing_method_type: int = 0, use_shuffled_weight: bool = False, @@ -1755,7 +1756,7 @@ def trtllm_fp8_block_scale_moe( intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float, + routed_scaling_factor: Optional[float], tile_tokens_dim: int = 8, routing_method_type: int = 0, use_shuffled_weight: bool = False, diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 2d8f5491b1..4681e5acd8 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -31,6 +31,7 @@ #include "../../exception.h" // #include #include "flashinfer/trtllm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" namespace moe::dev { diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh index ed08a70ae8..3da1447779 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh @@ -31,6 +31,8 @@ namespace cg = cooperative_groups; //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int WarpSize = 32; +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int MaxNumTopK = 10; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 5f066468e6..6a06eacc34 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -32,7 +32,7 @@ namespace trtllmgen_moe { namespace Routing { // The type of method in top-K routing, for use in torch custom op -// Please keep this in sync with the counterpart defined in +// Please keep this in sync with the counterpart defined in // flashinfer/fused_moe/core.py enum class RoutingMethodType : int64_t { // Default: Softmax -> TopK From 1adc570c9a8959a1677263923a403c0911669413 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:58:41 -0700 Subject: [PATCH 3/4] fix format Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 24 +++++++++--------- csrc/trtllm_fused_moe_routing_renormalize.cu | 26 ++++++++++---------- include/flashinfer/trtllm/fused_moe/runner.h | 2 +- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 9963c47af6..511b439c95 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -333,18 +333,18 @@ void trtllm_fp8_block_scale_moe_launcher( << "routing_bias has incorrect shape."; } -// if (n_group <= 0 || topk_group <= 0) { -// TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1."; -// } else { -// TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8."; -// TVM_FFI_ICHECK_LE(topk_group, 4) -// << "Current routing kernel (with groups) only supports topk_group<=4."; -// TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group."; -// TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group"; -// // This check ensures we have enough experts in the selected groups to handle the top_k routing -// TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group)) -// << "top_k must be less than total number of experts in selected groups"; -// } + // if (n_group <= 0 || topk_group <= 0) { + // TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1."; + // } else { + // TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports + // top_k<=8."; TVM_FFI_ICHECK_LE(topk_group, 4) + // << "Current routing kernel (with groups) only supports topk_group<=4."; + // TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group."; + // TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group"; + // // This check ensures we have enough experts in the selected groups to handle the top_k + // routing TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group)) + // << "top_k must be less than total number of experts in selected groups"; + // } TVM_FFI_ICHECK_EQ(num_experts % 4, 0) << "Routing kernel expects that num_experts must be divisible by 4"; TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k"; diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 0778861d93..79d4166938 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -381,25 +381,26 @@ int32_t constexpr getMaxNumExperts(int32_t numExperts) { //////////////////////////////////////////////////////////////////////////////////////////////////// void run(Data const& data, void* stream) { - TVM_FFI_ICHECK( - data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || + data.mPtrTopKIds != nullptr) << "Routing kernel requires at least one input parameter"; if (data.mPtrTopKIds != nullptr) { TVM_FFI_ICHECK(data.mPtrTopKWeights != nullptr) - << "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " - "Renormalize routing."; + << "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "Renormalize routing."; } - TVM_FFI_ICHECK( - data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) + TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) - << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; + << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; TVM_FFI_ICHECK_LE(data.mNumExperts, NumExpertsLimit) - << "Routing kernel expects #experts " << data.mNumExperts << " to be no more than " << NumExpertsLimit << "."; + << "Routing kernel expects #experts " << data.mNumExperts << " to be no more than " + << NumExpertsLimit << "."; TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) - << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8) << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; + << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; + TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8) + << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; @@ -409,8 +410,7 @@ void run(Data const& data, void* stream) { : MaxNumTokensSingleCluster); if (!useSingleCluster && !useSingleBlock) { - TVM_FFI_ICHECK - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) + TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) << "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."; TVM_FFI_ICHECK(data.mPtrExpertCounts != nullptr) << "When #tokens is large, `mPtrExpertCounts` is a required input."; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 6a06eacc34..5f066468e6 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -32,7 +32,7 @@ namespace trtllmgen_moe { namespace Routing { // The type of method in top-K routing, for use in torch custom op -// Please keep this in sync with the counterpart defined in +// Please keep this in sync with the counterpart defined in // flashinfer/fused_moe/core.py enum class RoutingMethodType : int64_t { // Default: Softmax -> TopK From c9e42cacef0506c57777d8d8efcf859219529951 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 1 Oct 2025 19:42:08 -0700 Subject: [PATCH 4/4] update some change Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 10 +++--- csrc/trtllm_fused_moe_routing_renormalize.cu | 3 ++ tests/moe/test_trtllm_gen_fused_moe.py | 36 ++++++++++---------- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 511b439c95..c7e0142ba9 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -684,8 +684,8 @@ Array trtllm_fp4_block_scale_moe_launcher( TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given"; TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0) << "num_experts must be divisible by n_group"; - TVM_FFI_ICHECK(top_k <= 8 && top_k > 0) - << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."; + // TVM_FFI_ICHECK(top_k <= 8 && top_k > 0) + // << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."; TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0) << "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."; TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value()) @@ -698,9 +698,9 @@ Array trtllm_fp4_block_scale_moe_launcher( static_cast(routing_method_type) == RoutingMethodType::RenormalizeNaive || static_cast(routing_method_type) == RoutingMethodType::TopK) { - TVM_FFI_ICHECK(top_k <= 8 && top_k > 0) - << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && " - "top_k>0."; + // TVM_FFI_ICHECK(top_k <= 8 && top_k > 0) + // << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && " + // "top_k>0."; } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups, Llama4) only supports top_k=1."; diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 79d4166938..23442efb74 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -462,6 +462,9 @@ void run(Data const& data, void* stream) { /*smemSize=*/0, // No dynamic smem stream, data.mDoSoftmaxBeforeTopK); } + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + std::cout << "cudaGetLastError: " << cudaGetErrorString(result) << std::endl; } // void run(Data const& data, void* stream) { diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 8382106239..3a7ecc3c6c 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -1838,8 +1838,8 @@ def cache_permute_indices(): @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) -@pytest.mark.parametrize("hidden_size", [1024, 8192]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384]) +@pytest.mark.parametrize("hidden_size", [1024, 2048, 8192]) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 5120, 768, 384]) @pytest.mark.parametrize( "moe_impl", [ @@ -1906,8 +1906,8 @@ def cache_permute_indices(): ), pytest.param( { - "num_experts": 128, - "top_k": 8, + "num_experts": 512, + "top_k": 10, "padding": 8, "n_groups": None, "top_k_groups": None, @@ -1917,21 +1917,21 @@ def cache_permute_indices(): "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], }, id="Renorm", - marks=pytest.mark.skip( - reason="Disabled for testing speed - similar to RenormalizeNaive" - ), + # marks=pytest.mark.skip( + # reason="Disabled for testing speed - similar to RenormalizeNaive" + # ), ), pytest.param( { - "num_experts": 128, - "top_k": 8, + "num_experts": 512, + "top_k": 10, "padding": 8, "n_groups": None, "top_k_groups": None, "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.RenormalizeNaive, - "compatible_moe_impls": [FP4Moe], + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], }, id="RenormNaive", ), @@ -2034,13 +2034,13 @@ def test_moe_quantization_classes( pytest.skip( f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" ) - elif gated_act_type == GatedActType.SwiGlu and ( - hidden_size > 1024 or intermediate_size > 1024 - ): - # Skip some tests for SwiGlu for testing speed - pytest.skip( - f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" - ) + # elif gated_act_type == GatedActType.SwiGlu and ( + # hidden_size > 1024 or intermediate_size > 1024 + # ): + # # Skip some tests for SwiGlu for testing speed + # pytest.skip( + # f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" + # ) if type(moe_impl) not in routing_config["compatible_moe_impls"]: pytest.skip( @@ -2079,7 +2079,7 @@ def test_moe_quantization_classes( # Validation checks assert top_k <= num_experts - assert top_k <= 8 + # assert top_k <= 8 if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0): assert top_k_groups <= 4 assert num_experts > n_groups