-
Notifications
You must be signed in to change notification settings - Fork 676
feat: add LSE return support to TRT LLM attention kernels #2332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit extends all TRT LLM attention APIs to optionally return Log-Sum-Exp (LSE) values, addressing issue #2169. Changes: - C++ bindings: Added LSE parameter to trtllm_paged_attention_decode and trtllm_paged_attention_context functions - Allocated softmaxStatsPtr buffer when LSE is requested for multi-block kernels - Python APIs: Added return_lse and lse parameters to: * trtllm_batch_decode_with_kv_cache * trtllm_batch_context_with_kv_cache * trtllm_batch_decode_with_kv_cache_mla - All APIs now return (output, lse) tuple when return_lse=True - XQA backend raises ValueError for LSE requests (not supported) - Auto-allocates LSE tensor if not provided when return_lse=True The underlying kernels already supported LSE through the lsePtr and softmaxStatsPtr fields in TllmGenFmhaRunnerParams. This change exposes that capability through the Python API. Fixes #2169 Co-authored-by: Zihao Ye <[email protected]>
|
Warning Rate limit exceeded@yzh119 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 3 minutes and 59 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (4)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the TRT LLM attention kernels by exposing the previously internal Log-Sum-Exp (LSE) values through both C++ and Python APIs. This feature allows users to optionally retrieve LSE outputs, which can be crucial for advanced use cases such as improved beam search, confidence estimation, or other downstream tasks that benefit from the raw attention scores' statistical properties. The changes integrate seamlessly into existing attention functions, providing a flexible way to access this valuable information without altering the core kernel logic. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully extends the TRT LLM attention APIs to optionally return Log-Sum-Exp (LSE) values. The changes are consistently applied across the C++ bindings and the Python APIs for decode, MLA, and context attention modes. The implementation correctly handles the allocation of the LSE tensor and adjusts the return types.
My main feedback is on the C++ implementation in csrc/trtllm_fmha_kernel_launcher.cu, where hardcoded values are used for buffer allocation, which could be optimized. I've left a specific comment with a suggestion for this. Additionally, there's a minor point about a potentially misleading comment that could be improved for better code clarity.
Overall, this is a great addition that exposes useful functionality from the underlying kernels.
| if (lse != nullptr) { | ||
| size_t max_batch_size = 8192; | ||
| size_t max_num_qo_heads = 256; | ||
| size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); | ||
| runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>( | ||
| num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); | ||
| runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>( | ||
| sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace"); | ||
| runner_params.multiCtasKvScratchPtr = | ||
| float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for allocating LSE-related buffers has a couple of areas for improvement:
-
Hardcoded values:
max_batch_sizeandmax_num_qo_headsare hardcoded. This can lead to inefficient workspace memory allocation. It's better to use the dynamicbatch_sizeandnum_qo_headsvalues from the function arguments to make the allocation more precise. -
Unclear allocations: The preceding comment mentions allocating the
softmaxStatsPtr, but this block also allocatesmultiCtasKvCounterPtrandmultiCtasKvScratchPtr. These seem related to multi-CTA mode, which is disabled for context attention. If these are not needed for LSE computation in context mode, they should be removed. If they are needed, the comment should be updated for clarity.
The suggestion below addresses the hardcoded values.
if (lse != nullptr) {
size_t num_semaphores = round_up(batch_size * num_qo_heads, 8);
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace");
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
Code Review: LSE Return Support for TRT LLM Attention KernelsThank you for this contribution! The implementation successfully exposes LSE functionality from the underlying TRT LLM kernels. I have reviewed the changes and have feedback: Strengths
Critical Issue: Workspace Buffer AllocationIn csrc/trtllm_fmha_kernel_launcher.cu:163-174, Context mode allocates multiCtasKvCounterPtr when LSE is requested, even though mMultiCtasKvMode=false. This wastes 8MB of workspace buffer. The counter should only be allocated when multi-block mode is enabled. Remove lines 168-169 from the Context LSE branch. Potential Issues
Recommendations Before Merging
Overall solid implementation - main concern is the buffer allocation issue which should be straightforward to fix. Review via Claude Code |
This PR extends all TRT LLM attention APIs to optionally return Log-Sum-Exp (LSE) values.
Changes
trtllm_paged_attention_decodeandtrtllm_paged_attention_contextsoftmaxStatsPtrbuffer when LSE is requested for multi-block kernelsreturn_lseandlseparameters to:trtllm_batch_decode_with_kv_cachetrtllm_batch_context_with_kv_cachetrtllm_batch_decode_with_kv_cache_mla(output, lse)tuple whenreturn_lse=TrueValueErrorfor LSE requests (not supported)The underlying kernels already supported LSE through the
lsePtrandsoftmaxStatsPtrfields inTllmGenFmhaRunnerParams. This change exposes that capability through the Python API.Fixes #2169
Generated with Claude Code