Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 11, 2026

This PR extends all TRT LLM attention APIs to optionally return Log-Sum-Exp (LSE) values.

Changes

  • C++ bindings: Added LSE parameter to trtllm_paged_attention_decode and trtllm_paged_attention_context
  • Allocated softmaxStatsPtr buffer when LSE is requested for multi-block kernels
  • Python APIs: Added return_lse and lse parameters to:
    • trtllm_batch_decode_with_kv_cache
    • trtllm_batch_context_with_kv_cache
    • trtllm_batch_decode_with_kv_cache_mla
  • All APIs return (output, lse) tuple when return_lse=True
  • XQA backend raises ValueError for LSE requests (not supported)

The underlying kernels already supported LSE through the lsePtr and softmaxStatsPtr fields in TllmGenFmhaRunnerParams. This change exposes that capability through the Python API.

Fixes #2169

Generated with Claude Code

This commit extends all TRT LLM attention APIs to optionally return
Log-Sum-Exp (LSE) values, addressing issue #2169.

Changes:
- C++ bindings: Added LSE parameter to trtllm_paged_attention_decode
  and trtllm_paged_attention_context functions
- Allocated softmaxStatsPtr buffer when LSE is requested for multi-block
  kernels
- Python APIs: Added return_lse and lse parameters to:
  * trtllm_batch_decode_with_kv_cache
  * trtllm_batch_context_with_kv_cache
  * trtllm_batch_decode_with_kv_cache_mla
- All APIs now return (output, lse) tuple when return_lse=True
- XQA backend raises ValueError for LSE requests (not supported)
- Auto-allocates LSE tensor if not provided when return_lse=True

The underlying kernels already supported LSE through the lsePtr and
softmaxStatsPtr fields in TllmGenFmhaRunnerParams. This change exposes
that capability through the Python API.

Fixes #2169

Co-authored-by: Zihao Ye <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 11, 2026

Warning

Rate limit exceeded

@yzh119 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 3 minutes and 59 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 2062dec and 0c753fa.

📒 Files selected for processing (4)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/prefill.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the TRT LLM attention kernels by exposing the previously internal Log-Sum-Exp (LSE) values through both C++ and Python APIs. This feature allows users to optionally retrieve LSE outputs, which can be crucial for advanced use cases such as improved beam search, confidence estimation, or other downstream tasks that benefit from the raw attention scores' statistical properties. The changes integrate seamlessly into existing attention functions, providing a flexible way to access this valuable information without altering the core kernel logic.

Highlights

  • LSE Return Support: Enabled optional return of Log-Sum-Exp (LSE) values across TRT LLM attention APIs, exposing a capability previously only available internally within the kernels.
  • C++ API Extension: The C++ bindings trtllm_paged_attention_decode and trtllm_paged_attention_context now accept an optional lse parameter to capture LSE values.
  • Python API Integration: Introduced return_lse and lse parameters to the Python functions trtllm_batch_decode_with_kv_cache, trtllm_batch_context_with_kv_cache, and trtllm_batch_decode_with_kv_cache_mla, allowing users to request and provide LSE tensors.
  • Dynamic Buffer Allocation: Implemented conditional allocation of the softmaxStatsPtr buffer in C++ kernels when LSE values are requested, ensuring efficient memory usage.
  • Unified Return Type: Modified Python APIs to return a (output, lse) tuple when return_lse is set to True, providing a consistent interface for accessing both attention output and LSE values.
  • XQA Backend Limitation: Added checks to raise a ValueError if LSE return is requested for the XQA backend, as it currently does not support this feature.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully extends the TRT LLM attention APIs to optionally return Log-Sum-Exp (LSE) values. The changes are consistently applied across the C++ bindings and the Python APIs for decode, MLA, and context attention modes. The implementation correctly handles the allocation of the LSE tensor and adjusts the return types.

My main feedback is on the C++ implementation in csrc/trtllm_fmha_kernel_launcher.cu, where hardcoded values are used for buffer allocation, which could be optimized. I've left a specific comment with a suggestion for this. Additionally, there's a minor point about a potentially misleading comment that could be improved for better code clarity.

Overall, this is a great addition that exposes useful functionality from the underlying kernels.

Comment on lines +164 to +174
if (lse != nullptr) {
size_t max_batch_size = 8192;
size_t max_num_qo_heads = 256;
size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8);
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace");
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for allocating LSE-related buffers has a couple of areas for improvement:

  1. Hardcoded values: max_batch_size and max_num_qo_heads are hardcoded. This can lead to inefficient workspace memory allocation. It's better to use the dynamic batch_size and num_qo_heads values from the function arguments to make the allocation more precise.

  2. Unclear allocations: The preceding comment mentions allocating the softmaxStatsPtr, but this block also allocates multiCtasKvCounterPtr and multiCtasKvScratchPtr. These seem related to multi-CTA mode, which is disabled for context attention. If these are not needed for LSE computation in context mode, they should be removed. If they are needed, the comment should be updated for clarity.

The suggestion below addresses the hardcoded values.

    if (lse != nullptr) {
      size_t num_semaphores = round_up(batch_size * num_qo_heads, 8);
      runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
          num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
      runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
          sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace");
      runner_params.multiCtasKvScratchPtr =
          float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
    }

@claude
Copy link

claude bot commented Jan 11, 2026

Code Review: LSE Return Support for TRT LLM Attention Kernels

Thank you for this contribution! The implementation successfully exposes LSE functionality from the underlying TRT LLM kernels. I have reviewed the changes and have feedback:

Strengths

  1. Clean API design following existing FlashInfer patterns
  2. Proper error handling for unsupported XQA backend
  3. Convenient auto-allocation of LSE tensors
  4. Good documentation updates

Critical Issue: Workspace Buffer Allocation

In csrc/trtllm_fmha_kernel_launcher.cu:163-174, Context mode allocates multiCtasKvCounterPtr when LSE is requested, even though mMultiCtasKvMode=false. This wastes 8MB of workspace buffer. The counter should only be allocated when multi-block mode is enabled. Remove lines 168-169 from the Context LSE branch.

Potential Issues

  1. Missing input validation when user provides lse tensor (should verify shape/dtype/device)
  2. No test coverage for the new LSE functionality
  3. Documentation could clarify what LSE values represent

Recommendations Before Merging

  1. Fix workspace allocation in Context mode
  2. Add input validation for user-provided LSE tensors
  3. Add test coverage for LSE functionality

Overall solid implementation - main concern is the buffer allocation issue which should be straightforward to fix.

Review via Claude Code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Return LSE from all TRT LLM attention kernels

2 participants