Skip to content

Commit 18fd91a

Browse files
authored
bugfix: Fix the bug of the kernel-selection heuristic in trtllm-gen (#1307)
<!-- .github/pull_request_template.md --> ## 📌 Description this fixes the bug of still selecting low-latency (swapsMmaAb) MLA kernels when batch size is quite large under the high-throughput case (attention DP is used). The accuracy won't be impacted, but it might have much worse performance without the fix. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Perkz Zheng <[email protected]>
1 parent 9c609f0 commit 18fd91a

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,19 @@ class TllmGenFmhaKernel {
397397
clusterDimX);
398398
}
399399

400-
// Compute the seqLenPerCtaKv for selecting the MLA generation kernel.
401-
int computeSeqLenPerCtaKv(RunnerParams const& params) const {
400+
// Determine if we should use the SwapsMmaAbForGeneration kernel for MLA generation.
401+
bool useSwapsMmaAbMlaGenKernel(RunnerParams const& params) const {
402+
// Use the SwapsMmaAbForGeneration kernel for MLA generation when the following conditions are
403+
// met:
404+
// 1. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
405+
// later).
406+
// 2. The numCtas (after splitting the heads across multiple CTAs) <=
407+
// params.mMultiProcessorCount.
408+
402409
// The maximum number Ctas per Kv sequence, which makes sure that each CtaKv has work to do.
403410
// Here we assume the stepKv is 256.
404411
int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
412+
;
405413
// The number of Ctas.
406414
int const numCtas = static_cast<int32_t>(params.mBatchSize * params.mMaxSeqLenQ *
407415
divUp(params.mNumHeadsQPerKv, 16));
@@ -410,8 +418,8 @@ class TllmGenFmhaKernel {
410418
std::min(maxNumCtasPerSeqKv, std::max(1, int32_t(params.mMultiProcessorCount / numCtas)));
411419
// Compute the seqLenPerCtaKv.
412420
int const seqLenPerCtaKv = flashinfer::ceil_div(params.mMaxSeqLenKv, numCtasPerSeqKv);
413-
// Return the seqLenPerCtaKv.
414-
return seqLenPerCtaKv;
421+
// Whether we should use the SwapsMmaAbForGeneration kernel for MLA generation.
422+
return seqLenPerCtaKv <= 1024 && numCtas <= params.mMultiProcessorCount;
415423
}
416424

417425
std::pair<uint64_t, std::string> hashFromRunnerParams(
@@ -424,10 +432,12 @@ class TllmGenFmhaKernel {
424432
// following conditions are met:
425433
// 1. The number of headsQPerKv is <= 32.
426434
// 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
427-
// later).
435+
// later) and
436+
// the numCtas (after splitting the heads across multiple CTAs) <=
437+
// params.mMultiProcessorCount.
428438

429439
// Check the conditions.
430-
if (params.mNumHeadsQPerKv <= 32 || computeSeqLenPerCtaKv(params) <= 1024) {
440+
if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params)) {
431441
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
432442
} else {
433443
// Otherwise, we use the high-throughput kernel.

0 commit comments

Comments
 (0)