1
+ """
2
+ Copyright (c) 2025 by FlashInfer team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+
16
+ AOT build script for FlashInfer.
17
+
18
+ NOTE (Zihao): The following modules are intentionally excluded from the AOT build:
19
+ - gen_pod_module
20
+ - gen_deepgemm_sm100_module (it doesn't involve host-side compilation)
21
+ """
22
+
1
23
import argparse
2
24
import os
3
25
import shutil
12
34
from .fp8_quantization import gen_mxfp8_quantization_sm100_module
13
35
from .cascade import gen_cascade_module
14
36
from .fp4_quantization import (
15
- gen_fp4_quantization_sm100_module ,
16
37
gen_fp4_quantization_sm90_module ,
38
+ gen_fp4_quantization_sm100_module ,
39
+ gen_fp4_quantization_sm103_module ,
40
+ gen_fp4_quantization_sm110_module ,
41
+ gen_fp4_quantization_sm120_module ,
42
+ gen_fp4_quantization_sm121_module ,
17
43
)
18
44
from .fused_moe import (
19
45
gen_cutlass_fused_moe_sm100_module ,
27
53
gen_gemm_sm100_module_cutlass_fp4 ,
28
54
gen_gemm_sm100_module_cutlass_fp8 ,
29
55
gen_gemm_sm100_module_tgv ,
56
+ gen_gemm_sm120_module ,
57
+ gen_gemm_sm120_module_cutlass_fp4 ,
30
58
gen_trtllm_gen_gemm_module ,
31
59
)
32
60
from .jit import JitSpec , build_jit_specs
33
61
from .jit import env as jit_env
34
62
from .jit import (
63
+ gen_batch_attention_module ,
35
64
gen_batch_decode_module ,
36
65
gen_batch_mla_module ,
37
66
gen_batch_prefill_module ,
67
+ gen_cudnn_fmha_module ,
38
68
gen_fmha_cutlass_sm100a_module ,
39
69
gen_single_decode_module ,
40
70
gen_single_prefill_module ,
@@ -187,6 +217,18 @@ def gen_attention(
187
217
use_sliding_window = use_sliding_window ,
188
218
use_logits_soft_cap = use_logits_soft_cap ,
189
219
)
220
+ yield gen_batch_attention_module (
221
+ dtype_q = dtype_qo ,
222
+ dtype_kv = dtype_kv ,
223
+ dtype_o = dtype_qo ,
224
+ dtype_idx = torch .int32 ,
225
+ head_dim_qk = head_dim_qk ,
226
+ head_dim_vo = head_dim_vo ,
227
+ pos_encoding_mode = 0 ,
228
+ # use_sliding_window=use_sliding_window,
229
+ use_logits_soft_cap = use_logits_soft_cap ,
230
+ use_profiler = False ,
231
+ )
190
232
191
233
# FA3 MHA / MQA / GQA
192
234
if has_sm90 :
@@ -357,8 +399,7 @@ def gen_all_modules(
357
399
fa3_head_dim_ : List [Tuple [int , int ]],
358
400
use_sliding_window_ : List [bool ],
359
401
use_logits_soft_cap_ : List [bool ],
360
- has_sm90 : bool ,
361
- has_sm100 : bool ,
402
+ sm_capabilities : dict ,
362
403
add_comm : bool ,
363
404
add_gemma : bool ,
364
405
add_oai_oss : bool ,
@@ -368,6 +409,12 @@ def gen_all_modules(
368
409
add_xqa : bool ,
369
410
) -> List [JitSpec ]:
370
411
jit_specs : List [JitSpec ] = []
412
+ has_sm90 = sm_capabilities .get ("sm90" , False )
413
+ has_sm100 = sm_capabilities .get ("sm100" , False )
414
+ has_sm103 = sm_capabilities .get ("sm103" , False )
415
+ has_sm110 = sm_capabilities .get ("sm110" , False )
416
+ has_sm120 = sm_capabilities .get ("sm120" , False )
417
+ has_sm121 = sm_capabilities .get ("sm121" , False )
371
418
372
419
jit_specs += list (
373
420
gen_attention (
@@ -406,6 +453,16 @@ def gen_all_modules(
406
453
jit_specs .append (gen_mxfp8_quantization_sm100_module ())
407
454
jit_specs .append (gen_trtllm_gen_gemm_module ())
408
455
jit_specs .append (gen_trtllm_gen_fused_moe_sm100_module ())
456
+ if has_sm103 :
457
+ jit_specs .append (gen_fp4_quantization_sm103_module ())
458
+ if has_sm110 :
459
+ jit_specs .append (gen_fp4_quantization_sm110_module ())
460
+ if has_sm120 :
461
+ jit_specs .append (gen_fp4_quantization_sm120_module ())
462
+ jit_specs .append (gen_gemm_sm120_module ())
463
+ jit_specs .append (gen_gemm_sm120_module_cutlass_fp4 ())
464
+ if has_sm121 :
465
+ jit_specs .append (gen_fp4_quantization_sm121_module ())
409
466
410
467
if add_comm :
411
468
from .comm import gen_trtllm_comm_module , gen_vllm_comm_module
@@ -450,6 +507,9 @@ def gen_all_modules(
450
507
)
451
508
)
452
509
510
+ # Add cuDNN FMHA module
511
+ jit_specs .append (gen_cudnn_fmha_module ())
512
+
453
513
# dedup
454
514
names = set ()
455
515
ret : List [JitSpec ] = []
@@ -523,13 +583,20 @@ def has_sm(compute: str, version: str) -> bool:
523
583
return True
524
584
return version_at_least (torch .version .cuda , version )
525
585
526
- return has_sm ("compute_90" , "12.3" ), has_sm ("compute_100" , "12.8" )
586
+ return {
587
+ "sm90" : has_sm ("compute_90" , "12.3" ),
588
+ "sm100" : has_sm ("compute_100" , "12.8" ),
589
+ "sm103" : has_sm ("compute_103" , "12.8" ),
590
+ "sm110" : has_sm ("compute_110" , "12.9" ),
591
+ "sm120" : has_sm ("compute_120" , "13.0" ),
592
+ "sm121" : has_sm ("compute_121" , "13.0" ),
593
+ }
527
594
528
595
529
596
def register_default_modules () -> int :
530
597
"""Register the default set of modules"""
531
598
config = get_default_config ()
532
- has_sm90 , has_sm100 = detect_sm_capabilities ()
599
+ sm_capabilities = detect_sm_capabilities ()
533
600
534
601
jit_specs = gen_all_modules (
535
602
config ["f16_dtype" ],
@@ -538,8 +605,7 @@ def register_default_modules() -> int:
538
605
config ["fa3_head_dim" ],
539
606
config ["use_sliding_window" ],
540
607
config ["use_logits_soft_cap" ],
541
- has_sm90 ,
542
- has_sm100 ,
608
+ sm_capabilities ,
543
609
config ["add_comm" ],
544
610
config ["add_gemma" ],
545
611
config ["add_oai_oss" ],
@@ -649,7 +715,7 @@ def main():
649
715
if "FLASHINFER_CUDA_ARCH_LIST" not in os .environ :
650
716
raise RuntimeError ("Please explicitly set env var FLASHINFER_CUDA_ARCH_LIST." )
651
717
652
- has_sm90 , has_sm100 = detect_sm_capabilities ()
718
+ sm_capabilities = detect_sm_capabilities ()
653
719
654
720
# Update data dir
655
721
jit_env .FLASHINFER_CSRC_DIR = project_root / "csrc"
@@ -678,8 +744,10 @@ def main():
678
744
print (" use_sliding_window:" , config ["use_sliding_window" ])
679
745
print (" use_logits_soft_cap:" , config ["use_logits_soft_cap" ])
680
746
print (" FLASHINFER_CUDA_ARCH_LIST:" , os .environ ["FLASHINFER_CUDA_ARCH_LIST" ])
681
- print (" has_sm90:" , has_sm90 )
682
- print (" has_sm100:" , has_sm100 )
747
+ print (" SM capabilities detected:" )
748
+ for sm_name , has_sm in sm_capabilities .items ():
749
+ if has_sm :
750
+ print (f" { sm_name } : True" )
683
751
for key in [
684
752
"add_comm" ,
685
753
"add_gemma" ,
@@ -701,8 +769,7 @@ def main():
701
769
config ["fa3_head_dim" ],
702
770
config ["use_sliding_window" ],
703
771
config ["use_logits_soft_cap" ],
704
- has_sm90 ,
705
- has_sm100 ,
772
+ sm_capabilities ,
706
773
config ["add_comm" ],
707
774
config ["add_gemma" ],
708
775
config ["add_oai_oss" ],
0 commit comments