@@ -189,8 +189,7 @@ class InputLayout(IntEnum):
189189ns_close = r"""
190190// clang-format on
191191} // namespace kernels
192- } // namespace tensorrt_llm
193- """ if generate_cu_trtllm else ""
192+ } // namespace tensorrt_llm""" if generate_cu_trtllm else ""
194193
195194copyright = '''\
196195 /***************************************************************************************************
@@ -1344,7 +1343,7 @@ def get_makefile_code(specs_names):
13441343
13451344#endif // sliding_or_chunked_causal_mask
13461345
1347- void {launcher_name}_nl({params_type } ¶ms,
1346+ void {launcher_name}_nl({fused_multihead_attention_params_v2_str } ¶ms,
13481347 const Launch_params& launch_params, cudaStream_t stream){{
13491348 constexpr int loop_iters = {seq_len} / {noloop_step};
13501349 static_assert(loop_iters * {noloop_step} == {seq_len}, "");
@@ -1431,6 +1430,7 @@ def get_makefile_code(specs_names):
14311430 {loop_step},
14321431 {kv_loop_step},
14331432 {head_size},
1433+ {head_size_v},
14341434 {q_tile_buffers},
14351435 {kv_tile_buffers},
14361436 NUM_COMPUTE_GROUPS,
@@ -1453,6 +1453,7 @@ def get_makefile_code(specs_names):
14531453 {loop_step},
14541454 {kv_loop_step},
14551455 {head_size},
1456+ {head_size_v},
14561457 {q_tile_buffers},
14571458 {kv_tile_buffers},
14581459 NUM_COMPUTE_GROUPS,
@@ -1472,6 +1473,7 @@ def get_makefile_code(specs_names):
14721473 {loop_step},
14731474 {kv_loop_step},
14741475 {head_size},
1476+ {head_size_v},
14751477 {q_tile_buffers},
14761478 {kv_tile_buffers},
14771479 NUM_COMPUTE_GROUPS,
@@ -1491,6 +1493,7 @@ def get_makefile_code(specs_names):
14911493 {loop_step},
14921494 {kv_loop_step},
14931495 {head_size},
1496+ {head_size_v},
14941497 {q_tile_buffers},
14951498 {kv_tile_buffers},
14961499 NUM_COMPUTE_GROUPS,
@@ -2881,6 +2884,7 @@ def get_kernel_traits_code(specs_names):
28812884 {loop_step},
28822885 {kv_loop_step},
28832886 {head_size},
2887+ {head_size_v},
28842888 {q_tile_buffers},
28852889 {kv_tile_buffers},
28862890 NUM_COMPUTE_GROUPS,
@@ -3213,7 +3217,7 @@ def get_lname_from_kname(kname: str) -> str:
32133217 return 'nullptr'
32143218 lname = kname .replace ('_kernel' , '' )
32153219 mask_types = [
3216- '_sliding_window_causal ' , '_custom_mask' , '_causal'
3220+ '_sliding_or_chunked_causal ' , '_custom_mask' , '_causal'
32173221 ]
32183222 for mask_type in mask_types :
32193223 lname = lname .replace (mask_type , '' )
@@ -3228,6 +3232,12 @@ def get_lname_from_kname(kname: str) -> str:
32283232 {cubin_name}_len, \" {kname}\" , {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
32293233 {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
32303234 {is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
3235+ ''' .format (** locals ()) if 'sage' in kname and 'sm90' in kname else '''\
3236+ {{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
3237+ {sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
3238+ 0, \" {kname}\" , {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
3239+ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
3240+ {is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
32313241 ''' .format (** locals ())
32323242 else :
32333243 code = '''\
@@ -3332,7 +3342,6 @@ def get_lname_from_kname(kname: str) -> str:
33323342{metadata_v2}
33333343}};
33343344{local_ns_close}
3335-
33363345''' .format (** locals (), copyright = copyright )
33373346
33383347 else :
@@ -3540,7 +3549,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
35403549
35413550
35423551# Note this will be used in TRT-LLM.
3543- def enumerate_hgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'fp16' ):
3552+ def enumerate_hgmma_flash_warpspec_kernels (specs ,
3553+ sm = 90 ,
3554+ dtype = 'fp16' ,
3555+ head_size_v = 0 ):
35443556
35453557 scheduling_mode = int (os .getenv ('SCHEDULING_MODE' , '1' ))
35463558
@@ -3563,6 +3575,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
35633575 dtype = dtype ,
35643576 seq_len = 0 , # support any sequence length
35653577 head_size = [32 , 40 , 48 , 64 ],
3578+ head_size_v = head_size_v ,
35663579 warps_m = 4 , #4x1 warpgroups
35673580 warps_n = 1 ,
35683581 version = 2 ,
@@ -3595,6 +3608,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
35953608 dtype = dtype ,
35963609 seq_len = 0 , # support any sequence length
35973610 head_size = [72 , 80 , 96 , 104 , 128 ],
3611+ head_size_v = head_size_v ,
35983612 warps_m = 4 , #4x1 warpgroups
35993613 warps_n = 1 ,
36003614 version = 2 ,
@@ -3627,6 +3641,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
36273641 dtype = dtype ,
36283642 seq_len = 0 , # support any sequence length
36293643 head_size = [160 , 192 , 256 ],
3644+ head_size_v = head_size_v ,
36303645 warps_m = 4 , #4x1 warpgroups
36313646 warps_n = 1 ,
36323647 version = 2 ,
@@ -3652,6 +3667,40 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
36523667 scheduling_mode = scheduling_mode ,
36533668 input_layout = input_layout ))
36543669
3670+ # for deepseek context 192/128, kv_step=128
3671+ specs .append (
3672+ kernel_spec (
3673+ sm = sm ,
3674+ sm_mma = 90 ,
3675+ dtype = dtype ,
3676+ seq_len = 0 , # support any sequence length
3677+ head_size = 192 ,
3678+ head_size_v = 128 ,
3679+ warps_m = 4 , #4x1 warpgroups
3680+ warps_n = 1 ,
3681+ version = 2 ,
3682+ interleaved = False ,
3683+ ldgsts_q =
3684+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3685+ ldgsts_k = False ,
3686+ ldgsts_v = False ,
3687+ share_smem_k_v = False ,
3688+ loop_step = 64 ,
3689+ q_tile_buffers = 1 , # only used by warp specialized kernels
3690+ has_noloop = 0 ,
3691+ noloop_step = 64 ,
3692+ kv_loop_step = 128 ,
3693+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3694+ unroll_threshold = 1 ,
3695+ has_scale_max = False ,
3696+ flash_attention = True ,
3697+ warp_specialization = True ,
3698+ alibi = alibi ,
3699+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3700+ return_softmax_stats = return_softmax ,
3701+ scheduling_mode = scheduling_mode ,
3702+ input_layout = input_layout ))
3703+
36553704
36563705# Note this will be used in TRT-LLM.
36573706def enumerate_qgmma_flash_warpspec_kernels (specs ,
@@ -6215,7 +6264,21 @@ def enumerate_kernels():
62156264 and kspec .cross_mha == False
62166265 and kspec .flash_attention == True
62176266 and kspec .warp_specialization == False
6218- and kspec .tiled == True )
6267+ and kspec .tiled == True
6268+ and not (kspec .sm == 90 and (kspec .head_size , kspec .head_size_v ) == (192 , 128 )))
6269+ # Deepseek MLA (hopper-style context 192/128 packed + paged)
6270+ or (kspec .sm == 90
6271+ and kspec .dtype == 'bf16'
6272+ and kspec .head_size == 192
6273+ and kspec .head_size_v == 128
6274+ and kspec .sage_block_sizes is None
6275+ and kspec .version == 2
6276+ and kspec .cross_mha == False
6277+ and kspec .flash_attention == True
6278+ and kspec .warp_specialization == True
6279+ and kspec .input_layout in [InputLayout .PACKED_QKV , InputLayout .Q_PAGED_KV ]
6280+ and kspec .alibi == False
6281+ and kspec .enable_attn_logit_softcapping == False )
62196282 # SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
62206283 or (kspec .sm == 90
62216284 and kspec .head_size in [80 , 128 ]
0 commit comments