Add Option For Fixed cta_tile_q#2830
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a new planning parameter Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Python as Python Plan Caller
participant JIT as JIT Binding
participant GPU as CUDA Scheduler
Python->>JIT: .plan(..., fixed_cta_tile_q)
JIT->>GPU: BatchPrefillWithKVCachePlan(..., fixed_cta_tile_q, ...)
GPU->>GPU: PrefillPlan(..., fixed_cta_tile_q) / PrefillSplitQOKVIndptr(...)
GPU-->>JIT: plan array / tile-size decisions
JIT-->>Python: return plan/results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances FlashInfer's attention planning by providing a mechanism to enforce a fixed CTA tile size. This change is crucial for achieving deterministic behavior and batch invariance, particularly when dealing with varying batch sizes or CUDA graph optimizations, by preventing the dynamic selection of tile sizes that could lead to inconsistent outputs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an option to set a fixed cta_tile_q size, which is a valuable feature for ensuring batch invariance and deterministic outputs. The changes are well-implemented across the C++ backend and Python wrappers, and the inclusion of new tests for validation and invariance is commendable.
My main feedback is regarding code duplication in the Python validation logic for fixed_cta_tile_q. I've left specific comments suggesting refactoring this into a shared helper function to improve maintainability. Other than that, the changes look solid.
There was a problem hiding this comment.
🧹 Nitpick comments (3)
flashinfer/decode.py (1)
2675-2684: Consider extracting duplicated validation to a helper function.The validation logic for
fixed_cta_tile_qis duplicated betweenplan()(lines 988-997) andfast_decode_plan()(lines 2675-2684). Consider extracting this to a small helper function to improve maintainability.♻️ Suggested helper extraction
def _validate_fixed_cta_tile_q(fixed_cta_tile_q: Optional[int], head_dim: int) -> int: """Validate and normalize fixed_cta_tile_q parameter. Returns -1 for auto heuristic, or the validated value. """ if fixed_cta_tile_q is None: return -1 if fixed_cta_tile_q not in (16, 64, 128): raise ValueError( f"fixed_cta_tile_q should be one of {{16, 64, 128}}, got {fixed_cta_tile_q}" ) if fixed_cta_tile_q == 128 and head_dim >= 256: raise ValueError( f"fixed_cta_tile_q=128 is not supported with head_dim={head_dim} (requires head_dim < 256)" ) return fixed_cta_tile_q🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 2675 - 2684, Extract the duplicated fixed_cta_tile_q validation into a helper function (e.g., _validate_fixed_cta_tile_q) that accepts fixed_cta_tile_q and head_dim and returns -1 when fixed_cta_tile_q is None or the validated value otherwise; move the three checks (None -> -1, membership in (16,64,128), and the 128 vs head_dim >= 256 error) into that helper, and replace the inline logic in both plan() and fast_decode_plan() with a call to this helper to normalize/validate the value and preserve the same ValueError messages and semantics.flashinfer/prefill.py (1)
1818-1827: Scopefixed_cta_tile_qvalidation to FA2-resolved plans.
fixed_cta_tile_qis documented as FA2-specific, but current validation executes unconditionally before backend resolution. This can reject non-FA2 plans for a parameter that is otherwise ignored there. Consider validating compatibility only when the resolved backend isfa2.Also applies to: 2809-2818
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1818 - 1827, The validation for fixed_cta_tile_q currently runs unconditionally; restrict this check so it only executes for FA2-resolved plans by wrapping the existing checks for fixed_cta_tile_q and the head_dim_vo compatibility check in a conditional that first verifies the plan/backend is FA2 (e.g., resolved_backend == "fa2" or plan.is_fa2) at the point after backend resolution; apply the same change to the other duplicate validation block (the one around the second occurrence) so non-FA2 backends won’t be rejected for this FA2-specific parameter.tests/attention/test_batch_invariant_fa2.py (1)
56-67: Consider reducing the new Cartesian test expansion.Adding
fixed_cta_tile_qas a full-axis multiplier triples already-large matrices and may make GPU CI much slower/flakier. Prefer a smaller targeted matrix (or dedicated focused cases) forfixed_cta_tile_qcoverage.Also applies to: 198-210
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_batch_invariant_fa2.py` around lines 56 - 67, The Cartesian expansion from adding full-axis parametrize for fixed_cta_tile_q is inflating test matrix size; restrict it by replacing pytest.mark.parametrize("fixed_cta_tile_q", [16, 64, 128]) with a much smaller set (e.g., a single representative value like [64] or two targeted values like [16, 64]) or move fixed_cta_tile_q into a separate, focused test marked slow so it doesn't multiply all other params; update the parametrize for fixed_cta_tile_q and/or add a dedicated test function (or pytest.mark.slow) to cover the remaining tiles without expanding the entire grid, ensuring references to the same parameter name fixed_cta_tile_q are adjusted accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/decode.py`:
- Around line 2675-2684: Extract the duplicated fixed_cta_tile_q validation into
a helper function (e.g., _validate_fixed_cta_tile_q) that accepts
fixed_cta_tile_q and head_dim and returns -1 when fixed_cta_tile_q is None or
the validated value otherwise; move the three checks (None -> -1, membership in
(16,64,128), and the 128 vs head_dim >= 256 error) into that helper, and replace
the inline logic in both plan() and fast_decode_plan() with a call to this
helper to normalize/validate the value and preserve the same ValueError messages
and semantics.
In `@flashinfer/prefill.py`:
- Around line 1818-1827: The validation for fixed_cta_tile_q currently runs
unconditionally; restrict this check so it only executes for FA2-resolved plans
by wrapping the existing checks for fixed_cta_tile_q and the head_dim_vo
compatibility check in a conditional that first verifies the plan/backend is FA2
(e.g., resolved_backend == "fa2" or plan.is_fa2) at the point after backend
resolution; apply the same change to the other duplicate validation block (the
one around the second occurrence) so non-FA2 backends won’t be rejected for this
FA2-specific parameter.
In `@tests/attention/test_batch_invariant_fa2.py`:
- Around line 56-67: The Cartesian expansion from adding full-axis parametrize
for fixed_cta_tile_q is inflating test matrix size; restrict it by replacing
pytest.mark.parametrize("fixed_cta_tile_q", [16, 64, 128]) with a much smaller
set (e.g., a single representative value like [64] or two targeted values like
[16, 64]) or move fixed_cta_tile_q into a separate, focused test marked slow so
it doesn't multiply all other params; update the parametrize for
fixed_cta_tile_q and/or add a dedicated test function (or pytest.mark.slow) to
cover the remaining tiles without expanding the entire grid, ensuring references
to the same parameter name fixed_cta_tile_q are adjusted accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fc264a2a-c5dc-487a-85e0-2de88b01bbbd
📒 Files selected for processing (9)
csrc/batch_prefill.cucsrc/batch_prefill_jit_binding.cuflashinfer/decode.pyflashinfer/pod.pyflashinfer/prefill.pyflashinfer/sparse.pyinclude/flashinfer/attention/scheduler.cuhtests/attention/test_batch_invariant_fa2.pytests/attention/test_batch_prefill.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Line 849: The parameter fixed_cta_tile_q is currently ignored in
non-tensor-core paths; mirror the fixed_split_size guard by rejecting/raising
when fixed_cta_tile_q is non-None and use_tensor_cores is False in the public
plan() (validate-and-drop) and similarly drop/raise in fast_decode_plan() where
parameters are pre-resolution. Additionally, after backend resolution (the code
path that inspects the resolved tensor-core backend and chooses FA2), add a
check that if the chosen tensor-core path is not "fa2" and fixed_cta_tile_q is
non-None, reject it (raise or error) so fixed_cta_tile_q is only accepted when
the final backend is fa2; update the same checks where fixed_split_size is
handled to cover fixed_cta_tile_q as well.
In `@flashinfer/prefill.py`:
- Line 1679: After resolving the effective backend (i.e., after handling
backend="auto") add a fail-fast check: if fixed_cta_tile_q is not None and the
resolved backend is not "fa2", raise a clear ValueError indicating
fixed_cta_tile_q is only supported for the fa2 backend. Locate the check and
insertion point inside plan() where self._backend == "fa2" is handled and
move/duplicate the validation so it runs after backend resolution (not before),
and mirror the same change in the other affected call sites referenced (around
the fixed_cta_tile_q occurrences and plan-like flows at the other locations) so
non-None values are rejected unless backend == "fa2".
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 979bb9a2-f818-4ca8-8b09-204e5e45ea05
📒 Files selected for processing (3)
flashinfer/decode.pyflashinfer/prefill.pyflashinfer/utils.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/utils.py
📌 Description
This PR adds functionality for a caller to set a fixed
cta_tile_qsize.This is mainly a use case for batch invariance as dynamically chosen
cta_tile_qvalues can lead to variant outputs.🔍 Related Issues
Fixes #2424
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Also tested integration with vLLM using a slightly modified
test_logprobs_bitwise_batch_invariance_bs1_vs_bsNwhich uses Qwen/Qwen3-1.7B (gqa_group_size=2).vLLM with 0.6.6 FlashInfer:
Logging the args of the failing request to FlashInfer's
BatchPrefillWithPagedKVCacheWrapperplan:Single request:
avg_packed_qo_len = 372 * 2 = 744and soFA2DetermineCtaTileQ -> 128Batched with other requests:
avg_packed_qo_len = 684 * 2 / 28 = 48.8571428571and soFA2DetermineCtaTileQ -> 64vLLM with FlashInfer built off of this branch:
Logging the args of the previously failing request sent to FlashInfer's
BatchPrefillWithPagedKVCacheWrapperplan:avg_packed_qo_len = 372 * 2 = 744, which would typically causeFA2DetermineCtaTileQ -> 128but doesn't matter asfixed_cta_tile_qoverrides to 128 anyways.Batched with other requests:
avg_packed_qo_len = 665 * 2 / 26 = 51.1538461538, which would typically causeFA2DetermineCtaTileQ -> 64butfixed_cta_tile_qoverrides to128.Summary by CodeRabbit
New Features
Tests