6
6
from typing import List , Tuple
7
7
8
8
import torch
9
+ import torch .version
9
10
from torch .utils .cpp_extension import _get_cuda_arch_flags
10
11
11
12
from .activation import act_func_def_str , gen_act_and_mul_module
12
13
from .cascade import gen_cascade_module
13
- from .comm import gen_trtllm_comm_module , gen_vllm_comm_module
14
14
from .fp4_quantization import gen_fp4_quantization_sm100_module
15
15
from .fused_moe import gen_cutlass_fused_moe_sm100_module
16
16
from .gemm import gen_gemm_module , gen_gemm_sm90_module , gen_gemm_sm100_module
@@ -42,11 +42,15 @@ def gen_fa2(
42
42
head_dim_vo : int ,
43
43
use_sliding_window : bool ,
44
44
use_logits_soft_cap : bool ,
45
+ use_attention_sink : bool ,
45
46
) -> List [JitSpec ]:
46
47
if dtype_qo .itemsize == dtype_kv .itemsize and dtype_qo != dtype_kv :
47
48
return []
48
49
if dtype_qo .itemsize == 1 :
49
50
return [] # fp8 tensor cores not supported in fa2
51
+
52
+ # TODO: support for AoT sink attention.
53
+
50
54
return [
51
55
gen_single_prefill_module (
52
56
backend = "fa2" ,
@@ -105,6 +109,7 @@ def gen_fa3(
105
109
head_dim_vo : int ,
106
110
use_sliding_window : bool ,
107
111
use_logits_soft_cap : bool ,
112
+ use_attention_sink : bool ,
108
113
) -> List [JitSpec ]:
109
114
if dtype_q != dtype_kv :
110
115
return [] # fa3 template do not support mixed precision
@@ -116,6 +121,8 @@ def gen_fa3(
116
121
if head_dim_qk == 192 or head_dim_qk == 64 :
117
122
return [] # (192, 128) & (64, 64) not supported for fp8 yet.
118
123
124
+ # TODO: support for AoT sink attention.
125
+
119
126
return [
120
127
gen_single_prefill_module (
121
128
backend = "fa3" ,
@@ -155,6 +162,7 @@ def gen_attention(
155
162
has_sm90 : bool ,
156
163
has_sm100 : bool ,
157
164
add_gemma : bool ,
165
+ add_oai_oss : bool ,
158
166
) -> List [JitSpec ]:
159
167
head_dim_ckv = 512
160
168
head_dim_kpe = 64
@@ -181,6 +189,7 @@ def gen_attention(
181
189
head_dim_vo = head_dim_vo ,
182
190
use_sliding_window = use_sliding_window ,
183
191
use_logits_soft_cap = use_logits_soft_cap ,
192
+ use_attention_sink = False ,
184
193
)
185
194
186
195
# FA3 MHA / MQA / GQA
@@ -206,6 +215,7 @@ def gen_attention(
206
215
head_dim_vo = head_dim_vo ,
207
216
use_sliding_window = use_sliding_window ,
208
217
use_logits_soft_cap = use_logits_soft_cap ,
218
+ use_attention_sink = False ,
209
219
)
210
220
211
221
# Gemma
@@ -226,6 +236,7 @@ def gen_attention(
226
236
head_dim_vo = 256 ,
227
237
use_sliding_window = use_sliding_window ,
228
238
use_logits_soft_cap = use_logits_soft_cap ,
239
+ use_attention_sink = False ,
229
240
)
230
241
if has_sm90 :
231
242
for (
@@ -245,8 +256,30 @@ def gen_attention(
245
256
head_dim_vo = 256 ,
246
257
use_sliding_window = use_sliding_window ,
247
258
use_logits_soft_cap = use_logits_soft_cap ,
259
+ use_attention_sink = False ,
248
260
)
249
261
262
+ # OAI OSS
263
+ if add_oai_oss :
264
+ for (
265
+ dtype_qo ,
266
+ dtype_kv ,
267
+ use_sliding_window ,
268
+ ) in product (
269
+ f16_dtype_ ,
270
+ f16_dtype_ + f8_dtype_ ,
271
+ [True ],
272
+ ):
273
+ jit_specs += gen_fa2 (
274
+ dtype_qo = dtype_qo ,
275
+ dtype_kv = dtype_kv ,
276
+ head_dim_qk = 64 ,
277
+ head_dim_vo = 64 ,
278
+ use_sliding_window = use_sliding_window ,
279
+ use_logits_soft_cap = False ,
280
+ use_attention_sink = True ,
281
+ )
282
+
250
283
# fmha_cutlass_sm100a
251
284
# NOTE: currently there's only one uri.
252
285
if has_sm100 :
@@ -301,7 +334,12 @@ def gen_all_modules(
301
334
use_logits_soft_cap_ : List [bool ],
302
335
has_sm90 : bool ,
303
336
has_sm100 : bool ,
337
+ add_comm : bool ,
304
338
add_gemma : bool ,
339
+ add_oai_oss : bool ,
340
+ add_moe : bool ,
341
+ add_act : bool ,
342
+ add_misc : bool ,
305
343
) -> List [JitSpec ]:
306
344
jit_specs : List [JitSpec ] = []
307
345
@@ -315,27 +353,40 @@ def gen_all_modules(
315
353
has_sm90 ,
316
354
has_sm100 ,
317
355
add_gemma ,
356
+ add_oai_oss ,
318
357
)
319
- for act_name in act_func_def_str :
320
- jit_specs .append (gen_act_and_mul_module (act_name ))
321
- jit_specs .append (gen_gemm_module ())
358
+
359
+ if add_act :
360
+ for act_name in act_func_def_str :
361
+ jit_specs .append (gen_act_and_mul_module (act_name ))
362
+
363
+ if add_moe :
364
+ jit_specs .append (gen_gemm_module ())
365
+ if has_sm90 :
366
+ jit_specs .append (gen_gemm_sm90_module ())
367
+ if has_sm100 :
368
+ jit_specs .append (gen_cutlass_fused_moe_sm100_module ())
369
+ jit_specs .append (gen_fp4_quantization_sm100_module ())
370
+ jit_specs .append (gen_gemm_sm100_module ())
371
+
372
+ if add_comm :
373
+ from .comm import gen_trtllm_comm_module , gen_vllm_comm_module
374
+
375
+ if has_sm100 :
376
+ jit_specs .append (gen_trtllm_comm_module ())
377
+ jit_specs .append (gen_vllm_comm_module ())
378
+
379
+ if add_misc :
380
+ jit_specs += [
381
+ gen_cascade_module (),
382
+ gen_norm_module (),
383
+ gen_page_module (),
384
+ gen_quantization_module (),
385
+ gen_rope_module (),
386
+ gen_sampling_module (),
387
+ ]
322
388
if has_sm90 :
323
- jit_specs .append (gen_gemm_sm90_module ())
324
- if has_sm100 :
325
- jit_specs .append (gen_cutlass_fused_moe_sm100_module ())
326
- jit_specs .append (gen_fp4_quantization_sm100_module ())
327
- jit_specs .append (gen_gemm_sm100_module ())
328
- jit_specs .append (gen_trtllm_comm_module ())
329
-
330
- jit_specs += [
331
- gen_cascade_module (),
332
- gen_vllm_comm_module (),
333
- gen_norm_module (),
334
- gen_page_module (),
335
- gen_quantization_module (),
336
- gen_rope_module (),
337
- gen_sampling_module (),
338
- ]
389
+ jit_specs .append (get_trtllm_utils_spec ())
339
390
340
391
# dedup
341
392
names = set ()
@@ -421,11 +472,36 @@ def main():
421
472
nargs = "*" ,
422
473
help = "Use logits soft cap" ,
423
474
)
475
+ parser .add_argument (
476
+ "--add-comm" ,
477
+ type = parse_bool ,
478
+ help = "Add communication kernels (trtllm_comm, vllm_comm)" ,
479
+ )
424
480
parser .add_argument (
425
481
"--add-gemma" ,
426
482
type = parse_bool ,
427
483
help = "Add kernels for Gemma Model (head_dim=256, use_sliding_window, use_logits_soft_cap)" ,
428
484
)
485
+ parser .add_argument (
486
+ "--add-oai-oss" ,
487
+ type = parse_bool ,
488
+ help = "Add kernels for OAI OSS Model (head_dim=64, use_sliding_window)" ,
489
+ )
490
+ parser .add_argument (
491
+ "--add-moe" ,
492
+ type = parse_bool ,
493
+ help = "Add MoE kernels" ,
494
+ )
495
+ parser .add_argument (
496
+ "--add-act" ,
497
+ type = parse_bool ,
498
+ help = "Add activation kernels" ,
499
+ )
500
+ parser .add_argument (
501
+ "--add-misc" ,
502
+ type = parse_bool ,
503
+ help = "Add miscellaneous kernels" ,
504
+ )
429
505
args = parser .parse_args ()
430
506
431
507
# Default values
@@ -459,7 +535,12 @@ def main():
459
535
False ,
460
536
# True,
461
537
]
538
+ add_comm = False
462
539
add_gemma = True
540
+ add_oai_oss = True
541
+ add_moe = False
542
+ add_act = True
543
+ add_misc = True
463
544
464
545
# Override
465
546
if args .out_dir :
@@ -478,19 +559,33 @@ def main():
478
559
use_sliding_window_ = [parse_bool (s ) for s in args .use_sliding_window ]
479
560
if args .use_logits_soft_cap :
480
561
use_logits_soft_cap_ = [parse_bool (s ) for s in args .use_logits_soft_cap ]
562
+ if args .add_comm is not None :
563
+ add_comm = bool (args .add_comm )
481
564
if args .add_gemma is not None :
482
565
add_gemma = bool (args .add_gemma )
566
+ if args .add_oai_oss is not None :
567
+ add_oai_oss = bool (args .add_oai_oss )
568
+ if args .add_moe is not None :
569
+ add_moe = bool (args .add_moe )
570
+ if args .add_act is not None :
571
+ add_act = bool (args .add_act )
572
+ if args .add_misc is not None :
573
+ add_misc = bool (args .add_misc )
483
574
484
575
# Cuda Arch
485
576
if "TORCH_CUDA_ARCH_LIST" not in os .environ :
486
577
raise RuntimeError ("Please explicitly set env var TORCH_CUDA_ARCH_LIST." )
487
578
gencode_flags = _get_cuda_arch_flags ()
488
- has_sm90 = any ("compute_90" in flag for flag in gencode_flags ) and version_at_least (
489
- torch .version .cuda , "12.3"
490
- )
491
- has_sm100 = any (
492
- "compute_100" in flag for flag in gencode_flags
493
- ) and version_at_least (torch .version .cuda , "12.8" )
579
+
580
+ def has_sm (compute : str , version : str ) -> bool :
581
+ if not any ("compute_90" in flag for flag in gencode_flags ):
582
+ return False
583
+ if torch .version .cuda is None :
584
+ return True
585
+ return version_at_least (torch .version .cuda , version )
586
+
587
+ has_sm90 = has_sm ("compute_90" , "12.3" )
588
+ has_sm100 = has_sm ("compute_100" , "12.8" )
494
589
495
590
# Update data dir
496
591
jit_env .FLASHINFER_CSRC_DIR = project_root / "csrc"
@@ -521,7 +616,12 @@ def main():
521
616
print (" TORCH_CUDA_ARCH_LIST:" , os .environ ["TORCH_CUDA_ARCH_LIST" ])
522
617
print (" has_sm90:" , has_sm90 )
523
618
print (" has_sm100:" , has_sm100 )
619
+ print (" add_comm:" , add_comm )
524
620
print (" add_gemma:" , add_gemma )
621
+ print (" add_oai_oss:" , add_oai_oss )
622
+ print (" add_moe:" , add_moe )
623
+ print (" add_act:" , add_act )
624
+ print (" add_misc:" , add_misc )
525
625
526
626
# Generate JIT specs
527
627
print ("Generating JIT specs..." )
@@ -537,8 +637,6 @@ def main():
537
637
],
538
638
)
539
639
]
540
- if has_sm90 :
541
- jit_specs .append (get_trtllm_utils_spec ())
542
640
jit_specs += gen_all_modules (
543
641
f16_dtype_ ,
544
642
f8_dtype_ ,
@@ -548,7 +646,12 @@ def main():
548
646
use_logits_soft_cap_ ,
549
647
has_sm90 ,
550
648
has_sm100 ,
649
+ add_comm ,
551
650
add_gemma ,
651
+ add_oai_oss ,
652
+ add_moe ,
653
+ add_act ,
654
+ add_misc ,
552
655
)
553
656
print ("Total ops:" , len (jit_specs ))
554
657
0 commit comments