[Fmha] support nvfp4 output keepsMmaAb generation kernels#2795
[Fmha] support nvfp4 output keepsMmaAb generation kernels#2795PerkzZheng wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's FMHA generation kernels by enabling and optimizing support for nvfp4 output, which is crucial for improving performance, particularly in speculative decoding scenarios. The changes involve updating core cubin artifacts, refining the kernel selection mechanism to better leverage GQA generation heuristics for nvfp4, and integrating support for shared paged KV indices. Furthermore, the pull request expands the testing suite to ensure robust validation of these new capabilities across a broader range of configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThis PR simplifies the GQA generation kernel selection logic by removing a special-case dtype check and making the tile-size heuristic selection unconditional, while also expanding test coverage for generative attention decoding across multiple head dimensions (64, 128, 256) and removing dtype-specific test skips. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can approve the review once all CodeRabbit's comments are resolved.Enable the |
There was a problem hiding this comment.
Code Review
This pull request enables support for nvfp4 output in generation kernels, which involves updating cubin artifacts, adjusting kernel selection logic, and expanding test coverage. The changes appear to be well-aligned with the PR's objectives. I've identified one area for improvement concerning a hardcoded value marked with a FIXME, which should be addressed to prevent future issues.
| // FIXME: set this with options.mUsesSharedPagedKvIdx. | ||
| params.mUsesSharedPagedKvIdx = true; |
There was a problem hiding this comment.
There's a FIXME here to set mUsesSharedPagedKvIdx from options, but it's currently hardcoded to true. This introduces technical debt and could lead to issues if this parameter needs to be configurable in the future. It would be best to plumb this option through from TllmGenFmhaRunnerParams and set it dynamically.
| // FIXME: set this with options.mUsesSharedPagedKvIdx. | |
| params.mUsesSharedPagedKvIdx = true; | |
| // FIXME: set this with options.mUsesSharedPagedKvIdx. | |
| params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx; |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 818-819: Replace the hardcoded true assignment for
params.mUsesSharedPagedKvIdx with the corresponding field from the caller
options object so the flag reflects the caller's intent; locate the assignment
to params.mUsesSharedPagedKvIdx and set it from options.mUsesSharedPagedKvIdx
(or the actual options struct in scope, e.g., kernelOptions or opts) and remove
the FIXME comment so non-shared layouts are handled correctly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f11a9187-4bec-48c6-8002-e55c771a0b61
📒 Files selected for processing (4)
flashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
| // FIXME: set this with options.mUsesSharedPagedKvIdx. | ||
| params.mUsesSharedPagedKvIdx = true; |
There was a problem hiding this comment.
Wire mUsesSharedPagedKvIdx from options instead of hardcoding true.
At Line 819, forcing this to true makes all calls behave as shared-index mode and ignores caller intent, which can produce incorrect paged-KV indexing for non-shared layouts.
Suggested fix
- // FIXME: set this with options.mUsesSharedPagedKvIdx.
- params.mUsesSharedPagedKvIdx = true;
+ params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // FIXME: set this with options.mUsesSharedPagedKvIdx. | |
| params.mUsesSharedPagedKvIdx = true; | |
| params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 818 - 819,
Replace the hardcoded true assignment for params.mUsesSharedPagedKvIdx with the
corresponding field from the caller options object so the flag reflects the
caller's intent; locate the assignment to params.mUsesSharedPagedKvIdx and set
it from options.mUsesSharedPagedKvIdx (or the actual options struct in scope,
e.g., kernelOptions or opts) and remove the FIXME comment so non-shared layouts
are handled correctly.
There was a problem hiding this comment.
Code Review
This pull request enables support for nvfp4 output with keepsMmaAb generation kernels in FMHA. The changes include updating kernel selection logic, expanding test coverage by removing a pytest.skip and adding more configurations, and updating cubin artifacts. A new field mUsesSharedPagedKvIdx is introduced for vLLM/FlashInfer paged KV indices. My review includes a suggestion to address a FIXME related to the hardcoded initialization of this new field.
| // TODO: Integrate trtllm block-sparse attention kernels when needed. | ||
| params.mUseBlockSparseAttention = false; | ||
| // FIXME: set this with options.mUsesSharedPagedKvIdx. | ||
| params.mUsesSharedPagedKvIdx = true; |
There was a problem hiding this comment.
As the FIXME comment on the preceding line indicates, mUsesSharedPagedKvIdx is currently hardcoded to true. This should be properly configured from the options object instead of being hardcoded. To fix this, you'll likely need to add the mUsesSharedPagedKvIdx field to the TllmGenFmhaRunnerParams struct and ensure it's populated correctly from the calling code.
| params.mUsesSharedPagedKvIdx = true; | |
| params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx; |
| ) | ||
| @pytest.mark.parametrize("enable_pdl", [True, False, None]) | ||
| @pytest.mark.parametrize("enable_sink", [True, False]) | ||
| @pytest.mark.parametrize("enable_pdl", [False]) |
There was a problem hiding this comment.
Are enable_pdl and enable_sink also not expected to work?
There was a problem hiding this comment.
right, let me revert the changes. I used for debugging locally. Thanks!
|
/bot run |
|
@PerkzZheng is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/artifacts.py`:
- Line 138: The constant TRTLLM_GEN_FMHA currently points to an artifact path
that returns 404 and will break runtime kernel loading; verify or correct the
artifact reference by either uploading the missing files to the path referenced
by TRTLLM_GEN_FMHA or update TRTLLM_GEN_FMHA to the correct existing artifact
path and matching checksum values (ensure checksums.txt location and filenames
match the uploaded cubin assets); update only the TRTLLM_GEN_FMHA value and
associated checksum entries so runtime can fetch the CUDA kernels successfully.
|
/bot run |
|
[FAILED] Pipeline #46475427: 10/20 passed |
|
script: https://paste.ubuntu.com/p/k2WYykRJ4Y/ llama33-70b-nvfp4-tp4-pr2795: |
|
@baonudesifeizhai thank you! If I'm reading this right fused latency goes from |
|
nvidia/Llama-3.3-70B-Instruct-NVFP4 4 card tp nvidia/Llama-3.1-8B-Instruct-NVFP4 1 card....
|
@saltyminty it seems that the failing ones are not related. Is it okay to merge it ? Thanks. |
- Update cubin artifact path/checksum to new build with nvfp4 output support - Fix kernel selection: remove E2M1 output dtype condition from mixed-precision path, allowing nvfp4 output to use GQA generation kernel selection heuristics - Always invoke selectTileSizeQForGqaGeneration (not just for maxSeqLenQ > 1) - Add mUsesSharedPagedKvIdx field to KernelParams for vLLM/FlashInfer paged KV index - Remove speculative-decode skip for nvfp4 output in tests - Expand test coverage: head_dim [64, 128, 256], additional batch configs AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> revert revert revert
86c1c32 to
2632da4
Compare
Just rebased. Feel free to trigger CI again. Thanks! |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
776-777: Reduce selector overhead now that heuristic selection is unconditional.Line 777 now runs on every GQA-generation selection path. Consider removing per-call
std::unordered_mapconstruction inselectTileSizeQForGqaGenerationto avoid extra host overhead in decode-heavy workloads.♻️ Suggested refactor
- std::unordered_map<int, float> kernelMainloopCost = { - {128, 2.2}, {64, 1.68}, {32, 1.48}, {16, 1.2}, {8, 1.0} - }; - - std::unordered_map<int, float> kernelReductionCost = { - {128, 1.32}, {64, 1.2}, {32, 1.08}, {16, 1.03}, {8, 1.0} - }; + auto kernelMainloopCost = [](int tileSizeQ) -> float { + switch (tileSizeQ) { + case 128: return 2.2f; + case 64: return 1.68f; + case 32: return 1.48f; + case 16: return 1.2f; + case 8: return 1.0f; + default: return FLT_MAX; + } + }; + auto kernelReductionCost = [](int tileSizeQ) -> float { + switch (tileSizeQ) { + case 128: return 1.32f; + case 64: return 1.2f; + case 32: return 1.08f; + case 16: return 1.03f; + case 8: return 1.0f; + default: return FLT_MAX; + } + }; ... - float modelingKernelTime = kernelMainloopCost.at(tileSizeQ) * seqLenPerCtaKv + - kernelReductionCost.at(tileSizeQ) * kernelReductionSeqLenFactor * + float modelingKernelTime = kernelMainloopCost(tileSizeQ) * seqLenPerCtaKv + + kernelReductionCost(tileSizeQ) * kernelReductionSeqLenFactor * ctaLaunchParams.mMaxNumCtasKv;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 776 - 777, selectTileSizeQForGqaGeneration currently builds a std::unordered_map on every call which adds host-side overhead now that it's invoked unconditionally; change it to use a persistent cache instead of per-call construction by moving the map out of the function (e.g., a static or thread_local std::unordered_map or a member cache in the owning class) and look up/insert entries rather than recreating the container each time. Update selectTileSizeQForGqaGeneration to accept/use the persistent cache (or reference it globally) and ensure thread-safety (e.g., use a mutex or thread_local storage) when accessing the cache so repeated calls during decoding reuse the prebuilt data rather than reconstructing it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 776-777: selectTileSizeQForGqaGeneration currently builds a
std::unordered_map on every call which adds host-side overhead now that it's
invoked unconditionally; change it to use a persistent cache instead of per-call
construction by moving the map out of the function (e.g., a static or
thread_local std::unordered_map or a member cache in the owning class) and look
up/insert entries rather than recreating the container each time. Update
selectTileSizeQForGqaGeneration to accept/use the persistent cache (or reference
it globally) and ensure thread-safety (e.g., use a mutex or thread_local
storage) when accessing the cache so repeated calls during decoding reuse the
prebuilt data rather than reconstructing it.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9ea0fe68-2306-4c7d-8170-3aa1aa442fe1
📒 Files selected for processing (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuhtests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/attention/test_trtllm_gen_attention.py
|
wait TRTLLM_GEN_FMHA: str = "3fec9f12548f83f44e4ca60394a2946238a677f1/fmha/trtllm-gen/" thats the whole points |
no worries. this is expected. we have another MR just merged (#2836) which includes the required cubins for this MR. |
AI-assisted
📌 Description
Qwen3-480B (num_qo_heads=96, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128)
Speedup (baseline / opt)
GPT-OSS (num_qo_heads=64, num_kv_heads=8, head_dim_qk=64, head_dim_vo=64)
Speedup (baseline / opt)
Summary
Speedup scales strongly with
s_qo(speculative decode query length):s_qo=2: 1.1–1.8x speedup across both modelss_qo=4: 1.9–2.6x speedups_qo=8: 2.8–5.1x speedup (peak 5.12x on GPT-OSS, bs=32)🔍 Related Issues
#2632
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests