[fix] bugfix 1419: Add batch size shape validation in decode and prefill run() APIs#2801
[fix] bugfix 1419: Add batch size shape validation in decode and prefill run() APIs#2801qsang-nv wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the robustness of the FlashInfer library by introducing critical input validation within its core attention mechanisms. By proactively checking the shape of query tensors ( Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds runtime validations that assert query tensor dimensions match expected batch/token counts in batch-decode (paged KV, MLA) and prefill (paged, ragged, cudnn-aware) code paths, raising ValueError with explicit expected-shape messages on mismatch. Changes
Sequence Diagram(s)(Skipped) Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces important shape validation checks in the run() methods of several wrapper classes, including BatchDecodeWithPagedKVCacheWrapper, BatchDecodeMlaWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, and BatchPrefillWithRaggedKVCacheWrapper. These checks validate the input query tensor shapes against expected dimensions derived from parameters set during the plan() phase. By raising a ValueError with a descriptive message upon mismatch, this change prevents potential out-of-bounds memory access and subsequent CUDA crashes, making the API more robust and user-friendly. The implementation is correct and effectively addresses the described issue.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 1847-1852: The MLA batch validation only checks q_nope's batch
dimension against expected_batch_size (derived from
self._paged_kv_last_page_len_buf.size(0)); add the same guard for q_pe before
the kernel launch by verifying q_pe.size(0) == expected_batch_size and raising a
ValueError with a clear message (mirroring the q_nope error text) if it doesn't
match, so both q_nope and q_pe must have shape [batch_size, num_heads,
head_dim].
- Around line 1283-1287: The current check forces q.shape[0] == self._batch_size
but fails to allow multi-token inputs where q.shape[0] == self._batch_size *
q_len_per_req; update the validation in the block that references q,
self._batch_size and q_len_per_req to accept either first-dim ==
self._batch_size or first-dim == self._batch_size * q_len_per_req (use the same
q_len_per_req used later when reshaping at the line that reshapes q to
[batch_size, q_len_per_req, ...]); if neither matches, raise a ValueError with a
clear message showing both expected shapes.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9a5a5800-9f9f-4fe1-8e82-45607b59aa2e
📒 Files selected for processing (2)
flashinfer/decode.pyflashinfer/prefill.py
flashinfer/decode.py
Outdated
| if q.size(0) != self._batch_size: | ||
| raise ValueError( | ||
| f"q.shape[0] ({q.size(0)}) does not match batch_size ({self._batch_size}). " | ||
| f"For batch decode, q must have shape [batch_size, num_heads, head_dim]." | ||
| ) |
There was a problem hiding this comment.
Batch-size check breaks valid trtllm-gen multi-token decode inputs.
At Line 1283, q.shape[0] is forced to self._batch_size, but Line 1340 supports q_len_per_req > 1 by reshaping q as [batch_size, q_len_per_req, ...]. This incorrectly rejects valid inputs where q.shape[0] == batch_size * q_len_per_req.
💡 Proposed fix
- if q.size(0) != self._batch_size:
+ expected_q_rows = self._batch_size
+ if self._backend == "trtllm-gen":
+ effective_q_len = q_len_per_req if q_len_per_req is not None else 1
+ expected_q_rows *= effective_q_len
+ if q.size(0) != expected_q_rows:
raise ValueError(
- f"q.shape[0] ({q.size(0)}) does not match batch_size ({self._batch_size}). "
- f"For batch decode, q must have shape [batch_size, num_heads, head_dim]."
+ f"q.shape[0] ({q.size(0)}) does not match expected rows ({expected_q_rows}). "
+ f"For batch decode, q must have shape "
+ f"[batch_size, num_heads, head_dim] (or [batch_size * q_len_per_req, num_heads, head_dim] for trtllm-gen)."
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/decode.py` around lines 1283 - 1287, The current check forces
q.shape[0] == self._batch_size but fails to allow multi-token inputs where
q.shape[0] == self._batch_size * q_len_per_req; update the validation in the
block that references q, self._batch_size and q_len_per_req to accept either
first-dim == self._batch_size or first-dim == self._batch_size * q_len_per_req
(use the same q_len_per_req used later when reshaping at the line that reshapes
q to [batch_size, q_len_per_req, ...]); if neither matches, raise a ValueError
with a clear message showing both expected shapes.
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
|
/bot run |
|
[FAILED] Pipeline #46389474: 6/20 passed |
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
|
/bot run |
|
[CANCELING] Pipeline #46414955: canceled |
flashinfer/prefill.py
Outdated
| q, k, self._cached_q_data_type, self._cached_kv_data_type | ||
| ) | ||
| # Validate q shape matches qo_indptr | ||
| expected_qo_indptr_last = self._qo_indptr_buf[-1].item() |
There was a problem hiding this comment.
self._qo_indptr_buf[-1].item() would be a blocking (requires synchronization) operation if _qo_indptr_buf is gpu tensor, and I would encourage moving it to plan if possible, we should get rid of synchronization in run functions.
There was a problem hiding this comment.
Done, now the size is obtained in plan and saved as a scalar to be used for checking.
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
|
/bot run |
|
[SUCCESS] Pipeline #46586853: 13/20 passed |
📌 Description
This PR fixes #1419, adding shape checks in several run() APIs to validate q.shape[0] before launching kernels, preventing silent OOB reads that cause CUDA crashes.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit