Skip to content

Commit 9573c71

Browse files
authored
ci: complete the list of modules in aot.py (#1746)
1 parent 7c8bcdb commit 9573c71

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

csrc/cudnn_sdpa_kernel_launcher.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ void decode(int64_t max_s_kv, at::Tensor q, at::Tensor k_cache, at::Tensor v_cac
947947
nullptr};
948948
static CUfunction lean_attn_reduction{nullptr};
949949

950-
static uint32_t sm_count = 0;
950+
static int sm_count = 0;
951951

952952
// Setup decode kernels
953953
if (hfunc_decode[0] == nullptr) {

flashinfer/aot.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
16+
AOT build script for FlashInfer.
17+
18+
NOTE (Zihao): The following modules are intentionally excluded from the AOT build:
19+
- gen_pod_module
20+
- gen_deepgemm_sm100_module (it doesn't involve host-side compilation)
21+
"""
22+
123
import argparse
224
import os
325
import shutil
@@ -12,8 +34,12 @@
1234
from .fp8_quantization import gen_mxfp8_quantization_sm100_module
1335
from .cascade import gen_cascade_module
1436
from .fp4_quantization import (
15-
gen_fp4_quantization_sm100_module,
1637
gen_fp4_quantization_sm90_module,
38+
gen_fp4_quantization_sm100_module,
39+
gen_fp4_quantization_sm103_module,
40+
gen_fp4_quantization_sm110_module,
41+
gen_fp4_quantization_sm120_module,
42+
gen_fp4_quantization_sm121_module,
1743
)
1844
from .fused_moe import (
1945
gen_cutlass_fused_moe_sm100_module,
@@ -27,14 +53,18 @@
2753
gen_gemm_sm100_module_cutlass_fp4,
2854
gen_gemm_sm100_module_cutlass_fp8,
2955
gen_gemm_sm100_module_tgv,
56+
gen_gemm_sm120_module,
57+
gen_gemm_sm120_module_cutlass_fp4,
3058
gen_trtllm_gen_gemm_module,
3159
)
3260
from .jit import JitSpec, build_jit_specs
3361
from .jit import env as jit_env
3462
from .jit import (
63+
gen_batch_attention_module,
3564
gen_batch_decode_module,
3665
gen_batch_mla_module,
3766
gen_batch_prefill_module,
67+
gen_cudnn_fmha_module,
3868
gen_fmha_cutlass_sm100a_module,
3969
gen_single_decode_module,
4070
gen_single_prefill_module,
@@ -187,6 +217,18 @@ def gen_attention(
187217
use_sliding_window=use_sliding_window,
188218
use_logits_soft_cap=use_logits_soft_cap,
189219
)
220+
yield gen_batch_attention_module(
221+
dtype_q=dtype_qo,
222+
dtype_kv=dtype_kv,
223+
dtype_o=dtype_qo,
224+
dtype_idx=torch.int32,
225+
head_dim_qk=head_dim_qk,
226+
head_dim_vo=head_dim_vo,
227+
pos_encoding_mode=0,
228+
# use_sliding_window=use_sliding_window,
229+
use_logits_soft_cap=use_logits_soft_cap,
230+
use_profiler=False,
231+
)
190232

191233
# FA3 MHA / MQA / GQA
192234
if has_sm90:
@@ -357,8 +399,7 @@ def gen_all_modules(
357399
fa3_head_dim_: List[Tuple[int, int]],
358400
use_sliding_window_: List[bool],
359401
use_logits_soft_cap_: List[bool],
360-
has_sm90: bool,
361-
has_sm100: bool,
402+
sm_capabilities: dict,
362403
add_comm: bool,
363404
add_gemma: bool,
364405
add_oai_oss: bool,
@@ -368,6 +409,12 @@ def gen_all_modules(
368409
add_xqa: bool,
369410
) -> List[JitSpec]:
370411
jit_specs: List[JitSpec] = []
412+
has_sm90 = sm_capabilities.get("sm90", False)
413+
has_sm100 = sm_capabilities.get("sm100", False)
414+
has_sm103 = sm_capabilities.get("sm103", False)
415+
has_sm110 = sm_capabilities.get("sm110", False)
416+
has_sm120 = sm_capabilities.get("sm120", False)
417+
has_sm121 = sm_capabilities.get("sm121", False)
371418

372419
jit_specs += list(
373420
gen_attention(
@@ -406,6 +453,16 @@ def gen_all_modules(
406453
jit_specs.append(gen_mxfp8_quantization_sm100_module())
407454
jit_specs.append(gen_trtllm_gen_gemm_module())
408455
jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module())
456+
if has_sm103:
457+
jit_specs.append(gen_fp4_quantization_sm103_module())
458+
if has_sm110:
459+
jit_specs.append(gen_fp4_quantization_sm110_module())
460+
if has_sm120:
461+
jit_specs.append(gen_fp4_quantization_sm120_module())
462+
jit_specs.append(gen_gemm_sm120_module())
463+
jit_specs.append(gen_gemm_sm120_module_cutlass_fp4())
464+
if has_sm121:
465+
jit_specs.append(gen_fp4_quantization_sm121_module())
409466

410467
if add_comm:
411468
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
@@ -450,6 +507,9 @@ def gen_all_modules(
450507
)
451508
)
452509

510+
# Add cuDNN FMHA module
511+
jit_specs.append(gen_cudnn_fmha_module())
512+
453513
# dedup
454514
names = set()
455515
ret: List[JitSpec] = []
@@ -523,13 +583,20 @@ def has_sm(compute: str, version: str) -> bool:
523583
return True
524584
return version_at_least(torch.version.cuda, version)
525585

526-
return has_sm("compute_90", "12.3"), has_sm("compute_100", "12.8")
586+
return {
587+
"sm90": has_sm("compute_90", "12.3"),
588+
"sm100": has_sm("compute_100", "12.8"),
589+
"sm103": has_sm("compute_103", "12.8"),
590+
"sm110": has_sm("compute_110", "12.9"),
591+
"sm120": has_sm("compute_120", "13.0"),
592+
"sm121": has_sm("compute_121", "13.0"),
593+
}
527594

528595

529596
def register_default_modules() -> int:
530597
"""Register the default set of modules"""
531598
config = get_default_config()
532-
has_sm90, has_sm100 = detect_sm_capabilities()
599+
sm_capabilities = detect_sm_capabilities()
533600

534601
jit_specs = gen_all_modules(
535602
config["f16_dtype"],
@@ -538,8 +605,7 @@ def register_default_modules() -> int:
538605
config["fa3_head_dim"],
539606
config["use_sliding_window"],
540607
config["use_logits_soft_cap"],
541-
has_sm90,
542-
has_sm100,
608+
sm_capabilities,
543609
config["add_comm"],
544610
config["add_gemma"],
545611
config["add_oai_oss"],
@@ -649,7 +715,7 @@ def main():
649715
if "FLASHINFER_CUDA_ARCH_LIST" not in os.environ:
650716
raise RuntimeError("Please explicitly set env var FLASHINFER_CUDA_ARCH_LIST.")
651717

652-
has_sm90, has_sm100 = detect_sm_capabilities()
718+
sm_capabilities = detect_sm_capabilities()
653719

654720
# Update data dir
655721
jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc"
@@ -678,8 +744,10 @@ def main():
678744
print(" use_sliding_window:", config["use_sliding_window"])
679745
print(" use_logits_soft_cap:", config["use_logits_soft_cap"])
680746
print(" FLASHINFER_CUDA_ARCH_LIST:", os.environ["FLASHINFER_CUDA_ARCH_LIST"])
681-
print(" has_sm90:", has_sm90)
682-
print(" has_sm100:", has_sm100)
747+
print(" SM capabilities detected:")
748+
for sm_name, has_sm in sm_capabilities.items():
749+
if has_sm:
750+
print(f" {sm_name}: True")
683751
for key in [
684752
"add_comm",
685753
"add_gemma",
@@ -701,8 +769,7 @@ def main():
701769
config["fa3_head_dim"],
702770
config["use_sliding_window"],
703771
config["use_logits_soft_cap"],
704-
has_sm90,
705-
has_sm100,
772+
sm_capabilities,
706773
config["add_comm"],
707774
config["add_gemma"],
708775
config["add_oai_oss"],

0 commit comments

Comments
 (0)