Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses API breaking changes for the upcoming 0.6.7 release by refining function signatures across several modules. It primarily focuses on making certain cache-related parameters keyword-only to improve API stability and clarity. Additionally, it enhances the flexibility of quantization scale handling in normalization functions by temporarily allowing float inputs while guiding users towards a more robust Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThe PR changes Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip Flake8 can be used to improve the quality of Python code reviews.Flake8 is a Python linter that wraps PyFlakes, pycodestyle and Ned Batchelder's McCabe script. To configure Flake8, add a '.flake8' or 'setup.cfg' file to your project root. See Flake8 Documentation for more details. |
There was a problem hiding this comment.
Code Review
This pull request introduces several API changes to improve backward compatibility and future-proofing. The changes in flashinfer/decode.py and flashinfer/xqa.py correctly adapt to making k_sf_cache and v_sf_cache keyword-only arguments, which is a good API design practice. The modifications in flashinfer/norm/__init__.py add backward compatibility for the scale parameter by allowing a float value, while issuing a helpful FutureWarning. The implementation is sound. I have one minor suggestion regarding code style in flashinfer/norm/__init__.py to improve adherence to PEP 8.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/norm/__init__.py (1)
186-187:⚠️ Potential issue | 🟡 MinorUpdate
scaleparameter docs to match the new API contract.The docstrings still say
scale: torch.Tensor, but the function now accepts float (deprecated) as well. This mismatch will confuse users.Proposed doc update
- scale: torch.Tensor - Scale factor for quantization, shape (1,). + scale: Union[float, torch.Tensor] + Quantization scale. `torch.Tensor` of shape (1,) is preferred. + Passing `float` is deprecated and kept temporarily for compatibility.Also applies to: 301-302
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/__init__.py` around lines 186 - 187, Update the docstring for the parameter named "scale" to reflect the new API contract: change the type from "torch.Tensor" to indicate it accepts either a torch.Tensor or a float (with a note that float usage is deprecated), and clarify expected shape/semantics (e.g., torch.Tensor shape (1,) or scalar float). Locate the two occurrences in this module where "scale: torch.Tensor" is documented (the block around the earlier occurrence and the second occurrence near the later docstring) and update both to "scale: Union[torch.Tensor, float] — Scale factor for quantization; preferred as torch.Tensor of shape (1,), float is accepted but deprecated."
🧹 Nitpick comments (1)
flashinfer/norm/__init__.py (1)
65-77: Tighten non-tensor input validation forscale.Current logic warns for any non-tensor value, but only float input is intended here. Add an explicit type gate so invalid inputs fail with a clear
TypeErrorinstead of implicit tensor-construction errors.Proposed patch
def _normalize_scale_tensor( scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor ) -> torch.Tensor: """Normalize quantization scale to 1D tensor of shape (1,) on target device.""" - if not isinstance(scale, torch.Tensor): + if not isinstance(scale, torch.Tensor): + if not isinstance(scale, float): + raise TypeError( + f"scale must be float or torch.Tensor, got {type(scale).__name__}" + ) import warnings warnings.warn( "Passing scale as a float is deprecated and will be removed in a future " "release. Use a torch.Tensor of shape (1,) instead.", FutureWarning, stacklevel=3, ) scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/__init__.py` around lines 65 - 77, The current normalization helper accepts any non-torch.Tensor and attempts to convert it, which masks invalid types; update the input validation so that if scale is a torch.Tensor proceed as before, if it's a float emit the existing FutureWarning and convert via torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device), but if scale is neither float nor torch.Tensor raise a TypeError with a clear message; adjust the branch around scale/ref_tensor and the FutureWarning usage to enforce this type gate and avoid implicit tensor-construction errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/norm/__init__.py`:
- Around line 186-187: Update the docstring for the parameter named "scale" to
reflect the new API contract: change the type from "torch.Tensor" to indicate it
accepts either a torch.Tensor or a float (with a note that float usage is
deprecated), and clarify expected shape/semantics (e.g., torch.Tensor shape (1,)
or scalar float). Locate the two occurrences in this module where "scale:
torch.Tensor" is documented (the block around the earlier occurrence and the
second occurrence near the later docstring) and update both to "scale:
Union[torch.Tensor, float] — Scale factor for quantization; preferred as
torch.Tensor of shape (1,), float is accepted but deprecated."
---
Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 65-77: The current normalization helper accepts any
non-torch.Tensor and attempts to convert it, which masks invalid types; update
the input validation so that if scale is a torch.Tensor proceed as before, if
it's a float emit the existing FutureWarning and convert via
torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device), but if
scale is neither float nor torch.Tensor raise a TypeError with a clear message;
adjust the branch around scale/ref_tensor and the FutureWarning usage to enforce
this type gate and avoid implicit tensor-construction errors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 87962e9b-6d94-4598-9a79-016ec25b4bbd
📒 Files selected for processing (3)
flashinfer/decode.pyflashinfer/norm/__init__.pyflashinfer/xqa.py
|
/bot run |
jimmyzho
left a comment
There was a problem hiding this comment.
lgtm for decode, just left question for clarity
| rcp_out_scale: float = 1.0, | ||
| q_seq_len: int = 1, | ||
| mask: Optional[torch.Tensor] = None, | ||
| *, |
There was a problem hiding this comment.
Why does these parameters need to be keyword only?
There was a problem hiding this comment.
i think it's optional feature (guessing)
to that end the rationale is documented (end of this page)
https://github.com/flashinfer-ai/flashinfer/blob/main/CONTRIBUTING.md
There was a problem hiding this comment.
i'd like to put a * as soon as basic feature are done in the api. the extra things that pile on later passed positionally just gets worse and worse for api stability
imo positional args shouldn't exceed 10, or it becomes harder to maintain
There was a problem hiding this comment.
great questions! keep them coming
|
[FAILED] Pipeline #46621950: 6/20 passed |
|
ugh internal CI has caught errors on xqa.. i'll fix them later today |
📌 Description
fix api breaking changes for 0.6.7 release
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit