9
9
import torch .version
10
10
11
11
from .activation import act_func_def_str , gen_act_and_mul_module
12
+ from .fp8_quantization import gen_mxfp8_quantization_sm100_module
12
13
from .cascade import gen_cascade_module
13
14
from .fp4_quantization import (
14
15
gen_fp4_quantization_sm100_module ,
17
18
from .fused_moe import (
18
19
gen_cutlass_fused_moe_sm100_module ,
19
20
gen_cutlass_fused_moe_sm90_module ,
21
+ gen_trtllm_gen_fused_moe_sm100_module ,
22
+ )
23
+ from .gemm import (
24
+ gen_gemm_module ,
25
+ gen_gemm_sm90_module ,
26
+ gen_gemm_sm100_module ,
27
+ gen_gemm_sm100_module_cutlass_fp4 ,
28
+ gen_gemm_sm100_module_cutlass_fp8 ,
29
+ gen_trtllm_gen_gemm_module ,
20
30
)
21
- from .gemm import gen_gemm_module , gen_gemm_sm90_module , gen_gemm_sm100_module
22
31
from .jit import JitSpec , build_jit_specs
23
32
from .jit import env as jit_env
24
33
from .jit import (
25
34
gen_batch_decode_module ,
26
35
gen_batch_mla_module ,
27
36
gen_batch_prefill_module ,
28
37
gen_fmha_cutlass_sm100a_module ,
29
- gen_jit_spec ,
30
38
gen_single_decode_module ,
31
39
gen_single_prefill_module ,
40
+ gen_trtllm_gen_fmha_module ,
32
41
)
33
42
from .mla import gen_mla_module
34
43
from .norm import gen_norm_module
35
44
from .page import gen_page_module
36
45
from .quantization import gen_quantization_module
37
46
from .rope import gen_rope_module
38
47
from .sampling import gen_sampling_module
39
- from .tllm_utils import get_trtllm_utils_spec
40
- from .utils import version_at_least
48
+ from .tllm_utils import gen_trtllm_utils_module
49
+ from .utils import gen_logging_module , version_at_least
50
+ from .xqa import gen_xqa_module
41
51
from .compilation_context import CompilationContext
42
52
43
53
@@ -275,6 +285,9 @@ def gen_attention(
275
285
use_logits_soft_cap = False ,
276
286
)
277
287
288
+ # trtllm_gen_fmha
289
+ yield gen_trtllm_gen_fmha_module ()
290
+
278
291
# MLA
279
292
# NOTE: fp8 kv not supported in MLA
280
293
mla_backend_ = ["fa2" ] + (["fa3" ] if has_sm90 else [])
@@ -296,6 +309,46 @@ def gen_attention(
296
309
yield gen_mla_module ()
297
310
298
311
312
+ def gen_xqa (
313
+ use_fp16_ : List [bool ],
314
+ token_per_page_ : List [int ],
315
+ head_size_ : List [int ],
316
+ head_grp_size_ : List [int ],
317
+ use_sliding_window_ : List [bool ],
318
+ has_sm90 : bool ,
319
+ ) -> Iterator [JitSpec ]:
320
+ """Generate XQA modules for various configurations."""
321
+ if not has_sm90 :
322
+ return # XQA requires SM90+
323
+
324
+ for (
325
+ use_fp16 ,
326
+ token_per_page ,
327
+ head_size ,
328
+ head_grp_size ,
329
+ use_sliding_window ,
330
+ ) in product (
331
+ use_fp16_ ,
332
+ token_per_page_ ,
333
+ head_size_ ,
334
+ head_grp_size_ ,
335
+ use_sliding_window_ ,
336
+ ):
337
+ # Skip invalid configurations
338
+ if head_size % 16 != 0 or head_size > 256 or head_size < 16 :
339
+ continue
340
+ if token_per_page not in [16 , 32 , 64 , 128 ]:
341
+ continue
342
+
343
+ yield gen_xqa_module (
344
+ use_fp16 = use_fp16 ,
345
+ token_per_page = token_per_page ,
346
+ head_size = head_size ,
347
+ head_grp_size = head_grp_size ,
348
+ use_sliding_window = use_sliding_window ,
349
+ )
350
+
351
+
299
352
def gen_all_modules (
300
353
f16_dtype_ : List [torch .dtype ],
301
354
f8_dtype_ : List [torch .dtype ],
@@ -311,6 +364,7 @@ def gen_all_modules(
311
364
add_moe : bool ,
312
365
add_act : bool ,
313
366
add_misc : bool ,
367
+ add_xqa : bool ,
314
368
) -> List [JitSpec ]:
315
369
jit_specs : List [JitSpec ] = []
316
370
@@ -343,14 +397,23 @@ def gen_all_modules(
343
397
jit_specs .append (gen_fp4_quantization_sm100_module ())
344
398
jit_specs .append (gen_cutlass_fused_moe_sm100_module ())
345
399
jit_specs .append (gen_gemm_sm100_module ())
400
+ jit_specs .append (gen_gemm_sm100_module_cutlass_fp4 ())
401
+ jit_specs .append (gen_gemm_sm100_module_cutlass_fp8 ())
402
+ jit_specs .append (gen_mxfp8_quantization_sm100_module ())
403
+ jit_specs .append (gen_trtllm_gen_gemm_module ())
404
+ jit_specs .append (gen_trtllm_gen_fused_moe_sm100_module ())
346
405
347
406
if add_comm :
348
407
from .comm import gen_trtllm_comm_module , gen_vllm_comm_module
349
408
from .comm .nvshmem import gen_nvshmem_module
409
+ from .comm .trtllm_alltoall import gen_comm_alltoall_module
410
+ from .comm .trtllm_mnnvl_ar import gen_trtllm_mnnvl_comm_module
350
411
351
412
jit_specs .append (gen_nvshmem_module ())
413
+ jit_specs .append (gen_comm_alltoall_module ())
352
414
if has_sm100 :
353
415
jit_specs .append (gen_trtllm_comm_module ())
416
+ jit_specs .append (gen_trtllm_mnnvl_comm_module ())
354
417
jit_specs .append (gen_vllm_comm_module ())
355
418
356
419
if add_misc :
@@ -363,7 +426,25 @@ def gen_all_modules(
363
426
gen_sampling_module (),
364
427
]
365
428
if has_sm90 :
366
- jit_specs .append (get_trtllm_utils_spec ())
429
+ jit_specs .append (gen_trtllm_utils_module ())
430
+
431
+ if add_xqa :
432
+ # Define XQA configurations to iterate over
433
+ xqa_use_fp16_ = [True , False ] # fp16 and bf16
434
+ xqa_token_per_page_ = [16 , 32 , 64 , 128 ]
435
+ xqa_head_size_ = [64 , 128 , 256 ]
436
+ xqa_head_grp_size_ = [1 , 2 , 4 , 8 ] # Different group sizes for MQA/GQA
437
+
438
+ jit_specs += list (
439
+ gen_xqa (
440
+ xqa_use_fp16_ ,
441
+ xqa_token_per_page_ ,
442
+ xqa_head_size_ ,
443
+ xqa_head_grp_size_ ,
444
+ use_sliding_window_ ,
445
+ has_sm90 ,
446
+ )
447
+ )
367
448
368
449
# dedup
369
450
names = set ()
@@ -479,6 +560,11 @@ def main():
479
560
type = parse_bool ,
480
561
help = "Add miscellaneous kernels" ,
481
562
)
563
+ parser .add_argument (
564
+ "--add-xqa" ,
565
+ type = parse_bool ,
566
+ help = "Add XQA (Cross-Query Attention) kernels" ,
567
+ )
482
568
args = parser .parse_args ()
483
569
484
570
# Default values
@@ -488,13 +574,13 @@ def main():
488
574
fa2_head_dim_ = [
489
575
(64 , 64 ),
490
576
(128 , 128 ),
491
- # (256, 256),
577
+ (256 , 256 ),
492
578
]
493
579
fa3_head_dim_ = [
494
580
(192 , 128 ),
495
581
(128 , 128 ),
496
- # (64, 64),
497
- # (256, 256),
582
+ (64 , 64 ),
583
+ (256 , 256 ),
498
584
]
499
585
f16_dtype_ = [
500
586
torch .float16 ,
@@ -506,18 +592,19 @@ def main():
506
592
]
507
593
use_sliding_window_ = [
508
594
False ,
509
- # True,
595
+ True ,
510
596
]
511
597
use_logits_soft_cap_ = [
512
598
False ,
513
- # True,
599
+ True ,
514
600
]
515
- add_comm = False
516
- add_gemma = False
601
+ add_comm = True
602
+ add_gemma = True
517
603
add_oai_oss = True
518
- add_moe = False
519
- add_act = False
604
+ add_moe = True
605
+ add_act = True
520
606
add_misc = True
607
+ add_xqa = True
521
608
522
609
# Override
523
610
if args .out_dir :
@@ -537,17 +624,19 @@ def main():
537
624
if args .use_logits_soft_cap :
538
625
use_logits_soft_cap_ = [parse_bool (s ) for s in args .use_logits_soft_cap ]
539
626
if args .add_comm is not None :
540
- add_comm = bool ( args .add_comm )
627
+ add_comm = args .add_comm
541
628
if args .add_gemma is not None :
542
- add_gemma = bool ( args .add_gemma )
629
+ add_gemma = args .add_gemma
543
630
if args .add_oai_oss is not None :
544
- add_oai_oss = bool ( args .add_oai_oss )
631
+ add_oai_oss = args .add_oai_oss
545
632
if args .add_moe is not None :
546
- add_moe = bool ( args .add_moe )
633
+ add_moe = args .add_moe
547
634
if args .add_act is not None :
548
- add_act = bool ( args .add_act )
635
+ add_act = args .add_act
549
636
if args .add_misc is not None :
550
- add_misc = bool (args .add_misc )
637
+ add_misc = args .add_misc
638
+ if args .add_xqa is not None :
639
+ add_xqa = args .add_xqa
551
640
552
641
# Cuda Arch
553
642
if "FLASHINFER_CUDA_ARCH_LIST" not in os .environ :
@@ -603,21 +692,11 @@ def has_sm(compute: str, version: str) -> bool:
603
692
print (" add_moe:" , add_moe )
604
693
print (" add_act:" , add_act )
605
694
print (" add_misc:" , add_misc )
695
+ print (" add_xqa:" , add_xqa )
606
696
607
697
# Generate JIT specs
608
698
print ("Generating JIT specs..." )
609
- jit_specs = [
610
- gen_jit_spec (
611
- "logging" ,
612
- [
613
- jit_env .FLASHINFER_CSRC_DIR / "logging.cc" ,
614
- ],
615
- extra_include_paths = [
616
- jit_env .SPDLOG_INCLUDE_DIR ,
617
- jit_env .FLASHINFER_INCLUDE_DIR ,
618
- ],
619
- )
620
- ]
699
+ jit_specs = [gen_logging_module ()]
621
700
jit_specs += gen_all_modules (
622
701
f16_dtype_ ,
623
702
f8_dtype_ ,
@@ -633,6 +712,7 @@ def has_sm(compute: str, version: str) -> bool:
633
712
add_moe ,
634
713
add_act ,
635
714
add_misc ,
715
+ add_xqa ,
636
716
)
637
717
print ("Total ops:" , len (jit_specs ))
638
718
0 commit comments