Skip to content

fix: fix bugs in flashinfer_struct#921

Merged
hiworldwzj merged 1 commit intomainfrom
fix_flashinfer
Jun 6, 2025
Merged

fix: fix bugs in flashinfer_struct#921
hiworldwzj merged 1 commit intomainfrom
fix_flashinfer

Conversation

@blueswhen
Copy link
Collaborator

No description provided.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @blueswhen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini/gemini-code-assist here, providing a summary of this pull request to help everyone get up to speed quickly.

This PR, titled "fix: fix bugs in flashinfer_struct", addresses issues within the flashinfer_struct implementation for the DeepSeek2 and Llama models. Based on the changes, the core problem appears to be related to the sizing or indexing of the kv_indices buffer used by the FlashInfer library during attention calculations. The fix involves correcting the size calculation for this buffer, switching from a potentially overly large global max_seq_length to a more accurate size based on the maximum sequence length within the current batch (max_len_in_batch) or the total space needed for the batch (batch_size * self.max_len_in_batch). This likely resolves potential out-of-bounds access or incorrect memory allocation issues when handling variable batch sizes or sequence lengths.

Highlights

  • Bug Fix: Fixes bugs related to the flashinfer_struct implementation in DeepSeek2 and Llama models.
  • FlashInfer Integration: Corrects the calculation for the size of the kv_indices buffer used by FlashInfer, ensuring it's based on the actual batch size and maximum sequence length within the batch.
  • Memory Management: Improves memory handling for the kv_indices buffer by using max_len_in_batch instead of a potentially larger max_seq_length.

Changelog

  • lightllm/models/deepseek2/flashinfer_struct.py
    • Corrected the size calculation for the kv_indices buffer initialization when batch_size is within or exceeds model.graph_max_batch_size, changing from self.flashinfer_extra_state.max_seq_length to self.max_len_in_batch (lines 28 and 31).
  • lightllm/models/llama/flashinfer_struct.py
    • Corrected the size calculation for the kv_indices buffer initialization when batch_size is within or exceeds model.graph_max_batch_size, changing from self.flashinfer_extra_state.max_seq_length to self.batch_size * self.max_len_in_batch (lines 28 and 31).
    • Corrected the size calculation for the kv_indices buffer initialization in another branch of the init_some_extra_state method, changing from self.flashinfer_extra_state.max_seq_length to self.batch_size * self.max_len_in_batch (line 71).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses critical bugs related to the sizing of the kv_indices tensor within the FlashInfer integration for both DeepSeek2 and Llama models. The changes ensure that kv_indices is allocated with sufficient memory for batch processing in both decode and prefill phases, preventing potential out-of-bounds errors and incorrect behavior. The precision of buffer slicing has also been improved. These are important fixes for the stability and correctness of the models. Great work on identifying and resolving these issues!

Summary of Findings

  • Corrected kv_indices Tensor Sizing in Decode Phase: The pull request critically fixes undersized memory allocation for kv_indices in the decode logic for both DeepSeek2 and Llama models. It also refines buffer slicing for better precision. This prevents potential out-of-bounds memory access and crashes.
  • Corrected kv_indices Tensor Sizing in Llama Prefill Phase: A similar critical bug regarding undersized kv_indices allocation was fixed in the Llama model's prefill logic, ensuring sufficient space for batch operations and preventing errors.

Merge Readiness

This pull request successfully fixes critical bugs in memory allocation that could lead to runtime errors or incorrect outputs when using FlashInfer. The changes are clear, targeted, and directly address the identified problems. Based on this review, the PR significantly improves the robustness of the codebase and appears ready for merging. As always, further testing and review by other team members are recommended. I am not authorized to approve pull requests.

Comment on lines 28 to 31
: self.batch_size * self.max_len_in_batch
]
else:
self.kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, dtype=torch.int32).to(
self.kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, dtype=torch.int32).to(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This hunk introduces crucial fixes and improvements to the sizing of self.kv_indices for the decode phase.

  • On line 28 (within the if block): Changing the slice dimension from self.batch_size * self.flashinfer_extra_state.max_seq_length to self.batch_size * self.max_len_in_batch makes the view self.kv_indices more precisely reflect the needs of the current batch (up to max_len_in_batch per sequence). This is a good refinement for accuracy and potentially for FlashInfer's internal planning. (The improvement in this specific part can be considered of medium importance).

  • On line 31 (within the else block): The change from torch.zeros(self.flashinfer_extra_state.max_seq_length, ...) to torch.zeros(self.batch_size * self.max_len_in_batch, ...) is a critical bug fix. The previous allocation was insufficient for batch_size > 1, as it only allocated space for a single max-length sequence. The repack_kv_index function and subsequent FlashInfer operations require space for all tokens in the batch (up to sum(self.b_seq_len), which is bounded by self.batch_size * self.max_len_in_batch). This correction prevents out-of-bounds memory access and ensures correct functionality.

Overall, these changes significantly improve the robustness of kv_indices handling.

Comment on lines 28 to 31
: self.batch_size * self.max_len_in_batch
]
else:
self.kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, dtype=torch.int32).to(
self.kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, dtype=torch.int32).to(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This hunk mirrors the important fixes and improvements for self.kv_indices sizing in the decode phase, as seen in the deepseek2 model.

  • On line 28 (within the if block): Adjusting the slice to use self.max_len_in_batch instead of self.flashinfer_extra_state.max_seq_length (when multiplied by self.batch_size) provides a more precise size for the self.kv_indices view, tailored to the current batch. (The improvement in this specific part can be considered of medium importance).

  • On line 31 (within the else block): Correcting the allocation from torch.zeros(self.flashinfer_extra_state.max_seq_length, ...) to torch.zeros(self.batch_size * self.max_len_in_batch, ...) is a critical bug fix. The original allocation was too small for batch processing, and this change ensures sufficient memory for repack_kv_index and FlashInfer, preventing errors.

These are essential changes for the stability and correctness of FlashInfer integration in the Llama model.

kv_starts = self.b1_cu_kv_seq_len.int()
kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32).to(input_ids.device)
kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, dtype=torch.int32).to(
kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, dtype=torch.int32).to(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change addresses a critical bug in the memory allocation for kv_indices during the prefill phase.
The original code kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, ...) allocated space sufficient for only one sequence of maximum length.
For batch prefill, the repack_kv_index operation and subsequent FlashInfer kernels require kv_indices to hold data for all tokens across all sequences in the batch. This total size can be up to self.batch_size * self.max_len_in_batch.
The corrected allocation kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, ...) ensures that kv_indices is large enough, preventing potential out-of-bounds writes and ensuring correct behavior. This is a vital fix for reliable prefill operations.

@blueswhen blueswhen force-pushed the fix_flashinfer branch 3 times, most recently from 9f9fa16 to 1694338 Compare June 6, 2025 07:50
@hiworldwzj hiworldwzj merged commit 5d248e9 into main Jun 6, 2025
1 check passed
@hiworldwzj hiworldwzj deleted the fix_flashinfer branch June 6, 2025 07:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants