diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f86f2d4582..bf57fd5b9e 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "flashinfer/trtllm/batched_gemm/KernelRunner.h" @@ -115,7 +116,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( } } - FLASHINFER_CHECK(!mPassingConfigIndices.empty(), "No kernel found for the given options"); + FLASHINFER_CHECK( + !mPassingConfigIndices.empty(), + "No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, " + "mUseDeepSeekFp8: %d, " + "mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d", + tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(), + tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput, + mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize); } size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( @@ -367,6 +375,7 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( return false; }; + // Sort configs by options. std::vector sortedIndices = mPassingConfigIndices; std::sort(sortedIndices.begin(), sortedIndices.end(), cmpFunc); diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5741611644..5e56bc67de 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -393,6 +393,12 @@ void trtllm_fp8_block_scale_moe_launcher( int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); + int32_t max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); + int32_t max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device); Tensor expanded_idx_to_permuted_idx = alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device); @@ -413,16 +419,16 @@ void trtllm_fp8_block_scale_moe_launcher( // dl_float8_e4m3fn, hidden_states->device); // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, // dl_float8_e4m3fn, hidden_states->device); - Tensor gemm1_output = - alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states->device); + Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, + hidden_states->device); Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, dl_float32, hidden_states->device); - Tensor activation_output = - alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states->device); - Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states->device); - Tensor gemm2_output = - alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device); + Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size}, + dl_uint8, hidden_states->device); + Tensor activation_output_scale = alloc_tensor( + {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states->device); + Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, + hidden_states->device); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); @@ -519,7 +525,8 @@ void trtllm_fp8_block_scale_moe_launcher( // setup workspace workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); - workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.total_max_padded_tokens = + std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(expert_indexes->data); workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); @@ -764,6 +771,12 @@ Array trtllm_fp4_block_scale_moe_launcher( int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); + int32_t max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); + int32_t max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states->device); Tensor expanded_idx_to_permuted_idx = alloc_tensor({args.num_tokens, args.top_k}, dl_int32, hidden_states->device); @@ -788,20 +801,20 @@ Array trtllm_fp4_block_scale_moe_launcher( // Tensor gemm1_output = alloc_tensor( // {max_num_padded_tokens, gemm1_output_hidden}, // dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, hidden_states->device); - Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, gemm1_output_hidden}, + Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, hidden_states->device); Optional gemm1_output_scale = std::nullopt; if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) { - int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens, + int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1, intermediate_size / sf_vec_size); // gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states->device); gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states->device); } - Tensor gemm2_output = - alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device); + Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, + hidden_states->device); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); @@ -958,7 +971,8 @@ Array trtllm_fp4_block_scale_moe_launcher( // setup workspace workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); - workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.total_max_padded_tokens = + std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(expert_indices->data); workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 4d344d02bb..c963007094 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -76,7 +76,7 @@ def get_available_cubin_files( class ArtifactPath: TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen" TRTLLM_GEN_BMM: str = ( - "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802" + "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" ) TRTLLM_GEN_GEMM: str = ( "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e" @@ -91,7 +91,7 @@ class MetaInfoHash: "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d" ) TRTLLM_GEN_BMM: str = ( - "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34" + "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152" ) DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 364d4182f1..4f862a4559 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -894,7 +894,9 @@ def __init__( self.gated_act_type = gated_act_type self.tile_tokens_dim = tile_tokens_dim - def get_tile_tokens_dim(self, num_tokens: int, top_k: int): + def get_tile_tokens_dim( + self, num_tokens: int, top_k: int, max_tile_tokens_dim: int = 128 + ): # Factor to account for the imbalance of the experts. # factor equals to the # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert @@ -910,10 +912,10 @@ def get_tile_tokens_dim(self, num_tokens: int, top_k: int): num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) # And pad the number to the next power of 2. tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - + if num_tokens_per_expert > 128 and num_tokens_per_expert < 256: + tile_tokens_dim = 192 + # Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim) return tile_tokens_dim def get_valid_tactics( @@ -931,7 +933,7 @@ def get_valid_tactics( ) = inputs num_tokens = routing_logits.shape[0] tile_tokens_dim = ( - self.get_tile_tokens_dim(num_tokens, self.top_k) + self.get_tile_tokens_dim(num_tokens, self.top_k, 128) if self.tile_tokens_dim is None else self.tile_tokens_dim ) @@ -975,7 +977,7 @@ def forward( ) = inputs num_tokens = routing_logits.shape[0] tile_tokens_dim = ( - self.get_tile_tokens_dim(num_tokens, self.top_k) + self.get_tile_tokens_dim(num_tokens, self.top_k, 128) if self.tile_tokens_dim is None else self.tile_tokens_dim ) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 448f6d116a..f8ca3446fa 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -113,15 +113,18 @@ def next_positive_power_of_2(x: int) -> int: return n + 1 -def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) -> int: +def calculate_tile_tokens_dim( + num_tokens: int, num_experts: int, top_k: int, max_tile_tokens_dim: int = 128 +) -> int: # Guess tokens per expert assuming perfect expert distribution first. num_tokens_per_expert = num_tokens * top_k // num_experts # And pad the number to the next power of 2. tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - + if num_tokens_per_expert > 128 and num_tokens_per_expert < 256: + tile_tokens_dim = 192 + # Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim) return tile_tokens_dim diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index 629dfc4d27..6b1f910178 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -506,8 +506,19 @@ class BatchedGemmInterface { throw std::invalid_argument("Invalid combination of options"); } - int32_t const numCtasTile = + if (batchM) { + numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX); + } else { + numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY); + } + + int32_t numCtasTile = batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM); + if (batchM) { + numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY); + } else { + numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX); + } int32_t const numCtasInner = options.mNumSlicesForSplitK; return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner); } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index c29fb24b0a..07dcd30be4 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -76,37 +76,40 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { // FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM. BatchedGemmOptions( gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, - tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, - bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, - int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, - bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, - gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, - int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, - int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, - int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN, - gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, - bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps, - bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType, - bool clampBeforeAct, std::vector batchedM, std::vector batchedN, - BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl, - bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp, - int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt) + int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, + int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, + gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, + int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps, + int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, + int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, + std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, + tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, + int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput, + bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, + bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, + gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector batchedM, + std::vector batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch, + int numTokens, RouteImpl routeImpl, std::optional routeSfsImpl, + bool gridWaitForPrimaryRouting, bool fusedAct, bool useTmaOobOpt) : gemmGatedAct::GemmGatedActOptions( gemm::GemmOptions( - allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, - dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, - enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits, - epilogueTileM, epilogueTileN, gridTriggerSecondaryA, gridTriggerSecondaryB, - gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, - hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, - mmaKind, mmaM, mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK, - numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, + allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, + ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, + enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, + epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA, + gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, + gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, + layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, + numRegsCopySfLdsSttm, numSlicesForSplitK, numSlicesForSliceK, numStages, + numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, @@ -126,6 +129,7 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { mNumRegsCastAWarps(numRegsCastAWarps), mNumTokens(numTokens), mRouteImpl(routeImpl), + mRouteSfsImpl(routeSfsImpl), mUseTmaOobOpt(useTmaOobOpt) {} // Batched M-dimensions of GEMM. @@ -153,6 +157,8 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { int mNumTokens{32}; // Whether load the input tokens and do routing. RouteImpl mRouteImpl{RouteImpl::NoRoute}; + // Routing logic for scaling factors. If not specified, mRouteImpl is used. + std::optional mRouteSfsImpl{std::nullopt}; // Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in // BatchedGemm/KernelParamsDecl.h. bool mUseTmaOobOpt{false}; @@ -235,6 +241,18 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw "E2m1 is not supported with DeepSeek FP8"); } + if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() != options.mRouteImpl) { + TLLM_CHECK_ERROR( + options.mRouteSfsImpl.value() == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma, + "RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma"); + } else if (!options.mRouteSfsImpl.has_value()) { + if (updateOptions) { + options.mRouteSfsImpl = options.mRouteImpl; + } else { + TLLM_LOG_ERROR("RouteSfsImpl must be specified"); + return false; + } + } if (batchM) { if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(doesRouteImplUseNoRoute(options.mRouteImpl), @@ -269,18 +287,20 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw } } - if (doesRouteImplUseTma(options.mRouteImpl)) { + if (doesRouteImplUseTma(options.mRouteSfsImpl.value())) { TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N."); if (tg::mmaKindIsBlockFmt(options.mMmaKind)) { auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB; - TLLM_CHECK_ERROR(options.mTileK % tg::dtypeNumEltsPerSf(dtypeRoute) == 0, - "tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)."); TLLM_CHECK_ERROR(options.mTileK % (tg::dtypeNumEltsPerSf(dtypeRoute) * 16) == 0, "tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)."); } } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); + } + if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl)) { TLLM_CHECK_ERROR(options.mSfLayoutA == tg::SfLayout::R128c4, "options.mSfLayoutA has to be tg::SfLayout::R128c4 when not being routed"); @@ -301,6 +321,11 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw TLLM_CHECK_ERROR(options.mK % options.mTileK == 0, "K must be a multiple of TileK"); } + if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute) { + TLLM_CHECK_ERROR(false, + "2CTA BatchedGemm does not support routing along M dimension. To support it, " + "change the input routing data layout to be padded to clusterDimX size."); + } return isValid; } @@ -343,6 +368,8 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) { ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast(options.mRouteImpl) << ")," << std::endl; + ss << "mRouteSfsImpl={batchedGemm::RouteImpl(" + << static_cast(options.mRouteSfsImpl.value()) << ")}," << std::endl; ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h index 6f2b1c270d..e9d5a23a65 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h @@ -97,6 +97,23 @@ enum class TileScheduler { //////////////////////////////////////////////////////////////////////////////////////////////////// +enum class CtaSwizzleType : uint32_t { + // Rasterize CTAs along the M dimension. + RasterizeAlongM = 0, + // Rasterize CTAs along the N dimension. + RasterizeAlongN, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2. + ZigZagAlongM2, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2. + ZigZagAlongN2, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4. + ZigZagAlongM4, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4. + ZigZagAlongN4, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Helper functions to check the SplitK type. #define SPLIT_K_FUNCTION(Mode) \ diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index f9c7044700..fc3bd88101 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -91,16 +91,18 @@ struct GemmOptions { #endif GemmOptions() = default; + GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, - tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, - bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, - int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, - bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, - MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, - int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK, + int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, + int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, + MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, + bool mockAllReduce, int n, int numRegsCopySfLdsSttm, int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, @@ -117,6 +119,7 @@ struct GemmOptions { mClusterDimX{clusterDimX}, mClusterDimY{clusterDimY}, mClusterDimZ{clusterDimZ}, + mCtaSwizzleType{ctaSwizzleType}, mDtypeAcc{dtypeAcc}, mDtypeA{dtypeA}, mDtypeB{dtypeB}, @@ -148,6 +151,7 @@ struct GemmOptions { mMmaN{mmaN}, mMockAllReduce{mockAllReduce}, mN{n}, + mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), mNumSlicesForSplitK{numSlicesForSplitK}, mNumSlicesForSliceK{numSlicesForSliceK}, mNumStages{numStages}, @@ -193,6 +197,8 @@ struct GemmOptions { int mClusterDimY{1}; // Cluster size in Z dim. int mClusterDimZ{1}; + // The type of CTA swizzle. + CtaSwizzleType mCtaSwizzleType{CtaSwizzleType::RasterizeAlongM}; // Data type of the accumulators. tg::Dtype mDtypeAcc{tg::Dtype::Fp32}; // Data type of the A matrix. @@ -263,6 +269,8 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of registers for the LDS+STTM warps. + int mNumRegsCopySfLdsSttm{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, // the problem is distributed across several SMs, where each CTA works on its local K slice. // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) @@ -401,25 +409,36 @@ inline std::string toString(trtllm::gen::MmaKind e) { inline std::string dumpOptions(GemmOptions const& options) { std::stringstream ss; - ss << "mAllReduceAlgo=" << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) - << ")" << "," << std::endl; - ss << "mBiasType=" << "gemm::BiasType(" << static_cast(options.mBiasType) << ")" << "," - << std::endl; + ss << "mAllReduceAlgo=" + << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) << ")" + << "," << std::endl; + ss << "mBiasType=" + << "gemm::BiasType(" << static_cast(options.mBiasType) << ")" + << "," << std::endl; ss << "mBlockK=" << options.mBlockK << "," << std::endl; ss << "mClusterDimX=" << options.mClusterDimX << "," << std::endl; ss << "mClusterDimY=" << options.mClusterDimY << "," << std::endl; ss << "mClusterDimZ=" << options.mClusterDimZ << "," << std::endl; - ss << "mDtypeAcc=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" + ss << "mCtaSwizzleType=" + << "gemm::CtaSwizzleType(" << static_cast(options.mCtaSwizzleType) << ")" + << "," << std::endl; + ss << "mDtypeAcc=" + << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" + << "," << std::endl; + ss << "mDtypeA=" + << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" + << "," << std::endl; + ss << "mDtypeB=" + << "trtllm::gen::Dtype(" << static_cast(options.mDtypeB) << ")" + << "," << std::endl; + ss << "mDtypeC=" + << "trtllm::gen::Dtype(" << static_cast(options.mDtypeC) << ")" << "," << std::endl; - ss << "mDtypeA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" << "," - << std::endl; - ss << "mDtypeB=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeB) << ")" << "," - << std::endl; - ss << "mDtypeC=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeC) << ")" << "," - << std::endl; - ss << "mDtypeMmaA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaA) << ")" + ss << "mDtypeMmaA=" + << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaA) << ")" << "," << std::endl; - ss << "mDtypeMmaB=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaB) << ")" + ss << "mDtypeMmaB=" + << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaB) << ")" << "," << std::endl; ss << "mEnablesEarlyExit=" << options.mEnablesEarlyExit << "," << std::endl; ss << "mEnablesDelayedEarlyExit=" << options.mEnablesDelayedEarlyExit << "," << std::endl; @@ -436,19 +455,22 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mHoistLoadTaskInit=" << options.mHoistLoadTaskInit << "," << std::endl; ss << "mHoistMmaTaskTryWaits=" << options.mHoistMmaTaskTryWaits << "," << std::endl; ss << "mK=" << options.mK << "," << std::endl; - ss << "mKernelTraits={}" << "," << std::endl; - ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" << "," - << std::endl; - ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" << "," - << std::endl; + ss << "mKernelTraits={}" + << "," << std::endl; + ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" + << "," << std::endl; + ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" + << "," << std::endl; ss << "mM=" << options.mM << "," << std::endl; ss << "mMmaK=" << options.mMmaK << "," << std::endl; - ss << "mMmaKind=" << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" + ss << "mMmaKind=" + << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" << "," << std::endl; ss << "mMmaM=" << options.mMmaM << "," << std::endl; ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; ss << "mN=" << options.mN << "," << std::endl; + ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl; ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl; ss << "mNumStages=" << options.mNumStages << "," << std::endl; @@ -461,23 +483,30 @@ inline std::string dumpOptions(GemmOptions const& options) { if (options.mSfBlockSizeA.has_value()) { ss << "mSfBlockSizeA=" << options.mSfBlockSizeA.value() << "," << std::endl; } else { - ss << "mSfBlockSizeA=" << "std::nullopt" << ", " << std::endl; + ss << "mSfBlockSizeA=" + << "std::nullopt" + << ", " << std::endl; } - ss << "mSfLayoutA=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutA) << ")" + ss << "mSfLayoutA=" + << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutA) << ")" << "," << std::endl; - ss << "mSfLayoutB=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutB) << ")" + ss << "mSfLayoutB=" + << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutB) << ")" << "," << std::endl; - ss << "mSfLayoutC=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutC) << ")" + ss << "mSfLayoutC=" + << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutC) << ")" << "," << std::endl; ss << "mSfReshapeFactor=" << options.mSfReshapeFactor << "," << std::endl; ss << "mSliceK=" << options.mSliceK << "," << std::endl; - ss << "mSplitK=" << "gemm::SplitK(" << static_cast(options.mSplitK) << ")" << "," - << std::endl; + ss << "mSplitK=" + << "gemm::SplitK(" << static_cast(options.mSplitK) << ")" + << "," << std::endl; ss << "mTileK=" << options.mTileK << "," << std::endl; ss << "mTileM=" << options.mTileM << "," << std::endl; ss << "mTileN=" << options.mTileN << "," << std::endl; - ss << "mTileScheduler=" << "gemm::TileScheduler(" << static_cast(options.mTileScheduler) - << ")" << "," << std::endl; + ss << "mTileScheduler=" + << "gemm::TileScheduler(" << static_cast(options.mTileScheduler) << ")" + << "," << std::endl; ss << "mTransposeMmaOutput=" << options.mTransposeMmaOutput << "," << std::endl; ss << "mUseCustomMmaSchedule=" << options.mUseCustomMmaSchedule << "," << std::endl; ss << "mUseDeepSeekFp8=" << options.mUseDeepSeekFp8 << "," << std::endl; @@ -673,18 +702,27 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { - // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. - int newTileM = 128 * divUp(options.mTileM, 128); - TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, - ") for MmaKind=", gemm::toString(options.mMmaKind), - ". Setting MmaM to 128 and TileM to ", newTileM); - if (updateOptions) { - options.mMmaM = 128; - options.mTileM = newTileM; + if (options.mClusterDimX == 1) { + // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. + int newTileM = 128 * divUp(options.mTileM, 128); + TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, + ") for MmaKind=", gemm::toString(options.mMmaKind), + ". Setting MmaM to 128 and TileM to ", newTileM); + if (updateOptions) { + options.mMmaM = 128; + options.mTileM = newTileM; + } else { + return false; + } } else { - return false; + TLLM_CHECK_ERROR(options.mMmaM == 256 && options.mTileM == 128, + "2CTA UTCxMMA only supports mmaM = 256 and tileM = 128."); } } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(options.mLayoutB != MatrixLayout::BlockMajorK, + "layoutB == MatrixLayout::BlockMajorK is not supported for now"); + } if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); @@ -869,14 +907,26 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (!options.mSliceK) { - TLLM_CHECK_ERROR(options.mMmaM <= options.mEpilogueTileM, + TLLM_CHECK_ERROR(options.mMmaM / options.mClusterDimX <= options.mEpilogueTileM, "EpilogueTileM must be larger or equal than mmaM."); + } else { + // FIXME: this is not necessary limitation. Simply fixing num repeats in TmemSliceKA should be + // enough. + TLLM_CHECK_ERROR((options.mTileN & (options.mTileN - 1)) == 0, + "For Slice-K TileN is required to be a power of 2"); } + + if (options.mClusterDimX == 2) { + TLLM_CHECK_ERROR(options.mMmaM == 256, "Only mmaM = 256 is supported for 2CTA UTCMMA."); + TLLM_CHECK_ERROR(options.mMmaN % 16 == 0, "mmaN needs to be multiple of 16 for 2CTA UTCMMA."); + } + TLLM_CHECK_ERROR( options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); - TLLM_CHECK_ERROR(options.mClusterDimX == 1 && options.mClusterDimY == 1, - "GEMM does not support cluster in X and Y dimensions."); + TLLM_CHECK_ERROR( + (options.mClusterDimX == 1 || options.mClusterDimX == 2) && options.mClusterDimY == 1, + "GEMM does not support cluster in X and Y dimensions."); TLLM_CHECK_ERROR(options.mClusterDimZ == 1 || options.mNumSlicesForSplitK > 1, "Cluster DimZ is only allowed for split-k."); TLLM_CHECK_ERROR(options.mTileM <= 128, "GEMM does not support TileM > 128."); @@ -1003,6 +1053,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Non-DeepSeekFp8 requires persistent scheduler when using numStagesMma >1"); } } + if (options.mUseDeepSeekFp8) { + TLLM_CHECK_ERROR(options.mClusterDimX == 1, "2CTA Gemm is not supported for DeepSeekFp8"); + } if (options.mUseDeepSeekFp8) { TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for DeepSeek Fp8. Found dtypeA=", @@ -1259,7 +1312,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, - options.mUsePerTokenSfA, options.mUsePerTokenSfB, options.mBiasType); + options.mUsePerTokenSfA, options.mUsePerTokenSfB, + /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); } return true; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index eba3f54737..7e0474bb5f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -137,12 +137,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in std::vector shape = {static_cast(hiddenSize), static_cast(numTokens)}; if (useTmaOobOpt /* also implies input/output activation */) { - // If TMA OOB optimization is used, we use 3D logical shape (M, tileM, K) or (N, tileN, K). - // The outer dimension is extended to make room for the possible counterbalance positive - // offset from the middle "bound" dimension. The counterbalance should be no more than - // ctaTileNumTokens. + // If TMA OOB optimization is used: + // Shape [hidden, tokens] Stride [1, hidden] becomes + // Shape [hidden, tileN, TmaDimMax, TmaDimMax] Stride [1, hidden, XLargeN - hidden, hidden] shape = {static_cast(hiddenSize), static_cast(ctaTileNumTokens), - static_cast(numTokens + ctaTileNumTokens)}; + static_cast(tg::TmaDimMax), static_cast(tg::TmaDimMax)}; } else if (isWeights) { // If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K). shape = {static_cast(hiddenSize), static_cast(numTokens), @@ -153,7 +152,8 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Swap the first two dimension as mentioned before. std::vector stride = {1, static_cast(hiddenSize)}; if (useTmaOobOpt) { - stride = {1, static_cast(hiddenSize), static_cast(hiddenSize)}; + stride = {1, static_cast(hiddenSize), static_cast(tg::XLargeN - hiddenSize), + static_cast(hiddenSize)}; } else if (isWeights) { stride = {1, static_cast(hiddenSize), static_cast(hiddenSize) * static_cast(numTokens)}; @@ -164,6 +164,10 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Alternate layouts (MajorMn and BlockMajorK) do not apply to matrixC if (matrixType != MatrixType::MatrixC) { + // When using 2CTA MMA, we only need to load half of the tile in each CTA for B. + if (matrixType == MatrixType::MatrixB && tileShape[1] > 1 && options.mClusterDimX == 2) { + tileShape[1] /= 2; + } gemm::MatrixLayout layout = (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB; // Note, only the weights support non MajorK layouts @@ -290,6 +294,7 @@ static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType } return std::make_tuple(std::vector{}, std::vector{}, std::vector{}); } + template static KernelParams setKernelParams( GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB, void* ptrC, @@ -428,7 +433,7 @@ static KernelParams setKernelParams( tg::Dtype const dTypeSf = (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; - if (batchedGemm::doesRouteImplUseTma(options.mRouteImpl)) { + if (batchedGemm::doesRouteImplUseTma(options.mRouteSfsImpl.value())) { // The input is NOT padded: // [act0, act1, act2, ...] @@ -445,7 +450,7 @@ static KernelParams setKernelParams( params.tmaSfB[0] = gemm::buildNdTmaDescriptor( dTypeSf, options.mMmaKind, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB), /*doSwizzle*/ true); - } else if (batchedGemm::doesRouteImplUseNoRoute(options.mRouteImpl)) { + } else if (batchedGemm::doesRouteImplUseNoRoute(options.mRouteSfsImpl.value())) { // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] @@ -473,7 +478,6 @@ static KernelParams setKernelParams( } else { params.ptrC = ptrC; } - } else { // B is the expert if (0 != options.mN % options.mTileN) { @@ -508,7 +512,7 @@ static KernelParams setKernelParams( tg::Dtype const dTypeSf = (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; - if (options.mRouteImpl == batchedGemm::RouteImpl::NoRoute) { + if (options.mRouteSfsImpl.value() == batchedGemm::RouteImpl::NoRoute) { // The input is padded: // [act0, padding, padding, ... tileM size .., act1, padding, padding, ...] auto const inputNumTokensSfA = ctaOffset * options.mTileM; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 640b3a69f0..4d79f83c23 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -163,7 +163,7 @@ class KernelTraits { int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, - bool usePerTokenSfB, BiasType biasType) + bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) : mMmaKind{mmaKind} { // // SMEM @@ -213,8 +213,8 @@ class KernelTraits { // LoadB { // Number of bytes in load B shared memory. - auto const numSmemBytesLoadB = - numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; + auto const numSmemBytesLoadB = numStages * (useTwoCtas ? tileN / 2 : tileN) * tileK * + getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h index 393949a516..c7f1020dea 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h @@ -23,6 +23,23 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA OOB optimization constants. +// +// CUDA Programming Guide states that "globalDim must be non-zero and less than or equal to 2^32". +// In practice, the kernel acts funny with TMA shape of 2^32 so we use 2^31. +constexpr unsigned long TmaDimMax = 1UL << 31; +// Chosen so that LargeN * XLargeN * sizeof(dtype) >= 2^64 which causes overflow and effectively +// becomes 0. As sizeof(dtype) can be as small as 0.5B, we choose LargeN = 2^30 and XLargeN = 2^35 +// so overflow can happen. +constexpr unsigned long LargeN = 1UL << 30; +// Used in TMA stride. Should be less than 2^40. +constexpr unsigned long XLargeN = 1UL << 35; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 5f066468e6..8cd2cc93d1 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -52,6 +52,13 @@ enum class RoutingMethodType : int64_t { Unspecified = 6, }; +inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize, + int32_t dtypeSizeBits) { + // Pad so total size exceeds 128KiB for performance reasons + int32_t minNumTokensRequired = common::divUp(128 * 1024 * 8, hiddenSize * dtypeSizeBits); + return std::max(numPaddedTokens, minNumTokensRequired); +} + inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethodType) { switch (routingMethodType) { case RoutingMethodType::Default: @@ -71,35 +78,41 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod }; } -inline int32_t getMaxPermutedPaddedCount(int32_t numTokens, int32_t expertsPerToken, - int32_t numExperts, int32_t padding) { - auto const expandedRowCount = numTokens * expertsPerToken; - auto const maxPaddingRequired = (padding - 1) * numExperts; - return common::roundUp(expandedRowCount + maxPaddingRequired, padding); -} - inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t tileTokensDim) { - // Get maximum number of CTAs in batch dim per expert. - auto const maxCtasInBatchDimPerExpert = common::ceilDiv(numTokens, tileTokensDim); - // Get maximum enabled experts. - auto const maxEnabledExperts = std::min(numTokens * topK, numExperts); - // Get maximum number of CTAs in batch dim. - auto maxNumCtasInBatchDim = maxEnabledExperts * maxCtasInBatchDimPerExpert; - - // For large token counts, the above bound can be pessimistic since not all the tokens can - // be routed to all the enabled experts. Instead we can essentially bound the number of CTAs - // by permuted buffer size. However, this method will be overly pessimistic for low-token - // counts - auto const tilesForPermutedBuffer = common::ceilDiv( - getMaxPermutedPaddedCount(numTokens, topK, numExperts, tileTokensDim), tileTokensDim); - - // Set maxNumCtasInBatchDim to be the minimum of the two methods - maxNumCtasInBatchDim = std::min(maxNumCtasInBatchDim, tilesForPermutedBuffer); - + // For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime. + // We launch maximally possible number of CTAs and use ptrNumNonExitingCtas to determine + // the actual number of CTAs to run. + + // Initialize number of tokens with the number of expanded tokens after routing. + int32_t numRemainingTokens = numTokens * topK; + int32_t maxNumCtasInBatchDim = 0; + // First, distribute one token each expert until token depletion to maximize CTA tile count. + int32_t numExpertsFilled = std::min(numExperts, numRemainingTokens); + maxNumCtasInBatchDim += numExpertsFilled; + numRemainingTokens -= numExpertsFilled; + // Next, greedily pour all remaining tokens to one expert to maximize CTA tile count. + // E.g., at this point tokens over 4 experts are [1, 1, 1, 1], and we have 4 tokens left. + // If each CTA handles 4 tokens/expert, the greedy strategy is to pour all remaining tokens + // to any one expert to get to the 5th CTA tile. Otherwise, we can only get 4 tiles in total. + // + // Another way to reason about this is to pour the remaining tokens into buckets of some fixed + // capacity. These buckets, if full, can then be attributed to any expert; it does not have to + // belong to the same expert every time. + if (numRemainingTokens > 0) { + // For every tileTokenDim tokens, we add an extra CTA tile in the token dimension. + // The number of CTA tiles is given by divDown(numRemainingTokens, tokenTileDim). + maxNumCtasInBatchDim += (numRemainingTokens / tileTokensDim); + } return maxNumCtasInBatchDim; } +inline int32_t getMaxPermutedPaddedCount(int32_t numTokens, int32_t expertsPerToken, + int32_t numExperts, int32_t padding) { + int32_t maxCtas = getMaxNumCtasInBatchDim(numTokens, expertsPerToken, numExperts, padding); + return maxCtas * padding; +} + class Runner { public: explicit Runner(); diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 880c739259..2765143cbf 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -17,7 +17,6 @@ from abc import ABC, abstractmethod from enum import IntEnum from typing import Dict - import pytest import torch from cuda.bindings import runtime @@ -1839,7 +1838,7 @@ def cache_permute_indices(): @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) @pytest.mark.parametrize("hidden_size", [1024, 8192]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384]) +@pytest.mark.parametrize("intermediate_size", [384, 768, 1024, 2048]) @pytest.mark.parametrize( "moe_impl", [ @@ -2075,7 +2074,17 @@ def test_moe_quantization_classes( routed_scaling = routing_config["routed_scaling"] num_experts = routing_config["num_experts"] routing_method_type = routing_config["routing_method_type"] - tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k) + + tile_tokens_dim = calculate_tile_tokens_dim( + num_tokens, + num_experts, + top_k, + max_tile_tokens_dim=128 + if ( + type(moe_impl) is FP4Moe and moe_impl.quant_mode != QuantMode.FP4_MXFP4_Bf16 + ) + else 64, + ) # Validation checks assert top_k <= num_experts @@ -2234,35 +2243,3 @@ def test_moe_quantization_classes( rtol=tolerances["rtol"], percent=tolerances["percent"], ) - - -if __name__ == "__main__": - # pytest.main([__file__, "-v"]) - routing_config = { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP8BlockScaleMoe, - ], - } - weight_processing = { - "use_shuffled_weight": False, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - } - test_moe_quantization_classes( - num_tokens=4, - hidden_size=1024, - intermediate_size=1024, - moe_impl=FP8BlockScaleMoe(), - routing_config=routing_config, - weight_processing=weight_processing, - gated_act_type=GatedActType.SwiGlu, - cache_permute_indices=cache_permute_indices, - )