Skip to content

[fix] bugfix 1419: Add batch size shape validation in decode and prefill run() APIs#2801

Open
qsang-nv wants to merge 5 commits intoflashinfer-ai:mainfrom
qsang-nv:issue_1419
Open

[fix] bugfix 1419: Add batch size shape validation in decode and prefill run() APIs#2801
qsang-nv wants to merge 5 commits intoflashinfer-ai:mainfrom
qsang-nv:issue_1419

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Mar 17, 2026

📌 Description

This PR fixes #1419, adding shape checks in several run() APIs to validate q.shape[0] before launching kernels, preventing silent OOB reads that cause CUDA crashes.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Added runtime validations to batched decode (including MLA path) to detect query/batch-size mismatches early.
    • Added runtime validations to prefill to ensure token/count consistency across paged and ragged flows, with backend-specific checks.
    • Error messages now report exact expected shapes/counts to make misconfigurations easier to diagnose.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the robustness of the FlashInfer library by introducing critical input validation within its core attention mechanisms. By proactively checking the shape of query tensors (q) before kernel launches, it eliminates a class of silent data access errors that previously resulted in difficult-to-diagnose CUDA crashes, thereby improving the stability and reliability of the library.

Highlights

  • Input Validation: Implemented shape checks in the run() APIs of flashinfer/decode.py and flashinfer/prefill.py to validate the q.shape[0] against expected batch sizes or total token counts.
  • Error Prevention: These checks prevent silent out-of-bounds (OOB) reads that could lead to CUDA crashes, directly addressing issue Ragged prefill SM90 fails with illegal memory access for total seqlen >= 16384 #1419 and improving system stability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/decode.py
    • Added a ValueError check in the run method to ensure q.size(0) matches _batch_size for batch decode operations.
    • Added a ValueError check in the run method for MLA batch decode to ensure q_nope.size(0) matches expected_batch_size.
  • flashinfer/prefill.py
    • Added a ValueError check in the run method to ensure q.size(0) matches _qo_indptr_buf[-1].item() for paged prefill.
    • Added a ValueError check in the run method to ensure q.size(0) matches _qo_indptr_buf[-1].item() for ragged prefill.
Activity
  • No specific review comments or activities have been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds runtime validations that assert query tensor dimensions match expected batch/token counts in batch-decode (paged KV, MLA) and prefill (paged, ragged, cudnn-aware) code paths, raising ValueError with explicit expected-shape messages on mismatch.

Changes

Cohort / File(s) Summary
Batch Decode Validation
flashinfer/decode.py
Added runtime checks in BatchDecodeWithPagedKVCacheWrapper.run and BatchDecodeMlaWithPagedKVCacheWrapper.run to validate q / q_nope / q_pe batch dimensions against batch_size derived from _paged_kv_last_page_len_buf. Raises ValueError with explicit expected-shape messages on mismatch.
Prefill Validation
flashinfer/prefill.py
Cached qo_indptr[-1] as self._qo_indptr_last in plan() for paged and ragged wrappers; updated run() to validate q size against that cached value. For backend == "cudnn" compares q.numel() vs self._qo_indptr_last; otherwise compares q.size(0) vs self._qo_indptr_last. Raises backend- and path-specific ValueError messages on mismatch.

Sequence Diagram(s)

(Skipped)

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • aleozlx
  • bkryu
  • yzh119
  • nv-yunzheq

Poem

🐰 I counted tokens, soft and bright,
Hopped through pages, checked each byte.
If shapes mislead, I sound the bell —
"Align your batches, do it well!"
A cheerful hop to keep things right.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding batch size shape validation to decode and prefill run() APIs to fix issue 1419.
Description check ✅ Passed The description adequately explains what the PR does (fixes issue 1419 by adding shape checks), references the related issue, and follows the template structure with sections for description and related issues.
Linked Issues check ✅ Passed The PR implementation directly addresses the objective in issue #1419: adding runtime shape validation for q.shape[0] (query batch dimension) in decode and prefill run() APIs to prevent out-of-bounds reads and CUDA crashes.
Out of Scope Changes check ✅ Passed All changes are narrowly focused on adding shape validation in run() methods for decode and prefill wrappers; no unrelated modifications or scope creep detected beyond the stated objective of issue #1419.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces important shape validation checks in the run() methods of several wrapper classes, including BatchDecodeWithPagedKVCacheWrapper, BatchDecodeMlaWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, and BatchPrefillWithRaggedKVCacheWrapper. These checks validate the input query tensor shapes against expected dimensions derived from parameters set during the plan() phase. By raising a ValueError with a descriptive message upon mismatch, this change prevents potential out-of-bounds memory access and subsequent CUDA crashes, making the API more robust and user-friendly. The implementation is correct and effectively addresses the described issue.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Around line 1847-1852: The MLA batch validation only checks q_nope's batch
dimension against expected_batch_size (derived from
self._paged_kv_last_page_len_buf.size(0)); add the same guard for q_pe before
the kernel launch by verifying q_pe.size(0) == expected_batch_size and raising a
ValueError with a clear message (mirroring the q_nope error text) if it doesn't
match, so both q_nope and q_pe must have shape [batch_size, num_heads,
head_dim].
- Around line 1283-1287: The current check forces q.shape[0] == self._batch_size
but fails to allow multi-token inputs where q.shape[0] == self._batch_size *
q_len_per_req; update the validation in the block that references q,
self._batch_size and q_len_per_req to accept either first-dim ==
self._batch_size or first-dim == self._batch_size * q_len_per_req (use the same
q_len_per_req used later when reshaping at the line that reshapes q to
[batch_size, q_len_per_req, ...]); if neither matches, raise a ValueError with a
clear message showing both expected shapes.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9a5a5800-9f9f-4fe1-8e82-45607b59aa2e

📥 Commits

Reviewing files that changed from the base of the PR and between e4dc66f and 1f1e38e.

📒 Files selected for processing (2)
  • flashinfer/decode.py
  • flashinfer/prefill.py

Comment on lines +1283 to +1287
if q.size(0) != self._batch_size:
raise ValueError(
f"q.shape[0] ({q.size(0)}) does not match batch_size ({self._batch_size}). "
f"For batch decode, q must have shape [batch_size, num_heads, head_dim]."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Batch-size check breaks valid trtllm-gen multi-token decode inputs.

At Line 1283, q.shape[0] is forced to self._batch_size, but Line 1340 supports q_len_per_req > 1 by reshaping q as [batch_size, q_len_per_req, ...]. This incorrectly rejects valid inputs where q.shape[0] == batch_size * q_len_per_req.

💡 Proposed fix
-        if q.size(0) != self._batch_size:
+        expected_q_rows = self._batch_size
+        if self._backend == "trtllm-gen":
+            effective_q_len = q_len_per_req if q_len_per_req is not None else 1
+            expected_q_rows *= effective_q_len
+        if q.size(0) != expected_q_rows:
             raise ValueError(
-                f"q.shape[0] ({q.size(0)}) does not match batch_size ({self._batch_size}). "
-                f"For batch decode, q must have shape [batch_size, num_heads, head_dim]."
+                f"q.shape[0] ({q.size(0)}) does not match expected rows ({expected_q_rows}). "
+                f"For batch decode, q must have shape "
+                f"[batch_size, num_heads, head_dim] (or [batch_size * q_len_per_req, num_heads, head_dim] for trtllm-gen)."
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1283 - 1287, The current check forces
q.shape[0] == self._batch_size but fails to allow multi-token inputs where
q.shape[0] == self._batch_size * q_len_per_req; update the validation in the
block that references q, self._batch_size and q_len_per_req to accept either
first-dim == self._batch_size or first-dim == self._batch_size * q_len_per_req
(use the same q_len_per_req used later when reshaping at the line that reshapes
q to [batch_size, q_len_per_req, ...]); if neither matches, raise a ValueError
with a clear message showing both expected shapes.

@qsang-nv qsang-nv requested a review from saltyminty March 17, 2026 07:27
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@yzh119 yzh119 changed the title [fix]fix issue 1419 [fix] bugfix 1419: Add batch size shape validation in decode and prefill run() APIs Mar 17, 2026
@qsang-nv
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !424 has been created, and the CI pipeline #46389474 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46389474: 6/20 passed

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@qsang-nv
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !424 has been updated with latest changes, and the CI pipeline #46414955 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #46414955: canceled

q, k, self._cached_q_data_type, self._cached_kv_data_type
)
# Validate q shape matches qo_indptr
expected_qo_indptr_last = self._qo_indptr_buf[-1].item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._qo_indptr_buf[-1].item() would be a blocking (requires synchronization) operation if _qo_indptr_buf is gpu tensor, and I would encourage moving it to plan if possible, we should get rid of synchronization in run functions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, now the size is obtained in plan and saved as a scalar to be used for checking.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@qsang-nv
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !424 has been updated with latest changes, and the CI pipeline #46586853 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46586853: 13/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Ragged prefill SM90 fails with illegal memory access for total seqlen >= 16384

3 participants