diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 99a51039bc..d0017c6630 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -135,7 +135,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "f1ed60e5666a7620683a8c34a41c850a25029b35/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "e7afc4134bb53eaab63fb85163d5943fb190621c/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/" ) @@ -155,7 +155,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "10a54e8c3175099481aed2739ae30fa0f782368c40f9ad1b423ed8353315d65b" + "5bd87798e560a63e883902fc5468146ffff0d3551bf337d2f81bd02893e9dc39" ) TRTLLM_GEN_BMM: str = ( "0af823880730c4f0b3832d2208fab035946694b83444410b9309db5613d60195" diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index aa0e640ede..1e64fb329f 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -207,119 +207,130 @@ class TllmGenFmhaKernel { // start here void run(RunnerParams const& params) const { - // The selectKernelParams that might be updated. SelectKernelParams selectKernelParams{params}; - // The parameters for launching the kernel. CtaLaunchParams ctaLaunchParams; - // The iteration index (used to detect a deadlock of selecting new kernels). - int selectKernelIter = 0; - // While loop. - while (true) { - // Any value >= 2 should work here, but we set it larger in case that we - // might have more complicated heuristic in the future. - FLASHINFER_CHECK(selectKernelIter < 8, - "A deadlock is detected when selecting trtllm-gen kernels."); - - // Select the kernel. - selectKernel(params, selectKernelParams); - // Load the kernel. - auto [func, kernelMeta] = loadKernel(params, selectKernelParams); - // Compute the number of CTAs in X, Y and Z dimension and the cluster size in the X dimension. + // Kernel selection loop (bounded). Each pass may update selectKernelParams (e.g. switch + // MultiCtasKvMode to Disabled, upgrade to CgaSmemReduction, or reduce headDimPerCtaV) and + // request a re-select via mSelectNewKernel. Each trigger fires at most once, so the sequence + // converges in at most kMaxKernelSelectionPasses passes. + static constexpr int kMaxKernelSelectionPasses = 4; + CUfunction func{}; + KernelMeta kernelMeta{}; + for (int pass = 0; pass < kMaxKernelSelectionPasses; ++pass) { + selectKernel(params, selectKernelParams); + std::tie(func, kernelMeta) = loadKernel(params, selectKernelParams); computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParams); - - // Need to select a new kernel if mSelectNewKernel is true. - if (selectKernelParams.mSelectNewKernel) { - selectKernelIter++; - continue; + if (!selectKernelParams.mSelectNewKernel) { + break; } + FLASHINFER_CHECK(pass + 1 < kMaxKernelSelectionPasses, + "trtllm-gen kernel selection did not converge in %d passes.", + kMaxKernelSelectionPasses); + } - // Prepare the kernel parameters. - auto kernelParams = KernelParams::setKernelParams( - params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); - - // Prepare kernel parameters list for cuLaunchKernelEx. - void* kernelParamsList[] = {&kernelParams}; - CUlaunchConfig launch_config; - launch_config.blockDimX = kernelMeta.mThreadsPerCTA; - launch_config.blockDimY = 1; - launch_config.blockDimZ = 1; - launch_config.gridDimX = ctaLaunchParams.mNumCtasX; - launch_config.gridDimY = ctaLaunchParams.mNumCtasY; - launch_config.gridDimZ = ctaLaunchParams.mNumCtasZ; - launch_config.hStream = params.stream; - launch_config.sharedMemBytes = kernelMeta.mSharedMemBytes; - - // Debug info. - IKL_LOG_DEBUG("TRTLLM-Gen launch info (in TllmGenFmhaKernel %s, %s, %s, %d): kernelName = %s", - toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM, kernelMeta.mFuncName); - IKL_LOG_DEBUG( - "TRTLLM-Gen launch info: maxSeqLenQ = %d, " - "maxSeqLenKv = %d, " - "numHeadsQ = %d, " - "numHeadsKv = %d, batchSize = %d, kernelType = %d", - params.mMaxSeqLenQ, params.mMaxSeqLenKv, params.mNumHeadsQ, params.mNumHeadsKv, - params.mBatchSize, static_cast(params.mKernelType)); - IKL_LOG_DEBUG( - "TRTLLM-Gen launch info: numCtasX = %d, numCtasY = %d, numCtasZ = %d, clusterDimX = %d", - ctaLaunchParams.mNumCtasX, ctaLaunchParams.mNumCtasY, ctaLaunchParams.mNumCtasZ, - ctaLaunchParams.mClusterDimX); - - CUlaunchAttribute launch_attribute[3]; - launch_attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launch_attribute[0].value.clusterDim.x = ctaLaunchParams.mClusterDimX; - launch_attribute[0].value.clusterDim.y = 1; - launch_attribute[0].value.clusterDim.z = 1; - launch_attribute[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launch_attribute[1].value.clusterSchedulingPolicyPreference = - ctaLaunchParams.mClusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD - : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; - launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; - launch_attribute[2].value.programmaticStreamSerializationAllowed = params.enable_pdl; - - launch_config.attrs = launch_attribute; - launch_config.numAttrs = 3; - // Add setting for non-portable cluster size. - if (ctaLaunchParams.mClusterDimX > 8) { - cuErrCheck(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, - 1 // Enable non-portable cluster sizes - )); - } + // Prepare the kernel parameters. + auto kernelParams = KernelParams::setKernelParams( + params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); - // Force using GmemReduction for the multiCtasKvMode if the CgaSmemReduction needs more than - // one wave (due to the cluster occupancy limit). - // TODO: find a better heuristic of using CgaSmemReduction. - if (isCgaSmemReduction(selectKernelParams.mMultiCtasKvMode)) { - // The maximum number of active clusters that could co-exist. - int maxActiveClusters = 1; - cuErrCheck(cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &launch_config)); - // Use the GmemReduction instead if it needs more than one wave. - if (maxActiveClusters * ctaLaunchParams.mClusterDimX < - (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ)) { - selectKernelParams.mForceGmemReduction = true; - selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; - // continue to select a new kernel. - continue; - } - } + void* kernelParamsList[] = {&kernelParams}; + CUlaunchAttribute launch_attribute[3]; + CUlaunchConfig launch_config; + buildLaunchConfig(launch_config, launch_attribute, kernelMeta, ctaLaunchParams, params); - cuErrCheck(cuLaunchKernelEx(&launch_config, func, kernelParamsList, nullptr)); + // Debug info. + IKL_LOG_DEBUG("TRTLLM-Gen launch info (in TllmGenFmhaKernel %s, %s, %s, %d): kernelName = %s", + toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM, kernelMeta.mFuncName); + IKL_LOG_DEBUG( + "TRTLLM-Gen launch info: maxSeqLenQ = %d, " + "maxSeqLenKv = %d, " + "numHeadsQ = %d, " + "numHeadsKv = %d, batchSize = %d, kernelType = %d", + params.mMaxSeqLenQ, params.mMaxSeqLenKv, params.mNumHeadsQ, params.mNumHeadsKv, + params.mBatchSize, static_cast(params.mKernelType)); + IKL_LOG_DEBUG( + "TRTLLM-Gen launch info: numCtasX = %d, numCtasY = %d, numCtasZ = %d, clusterDimX = %d", + ctaLaunchParams.mNumCtasX, ctaLaunchParams.mNumCtasY, ctaLaunchParams.mNumCtasZ, + ctaLaunchParams.mClusterDimX); - // Run the separate reduction kernel if needed. - tensorrt_llm::kernels::runFmhaReduction(kernelMeta, kernelParams, params.mMultiProcessorCount, - params.enable_pdl, params.stream); + setNonPortableClusterIfNeeded(func, ctaLaunchParams); - if (params.lsePtr != nullptr) { - flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr, - params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl, - params.stream); + // Force GmemReduction if CgaSmemReduction would need more than one wave (cluster occupancy + // limit). TODO: find a better heuristic of using CgaSmemReduction. + if (isCgaSmemReduction(selectKernelParams.mMultiCtasKvMode)) { + int maxActiveClusters = 1; + cuErrCheck(cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &launch_config)); + if (maxActiveClusters * ctaLaunchParams.mClusterDimX < + ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ) { + selectKernelParams.mForceGmemReduction = true; + selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; + selectKernel(params, selectKernelParams); + std::tie(func, kernelMeta) = loadKernel(params, selectKernelParams); + computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParams); + FLASHINFER_CHECK(!selectKernelParams.mSelectNewKernel, + "trtllm-gen kernel selection did not converge after CgaSmemReduction " + "fallback to GmemReduction."); + // Rebuild kernelParams: setKernelParams uses kernelMeta (TMA descriptors, tile shapes) + // which changed when switching from CgaSmemReduction to GmemReduction kernel. + kernelParams = KernelParams::setKernelParams( + params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); + buildLaunchConfig(launch_config, launch_attribute, kernelMeta, ctaLaunchParams, params); + setNonPortableClusterIfNeeded(func, ctaLaunchParams); } - // Break the while op. - break; + } + + cuErrCheck(cuLaunchKernelEx(&launch_config, func, kernelParamsList, nullptr)); + + // Run the separate reduction kernel if needed. + tensorrt_llm::kernels::runFmhaReduction(kernelMeta, kernelParams, params.mMultiProcessorCount, + params.enable_pdl, params.stream); + + if (params.lsePtr != nullptr) { + flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr, + params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl, + params.stream); } } private: + // Fill a CUlaunchConfig and its associated attribute array from the current kernel and CTA + // params. The caller owns the storage for launch_attribute (must be an array of at least 3 + // elements) and is responsible for ensuring it outlives launch_config. + void buildLaunchConfig(CUlaunchConfig& launch_config, CUlaunchAttribute* launch_attribute, + KernelMeta const& kernelMeta, CtaLaunchParams const& ctaLaunchParams, + RunnerParams const& params) const { + launch_config.blockDimX = kernelMeta.mThreadsPerCTA; + launch_config.blockDimY = 1; + launch_config.blockDimZ = 1; + launch_config.gridDimX = ctaLaunchParams.mNumCtasX; + launch_config.gridDimY = ctaLaunchParams.mNumCtasY; + launch_config.gridDimZ = ctaLaunchParams.mNumCtasZ; + launch_config.hStream = params.stream; + launch_config.sharedMemBytes = kernelMeta.mSharedMemBytes; + launch_attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launch_attribute[0].value.clusterDim.x = ctaLaunchParams.mClusterDimX; + launch_attribute[0].value.clusterDim.y = 1; + launch_attribute[0].value.clusterDim.z = 1; + launch_attribute[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launch_attribute[1].value.clusterSchedulingPolicyPreference = + ctaLaunchParams.mClusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD + : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; + launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + launch_attribute[2].value.programmaticStreamSerializationAllowed = params.enable_pdl; + launch_config.attrs = launch_attribute; + launch_config.numAttrs = 3; + } + + // Enable non-portable cluster sizes when clusterDimX exceeds the portable limit of 8. + void setNonPortableClusterIfNeeded(CUfunction func, + CtaLaunchParams const& ctaLaunchParams) const { + if (ctaLaunchParams.mClusterDimX > 8) { + cuErrCheck(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1 // Enable non-portable cluster sizes + )); + } + } + // Is it MLA generation kernel ? inline bool isMlaGenKernel(RunnerParams const& params) const { return params.mHeadDimQk == 576 && params.mHeadDimV == 512; @@ -432,10 +443,13 @@ class TllmGenFmhaKernel { // Enable the CgaSmemReduction if the numCtasPerSeqKv <= 16 as the maximum cluster dimension // is 16. Only the swapsMmaAbForGeneration kernel supports the CgaSmemReduction for now. + // CgaSmemReduction exceeds the shared memory limit for MLA decode with tileSizeQ >= 32 + // (headDimQk=576 requires more smem than the device allows for that tile size). if (!isDsv3MinLatencyMode && numCtasPerSeqKv > 1 && numCtasPerSeqKv <= 16 && isSwapsMmaAbForGenerationKernel(selectKernelParams.mKernelType) && isGmemReduction(selectKernelParams.mMultiCtasKvMode) && - !selectKernelParams.mForceGmemReduction) { + !selectKernelParams.mForceGmemReduction && + (!isMlaGenKernel(params) || selectKernelParams.mTileSizeQ < 32)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::CgaSmemReduction; // Need to select a different kernel. selectKernelParams.mSelectNewKernel = true; @@ -514,53 +528,106 @@ class TllmGenFmhaKernel { return seqLenPerCtaKv <= 1024 && numCtas <= params.mMultiProcessorCount; } - // Select the MLA generation kernel. - void selectMlaGenerationKernel(RunnerParams const& params, - SelectKernelParams& selectKernelParams) const { - // We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the - // following conditions are met: - // 1. The number of headsQPerKv is <= 32. - // 2. The number of headsQPerKv is < 128 for sparseMla. - // 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later) - // and - // the numCtas (after splitting the heads across multiple CTAs) <= - // params.mMultiProcessorCount. - // The sparseMla kernel will always use the 2CTA high-throughput kernel. + // Select the sparse MLA generation kernel. + // Heuristics benchmarked on B200 (SM=148, sparseMlaTopK=2048). + void selectSparseMlaGenerationKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // numHeadsQ <= 32 : SwapsMmaAbForGeneration + // tileSizeQ = numHeadsQPerKv/2 at batch=1 (GPU under-utilized with full tile; halving creates + // 2x more head-splitting CTAs), or numHeadsQPerKv at batch>=2. + // Threshold: batchSize * maxNumCtasPerSeqKv <= MP/8 (crossover at batch=1->2 on B200). + // Benchmarks (seqLen=8192, topK=2048): half tileSizeQ wins by 2-6% at batch=1; + // full tileSizeQ wins by 2-11% at batch>=2. + // numHeadsQ >= 64 : KeepsMmaAbForGeneration, tileSizeQ = 64 + // numHeadsQ=128 at large batch : 2CTA (clusterDimX=2, headDimPerCtaV=256) + // otherwise : 1CTA, headDimPerCtaV fine-tuned later + // Note: at small batch e4m3 prefers SwapsMmaAb tileSizeQ=32 (+10%), but fp16 prefers + // KeepsMmaAb tileSizeQ=64 (+19% at numHeadsQ=128). We keep KeepsMmaAb for numHeadsQ>=64 + // to avoid penalizing fp16. - // The kernel type. FmhaKernelType& kernelType = selectKernelParams.mKernelType; - // The tile size for Q. int& tileSizeQ = selectKernelParams.mTileSizeQ; - // Check the conditions. - if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) || - useSwapsMmaAbMlaGenKernel(params)) { + if (params.mNumHeadsQPerKv <= 32) { kernelType = FmhaKernelType::SwapsMmaAbForGeneration; - // Currently, only tileSizeQ = 8 or 16 are supported. - tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; + selectKernelParams.mTileSizeKv = 128; + // mMultiCtasKvMode defaults to GmemReduction from the constructor. computeCtaAndClusterConfig + // may upgrade it to CgaSmemReduction; that update is preserved naturally across + // re-selections. The base tileSizeQ is numHeadsQPerKv (one CTA covers all Q heads per token). + // At batch=1 the GPU is under-utilized, so we halve tileSizeQ to create 2x more + // head-splitting CTAs. Threshold: batchSize * maxNumCtasPerSeqKv <= MP/8. + // effectiveSeqLenKv = min(seqLen, topK) = 2048 -> maxNumCtasPerSeqKv = 16. + // Condition: batchSize * 16 <= MP/8 -> batchSize <= 1 (crossover at batch=1->2). + // Only halve when half tileSizeQ >= 8 (no valid SwapsMmaAb kernel below tileSizeQ=8). + int const fullTileSizeQ = params.mNumHeadsQPerKv; + int const halfTileSizeQ = fullTileSizeQ / 2; + int const effectiveSeqLenKv = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK); + int const maxNumCtasPerSeqKv = + flashinfer::ceil_div(effectiveSeqLenKv, selectKernelParams.mTileSizeKv); + bool const useHalfTileSizeQ = halfTileSizeQ >= 8 && params.mBatchSize * maxNumCtasPerSeqKv <= + params.mMultiProcessorCount / 8; + tileSizeQ = useHalfTileSizeQ ? halfTileSizeQ : fullTileSizeQ; } else { - // Otherwise, we use the high-throughput kernel. + // numHeadsQ >= 64: use KeepsMmaAbForGeneration. kernelType = FmhaKernelType::KeepsMmaAbForGeneration; - // Use the tileSizeQ = 64 for MLA high-throughput generation kernels. tileSizeQ = 64; - // Always use the separate reduction kernel. - if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { + selectKernelParams.mTileSizeKv = 128; + // Upgrade GmemReduction (constructor default) to GmemReductionWithSeparateKernel. + // If computeCtaAndClusterConfig already set it to Disabled (numCtasPerSeqKv==1), the + // isGmemReduction() guard is false and the Disabled state is preserved on re-selection. + if (isGmemReduction(selectKernelParams.mMultiCtasKvMode)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } - // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. - FLASHINFER_CHECK( - !params.mSparseMla || params.mNumHeadsQPerKv == 128, - "The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128, got %d", - params.mNumHeadsQPerKv); - // The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128. - if (params.mNumHeadsQPerKv == 128) { + // For numHeadsQ=128, use 2CTA when there are enough CTAs to amortize 2CTA overhead. + // numCtasPerToken = numHeadsQPerKv / tileSizeQ (number of CTAs per token per batch item). + // Benchmarks (fp16/e4m3, sparseMlaTopK=2048): + // batch=1 : 1CTA wins by ~20%; batch=8 : 1CTA wins by 3-8% + // batch=16 : 2CTA wins by 8-16%; batch=32+: 2CTA wins by 12-20% + // Threshold: batchSize * numCtasPerToken * 8 > MP -> crossover at batch ~ MP/16 ~ 9. + int const numCtasPerToken = params.mNumHeadsQPerKv / 64; + bool const use2Cta = params.mNumHeadsQPerKv == 128 && + params.mBatchSize * numCtasPerToken * 8 > params.mMultiProcessorCount; + if (use2Cta) { selectKernelParams.mUses2CtaMma = true; - // Each Cta only handles 256 headDimV. selectKernelParams.mHeadDimPerCtaV = 256; } } } + // Select the MLA generation kernel. + void selectMlaGenerationKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // The kernel type. + FmhaKernelType& kernelType = selectKernelParams.mKernelType; + // The tile size for Q. + int& tileSizeQ = selectKernelParams.mTileSizeQ; + + if (params.mSparseMla) { + selectSparseMlaGenerationKernel(params, selectKernelParams); + } else { + // Non-sparse MLA: use SwapsMmaAb when numHeadsQPerKv <= 32 or seqLenPerCtaKv is small. + bool const useSwapsMmaAb = params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params); + + if (useSwapsMmaAb) { + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + // Non-sparse MLA (legacy): tileSizeQ capped at 16. + tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; + } else { + kernelType = FmhaKernelType::KeepsMmaAbForGeneration; + tileSizeQ = 64; + // Always use the separate reduction kernel. + if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { + selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; + } + // For non-sparse MLA, always use 2CTA for numHeadsQPerKv=128 (legacy behavior). + if (params.mNumHeadsQPerKv == 128) { + selectKernelParams.mUses2CtaMma = true; + selectKernelParams.mHeadDimPerCtaV = 256; + } + } + } + } + // Selects a heuristic tileSizeQ if groupsTokensHeadsQ is true. void selectTileSizeQForGqaGeneration(RunnerParams const& params, SelectKernelParams& selectKernelParams) const { diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 6d3be7c6d4..86ce33f737 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -203,6 +203,8 @@ struct KernelParams { int32_t mSparseMlaTopK; // The flag to use block sparse attention. bool mUseBlockSparseAttention; + // Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer). + bool mUsesSharedPagedKvIdx; // Create the TMA shape/stride for Q. template @@ -828,6 +830,8 @@ struct KernelParams { params.mSparseMlaTopK = options.mSparseMlaTopK; // TODO: Integrate trtllm block-sparse attention kernels when needed. params.mUseBlockSparseAttention = false; + // Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer). + params.mUsesSharedPagedKvIdx = true; return params; } };