Conversation
There was a problem hiding this comment.
Hello @blueswhen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! Gemini/gemini-code-assist here, providing a summary of this pull request to help everyone get up to speed quickly.
This PR, titled "fix: fix bugs in flashinfer_struct", addresses issues within the flashinfer_struct implementation for the DeepSeek2 and Llama models. Based on the changes, the core problem appears to be related to the sizing or indexing of the kv_indices buffer used by the FlashInfer library during attention calculations. The fix involves correcting the size calculation for this buffer, switching from a potentially overly large global max_seq_length to a more accurate size based on the maximum sequence length within the current batch (max_len_in_batch) or the total space needed for the batch (batch_size * self.max_len_in_batch). This likely resolves potential out-of-bounds access or incorrect memory allocation issues when handling variable batch sizes or sequence lengths.
Highlights
- Bug Fix: Fixes bugs related to the
flashinfer_structimplementation in DeepSeek2 and Llama models. - FlashInfer Integration: Corrects the calculation for the size of the
kv_indicesbuffer used by FlashInfer, ensuring it's based on the actual batch size and maximum sequence length within the batch. - Memory Management: Improves memory handling for the
kv_indicesbuffer by usingmax_len_in_batchinstead of a potentially largermax_seq_length.
Changelog
- lightllm/models/deepseek2/flashinfer_struct.py
- Corrected the size calculation for the
kv_indicesbuffer initialization whenbatch_sizeis within or exceedsmodel.graph_max_batch_size, changing fromself.flashinfer_extra_state.max_seq_lengthtoself.max_len_in_batch(lines 28 and 31).
- Corrected the size calculation for the
- lightllm/models/llama/flashinfer_struct.py
- Corrected the size calculation for the
kv_indicesbuffer initialization whenbatch_sizeis within or exceedsmodel.graph_max_batch_size, changing fromself.flashinfer_extra_state.max_seq_lengthtoself.batch_size * self.max_len_in_batch(lines 28 and 31). - Corrected the size calculation for the
kv_indicesbuffer initialization in another branch of theinit_some_extra_statemethod, changing fromself.flashinfer_extra_state.max_seq_lengthtoself.batch_size * self.max_len_in_batch(line 71).
- Corrected the size calculation for the
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request addresses critical bugs related to the sizing of the kv_indices tensor within the FlashInfer integration for both DeepSeek2 and Llama models. The changes ensure that kv_indices is allocated with sufficient memory for batch processing in both decode and prefill phases, preventing potential out-of-bounds errors and incorrect behavior. The precision of buffer slicing has also been improved. These are important fixes for the stability and correctness of the models. Great work on identifying and resolving these issues!
Summary of Findings
- Corrected
kv_indicesTensor Sizing in Decode Phase: The pull request critically fixes undersized memory allocation forkv_indicesin the decode logic for both DeepSeek2 and Llama models. It also refines buffer slicing for better precision. This prevents potential out-of-bounds memory access and crashes. - Corrected
kv_indicesTensor Sizing in Llama Prefill Phase: A similar critical bug regarding undersizedkv_indicesallocation was fixed in the Llama model's prefill logic, ensuring sufficient space for batch operations and preventing errors.
Merge Readiness
This pull request successfully fixes critical bugs in memory allocation that could lead to runtime errors or incorrect outputs when using FlashInfer. The changes are clear, targeted, and directly address the identified problems. Based on this review, the PR significantly improves the robustness of the codebase and appears ready for merging. As always, further testing and review by other team members are recommended. I am not authorized to approve pull requests.
| : self.batch_size * self.max_len_in_batch | ||
| ] | ||
| else: | ||
| self.kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, dtype=torch.int32).to( | ||
| self.kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, dtype=torch.int32).to( |
There was a problem hiding this comment.
This hunk introduces crucial fixes and improvements to the sizing of self.kv_indices for the decode phase.
-
On line 28 (within the
ifblock): Changing the slice dimension fromself.batch_size * self.flashinfer_extra_state.max_seq_lengthtoself.batch_size * self.max_len_in_batchmakes the viewself.kv_indicesmore precisely reflect the needs of the current batch (up tomax_len_in_batchper sequence). This is a good refinement for accuracy and potentially for FlashInfer's internal planning. (The improvement in this specific part can be considered ofmediumimportance). -
On line 31 (within the
elseblock): The change fromtorch.zeros(self.flashinfer_extra_state.max_seq_length, ...)totorch.zeros(self.batch_size * self.max_len_in_batch, ...)is a critical bug fix. The previous allocation was insufficient forbatch_size > 1, as it only allocated space for a single max-length sequence. Therepack_kv_indexfunction and subsequent FlashInfer operations require space for all tokens in the batch (up tosum(self.b_seq_len), which is bounded byself.batch_size * self.max_len_in_batch). This correction prevents out-of-bounds memory access and ensures correct functionality.
Overall, these changes significantly improve the robustness of kv_indices handling.
| : self.batch_size * self.max_len_in_batch | ||
| ] | ||
| else: | ||
| self.kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, dtype=torch.int32).to( | ||
| self.kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, dtype=torch.int32).to( |
There was a problem hiding this comment.
This hunk mirrors the important fixes and improvements for self.kv_indices sizing in the decode phase, as seen in the deepseek2 model.
-
On line 28 (within the
ifblock): Adjusting the slice to useself.max_len_in_batchinstead ofself.flashinfer_extra_state.max_seq_length(when multiplied byself.batch_size) provides a more precise size for theself.kv_indicesview, tailored to the current batch. (The improvement in this specific part can be considered ofmediumimportance). -
On line 31 (within the
elseblock): Correcting the allocation fromtorch.zeros(self.flashinfer_extra_state.max_seq_length, ...)totorch.zeros(self.batch_size * self.max_len_in_batch, ...)is a critical bug fix. The original allocation was too small for batch processing, and this change ensures sufficient memory forrepack_kv_indexand FlashInfer, preventing errors.
These are essential changes for the stability and correctness of FlashInfer integration in the Llama model.
| kv_starts = self.b1_cu_kv_seq_len.int() | ||
| kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32).to(input_ids.device) | ||
| kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, dtype=torch.int32).to( | ||
| kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, dtype=torch.int32).to( |
There was a problem hiding this comment.
This change addresses a critical bug in the memory allocation for kv_indices during the prefill phase.
The original code kv_indices = torch.zeros(self.flashinfer_extra_state.max_seq_length, ...) allocated space sufficient for only one sequence of maximum length.
For batch prefill, the repack_kv_index operation and subsequent FlashInfer kernels require kv_indices to hold data for all tokens across all sequences in the batch. This total size can be up to self.batch_size * self.max_len_in_batch.
The corrected allocation kv_indices = torch.zeros(self.batch_size * self.max_len_in_batch, ...) ensures that kv_indices is large enough, preventing potential out-of-bounds writes and ensuring correct behavior. This is a vital fix for reliable prefill operations.
9f9fa16 to
1694338
Compare
No description provided.