Skip to content

Commit 180b91f

Browse files
authored
update fmha_v2 (NVIDIA#4895)
Signed-off-by: Qidi Sang <[email protected]>
1 parent 51652b9 commit 180b91f

File tree

7 files changed

+217
-73
lines changed

7 files changed

+217
-73
lines changed

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_trtllm_context_mla_attention_fmha(dtype, s):
157157
epsilon += ' -epsilon 0.03'
158158

159159
sm_version = getSMVersion()
160-
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
160+
if sm_version != 89:
161161
pytest.skip("FP8 MLAs only supported on sm89 currently.")
162162

163163
# Context phase kernels.

cpp/kernels/fmha_v2/setup.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ class InputLayout(IntEnum):
189189
ns_close = r"""
190190
// clang-format on
191191
} // namespace kernels
192-
} // namespace tensorrt_llm
193-
""" if generate_cu_trtllm else ""
192+
} // namespace tensorrt_llm""" if generate_cu_trtllm else ""
194193

195194
copyright = '''\
196195
/***************************************************************************************************
@@ -1344,7 +1343,7 @@ def get_makefile_code(specs_names):
13441343
13451344
#endif // sliding_or_chunked_causal_mask
13461345
1347-
void {launcher_name}_nl({params_type} &params,
1346+
void {launcher_name}_nl({fused_multihead_attention_params_v2_str} &params,
13481347
const Launch_params& launch_params, cudaStream_t stream){{
13491348
constexpr int loop_iters = {seq_len} / {noloop_step};
13501349
static_assert(loop_iters * {noloop_step} == {seq_len}, "");
@@ -1431,6 +1430,7 @@ def get_makefile_code(specs_names):
14311430
{loop_step},
14321431
{kv_loop_step},
14331432
{head_size},
1433+
{head_size_v},
14341434
{q_tile_buffers},
14351435
{kv_tile_buffers},
14361436
NUM_COMPUTE_GROUPS,
@@ -1453,6 +1453,7 @@ def get_makefile_code(specs_names):
14531453
{loop_step},
14541454
{kv_loop_step},
14551455
{head_size},
1456+
{head_size_v},
14561457
{q_tile_buffers},
14571458
{kv_tile_buffers},
14581459
NUM_COMPUTE_GROUPS,
@@ -1472,6 +1473,7 @@ def get_makefile_code(specs_names):
14721473
{loop_step},
14731474
{kv_loop_step},
14741475
{head_size},
1476+
{head_size_v},
14751477
{q_tile_buffers},
14761478
{kv_tile_buffers},
14771479
NUM_COMPUTE_GROUPS,
@@ -1491,6 +1493,7 @@ def get_makefile_code(specs_names):
14911493
{loop_step},
14921494
{kv_loop_step},
14931495
{head_size},
1496+
{head_size_v},
14941497
{q_tile_buffers},
14951498
{kv_tile_buffers},
14961499
NUM_COMPUTE_GROUPS,
@@ -2881,6 +2884,7 @@ def get_kernel_traits_code(specs_names):
28812884
{loop_step},
28822885
{kv_loop_step},
28832886
{head_size},
2887+
{head_size_v},
28842888
{q_tile_buffers},
28852889
{kv_tile_buffers},
28862890
NUM_COMPUTE_GROUPS,
@@ -3213,7 +3217,7 @@ def get_lname_from_kname(kname: str) -> str:
32133217
return 'nullptr'
32143218
lname = kname.replace('_kernel', '')
32153219
mask_types = [
3216-
'_sliding_window_causal', '_custom_mask', '_causal'
3220+
'_sliding_or_chunked_causal', '_custom_mask', '_causal'
32173221
]
32183222
for mask_type in mask_types:
32193223
lname = lname.replace(mask_type, '')
@@ -3228,6 +3232,12 @@ def get_lname_from_kname(kname: str) -> str:
32283232
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
32293233
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
32303234
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
3235+
'''.format(**locals()) if 'sage' in kname and 'sm90' in kname else '''\
3236+
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
3237+
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
3238+
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
3239+
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
3240+
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
32313241
'''.format(**locals())
32323242
else:
32333243
code = '''\
@@ -3332,7 +3342,6 @@ def get_lname_from_kname(kname: str) -> str:
33323342
{metadata_v2}
33333343
}};
33343344
{local_ns_close}
3335-
33363345
'''.format(**locals(), copyright=copyright)
33373346

33383347
else:
@@ -3540,7 +3549,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
35403549

35413550

35423551
# Note this will be used in TRT-LLM.
3543-
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
3552+
def enumerate_hgmma_flash_warpspec_kernels(specs,
3553+
sm=90,
3554+
dtype='fp16',
3555+
head_size_v=0):
35443556

35453557
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
35463558

@@ -3563,6 +3575,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
35633575
dtype=dtype,
35643576
seq_len=0, # support any sequence length
35653577
head_size=[32, 40, 48, 64],
3578+
head_size_v=head_size_v,
35663579
warps_m=4, #4x1 warpgroups
35673580
warps_n=1,
35683581
version=2,
@@ -3595,6 +3608,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
35953608
dtype=dtype,
35963609
seq_len=0, # support any sequence length
35973610
head_size=[72, 80, 96, 104, 128],
3611+
head_size_v=head_size_v,
35983612
warps_m=4, #4x1 warpgroups
35993613
warps_n=1,
36003614
version=2,
@@ -3627,6 +3641,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
36273641
dtype=dtype,
36283642
seq_len=0, # support any sequence length
36293643
head_size=[160, 192, 256],
3644+
head_size_v=head_size_v,
36303645
warps_m=4, #4x1 warpgroups
36313646
warps_n=1,
36323647
version=2,
@@ -3652,6 +3667,40 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
36523667
scheduling_mode=scheduling_mode,
36533668
input_layout=input_layout))
36543669

3670+
# for deepseek context 192/128, kv_step=128
3671+
specs.append(
3672+
kernel_spec(
3673+
sm=sm,
3674+
sm_mma=90,
3675+
dtype=dtype,
3676+
seq_len=0, # support any sequence length
3677+
head_size=192,
3678+
head_size_v=128,
3679+
warps_m=4, #4x1 warpgroups
3680+
warps_n=1,
3681+
version=2,
3682+
interleaved=False,
3683+
ldgsts_q=
3684+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3685+
ldgsts_k=False,
3686+
ldgsts_v=False,
3687+
share_smem_k_v=False,
3688+
loop_step=64,
3689+
q_tile_buffers=1, # only used by warp specialized kernels
3690+
has_noloop=0,
3691+
noloop_step=64,
3692+
kv_loop_step=128,
3693+
kv_tile_buffers=2, # only used by warp specialized kernels
3694+
unroll_threshold=1,
3695+
has_scale_max=False,
3696+
flash_attention=True,
3697+
warp_specialization=True,
3698+
alibi=alibi,
3699+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3700+
return_softmax_stats=return_softmax,
3701+
scheduling_mode=scheduling_mode,
3702+
input_layout=input_layout))
3703+
36553704

36563705
# Note this will be used in TRT-LLM.
36573706
def enumerate_qgmma_flash_warpspec_kernels(specs,
@@ -6215,7 +6264,21 @@ def enumerate_kernels():
62156264
and kspec.cross_mha == False
62166265
and kspec.flash_attention == True
62176266
and kspec.warp_specialization == False
6218-
and kspec.tiled == True)
6267+
and kspec.tiled == True
6268+
and not (kspec.sm == 90 and (kspec.head_size, kspec.head_size_v) == (192, 128)))
6269+
# Deepseek MLA (hopper-style context 192/128 packed + paged)
6270+
or (kspec.sm == 90
6271+
and kspec.dtype == 'bf16'
6272+
and kspec.head_size == 192
6273+
and kspec.head_size_v == 128
6274+
and kspec.sage_block_sizes is None
6275+
and kspec.version == 2
6276+
and kspec.cross_mha == False
6277+
and kspec.flash_attention == True
6278+
and kspec.warp_specialization == True
6279+
and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV]
6280+
and kspec.alibi == False
6281+
and kspec.enable_attn_logit_softcapping == False)
62196282
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
62206283
or (kspec.sm == 90
62216284
and kspec.head_size in [80, 128]

cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct Compute
173173

174174
enum
175175
{
176-
TILE_SIZE_V = STEP_KV * Kernel_traits::D
176+
TILE_SIZE_V = STEP_KV * Kernel_traits::DV
177177
};
178178

179179
enum

0 commit comments

Comments
 (0)