[RFC]MXFP8 autotune IMA at some batch size#2800
[RFC]MXFP8 autotune IMA at some batch size#2800charlotte12l wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR fixes an out-of-bounds read issue in the MoE GEMM kernel when using MxFp8 quantization. The fix reconstructs the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ❌ 3❌ Failed checks (1 warning, 2 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can get early access to new features in CodeRabbit.Enable the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request provides a provisional solution to a critical CUDA 'illegal memory access' error encountered during the autotuning process for MXFP8 quantization within vLLM, particularly affecting smaller batch sizes. The change ensures that the scale tensors used in the Mixture-of-Experts (MoE) GEMM kernel are allocated with the correct dimensions, thereby preventing runtime failures. While this fix resolves the error and maintains evaluation accuracy, the author notes a performance regression, indicating that further optimization or a more fundamental solution may be required. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a CUDA error: an illegal memory access was encountered that occurs during autotuning for MXFP8 with small batch sizes. The fix correctly identifies that the scale tensor created by DynamicTensorSpec can be undersized and resizes it to the correct dimension, which effectively prevents the crash. My feedback focuses on improving this fix to address the performance regression mentioned in the pull request description. I suggest using more realistic random data for the recreated scale tensor instead of filling it with ones. This should help the autotuner select a more optimal and performant kernel.
| if current_hidden_states_scale.numel() < sf_size: | ||
| current_hidden_states_scale = torch.ones( | ||
| (sf_size,), | ||
| dtype=torch.uint8, | ||
| device=hidden_states.device, | ||
| ) |
There was a problem hiding this comment.
This is a good catch to fix the illegal memory access error during autotuning. The logic to check and resize the current_hidden_states_scale tensor is correct.
However, as you noted, this fix may cause a performance regression. This is likely because filling the tensor with torch.ones does not provide realistic scale values for the autotuner. The autotuner might be selecting a suboptimal kernel based on this non-representative data.
To potentially resolve the performance regression, I suggest initializing the tensor with random data, which better simulates real-world scale values. This should help the autotuner find a more performant kernel.
| if current_hidden_states_scale.numel() < sf_size: | |
| current_hidden_states_scale = torch.ones( | |
| (sf_size,), | |
| dtype=torch.uint8, | |
| device=hidden_states.device, | |
| ) | |
| if current_hidden_states_scale.numel() < sf_size: | |
| current_hidden_states_scale = torch.randint( | |
| 0, 256, (sf_size,), | |
| dtype=torch.uint8, | |
| device=hidden_states.device, | |
| ) |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)
1160-1173: UE8M0 scale value encoding: byte 1 ≠ 1.0 scale.The fix correctly prevents out-of-bounds reads during autotuner profiling. However,
torch.ones(..., dtype=torch.uint8)creates a tensor filled with byte value1, which in UE8M0 format represents2^(1-127) = 2^(-126) ≈ 0, not a scale of 1.0.For UE8M0, byte value
127represents2^(127-127) = 1.0. While this may not cause crashes, using near-zero scales during profiling could affect tactic selection accuracy.Suggested fix to use correct UE8M0 encoding for 1.0
if current_hidden_states_scale.numel() < sf_size: - current_hidden_states_scale = torch.ones( + current_hidden_states_scale = torch.full( (sf_size,), + 127, # UE8M0 encoding for scale = 1.0 dtype=torch.uint8, device=hidden_states.device, )Based on learnings: In FlashInfer's quantization code,
torch.float8_e4m3fnis used as a carrier dtype for 1-byte scale factors (UE8M0, etc.) — the raw bytes are interpreted by C++ kernels according to the actual format semantics.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1160 - 1173, The created padding scale uses torch.ones(..., dtype=torch.uint8) which sets raw byte value 1 (not 1.0 in UE8M0); replace that with a uint8 tensor filled with the UE8M0 encoding for 1.0 (byte value 127) so current_hidden_states_scale has correct 1.0 scale bytes; update the creation site that constructs current_hidden_states_scale (the torch.ones call using sf_size, dtype=torch.uint8, device=hidden_states.device) to create a tensor filled with 127 instead (or use the established float8 carrier path used elsewhere) to ensure profiling uses true 1.0 scales.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1160-1173: The created padding scale uses torch.ones(...,
dtype=torch.uint8) which sets raw byte value 1 (not 1.0 in UE8M0); replace that
with a uint8 tensor filled with the UE8M0 encoding for 1.0 (byte value 127) so
current_hidden_states_scale has correct 1.0 scale bytes; update the creation
site that constructs current_hidden_states_scale (the torch.ones call using
sf_size, dtype=torch.uint8, device=hidden_states.device) to create a tensor
filled with 127 instead (or use the established float8 carrier path used
elsewhere) to ensure profiling uses true 1.0 scales.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: efbff30c-41da-4fa7-aa5d-e6813d2ddae6
📒 Files selected for processing (1)
flashinfer/fused_moe/core.py
| current_hidden_states_scale = torch.ones( | ||
| (sf_size,), | ||
| dtype=torch.uint8, | ||
| device=hidden_states.device, | ||
| ) |
There was a problem hiding this comment.
Considering the e8m0 data, maybe using 126-127 (0.5-2).
| current_hidden_states_scale = torch.ones( | |
| (sf_size,), | |
| dtype=torch.uint8, | |
| device=hidden_states.device, | |
| ) | |
| if current_hidden_states_scale.numel() < sf_size: | |
| current_hidden_states_scale = torch.randint( | |
| 126, 128, (sf_size,), | |
| dtype=torch.uint8, | |
| device=hidden_states.device, | |
| ) |
| @@ -1157,7 +1157,20 @@ def forward( | |||
| ) | |||
| elif self.fp8_quantization_type == Fp8QuantizationType.MxFp8: | |||
| current_hidden_states_scale = extra_inputs[0] | |||
There was a problem hiding this comment.
| current_hidden_states_scale = extra_inputs[0] | |
| current_hidden_states_scale = hidden_states_scale |
…2725) ## Summary SM120 desktop Blackwell GPUs (RTX PRO 6000, RTX 5090) are blocked from NVFP4 MoE grouped GEMM due to hardcoded SM100-only checks. **Changes:** - `jit/fused_moe.py`: Add major version 12 to `supported_major_versions` - `csrc/trtllm_fused_moe_kernel_launcher.cu`: `ICHECK_EQ(major, 10)` -> `ICHECK_GE(major, 10)` **Benchmark** (Qwen3.5-397B on 4x RTX PRO 6000 SM120): | Config | tok/s | Output | |--------|-------|--------| | compute_120f (CUDA 13.0) | 39.0 | Correct | | compute_120a (CUDA 12.8) | 14.6 | Correct (slow fallback) | | Marlin W4A16 | 46-49 | Correct | **Root cause:** All TMA WS grouped GEMM autotuner tactics fail on `compute_120a`, requiring `compute_120f` (CUDA 13.0). CuTe DSL `admissible_archs` in vendored CUTLASS also needs `sm_120a`/`sm_120f` (cpasync/copy.py, tcgen05/mma.py, arch/mbar.py, etc). Related: CUTLASS #2820, #2800; vLLM #33416, #33333; FlashInfer #2577 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Broadened GPU architecture checks to accept additional modern compute capabilities (SM 10.x and 12.x), improving compatibility and clearer SM reporting. * Improved compute-capability detection and encoding, preserving user-provided architecture suffixes and more accurately generating nvcc architecture flags. * Expanded JIT module generation to include additional CUDA majors so fused-MoE kernels run on more recent GPUs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Brandon Music <brandon.m.music@gmail.com> Co-authored-by: Brandon Music <brandonmmusic-max@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Brandon Music <brandonmusic@pop-os.tail8674da.ts.net>
📌 Description
We are running vLLM with MXFP8 and will get the below flashinfer autotune error with
max_batch_size=64. However,max_batch_size=128ormax_batch_size=256is fine. We tried withCUDA_LAUNCH_BLOCKING=1however it's the same stacktrace.The current PR could avoid the error and eval accuracy is fine. However we doubt if it's the correct fix as perf regressed.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes