Skip to content

Commit 292f9be

Browse files
authored
fix: include fp8_blockscale_gemm_90 in AOT jit-cache (#2533)
## Summary - Add fp8_blockscale_gemm_90 (gen_fp8_blockscale_gemm_sm90_module) to the AOT build list when SM90 is enabled. - Avoid runtime JIT compilation for fp8_blockscale_gemm_sm90 in environments without CUDA dev headers, which can fail with cublasLt.h not found. ## Changes - flashinfer/aot.py: append gen_fp8_blockscale_gemm_sm90_module() under add_moe + has_sm90 gating. ## Related Issues - Fixes #2527 - #2527 ## Tests <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for FP8 blockscale matrix multiplication operations on SM90 GPU architecture. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c5b8a2e commit 292f9be

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

flashinfer/aot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .jit.gemm import (
5353
gen_gemm_module,
5454
gen_gemm_sm90_module,
55+
gen_fp8_blockscale_gemm_sm90_module,
5556
gen_gemm_sm100_module,
5657
gen_gemm_sm100_module_cutlass_fp4,
5758
gen_gemm_sm100_module_cutlass_fp8,
@@ -477,6 +478,8 @@ def gen_all_modules(
477478
jit_specs.append(gen_gemm_module())
478479
if has_sm90:
479480
jit_specs.append(gen_gemm_sm90_module())
481+
# fp8 blockscale GEMM (SM90)
482+
jit_specs.append(gen_fp8_blockscale_gemm_sm90_module())
480483
jit_specs.append(gen_fp4_quantization_sm90_module())
481484
jit_specs.append(gen_cutlass_fused_moe_sm90_module())
482485
if has_sm100:

0 commit comments

Comments
 (0)