Skip to content

Commit 156f645

Browse files
authored
[TRTLLM-9798][feat] Change to use new DeepGEMM MQA sm100 kernel for MTP-3 (#10226)
Signed-off-by: Fanrong Li <[email protected]>
1 parent f6c3bc1 commit 156f645

File tree

2 files changed

+78
-19
lines changed

2 files changed

+78
-19
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
1919
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2020
from tensorrt_llm._torch.utils import maybe_compile
21-
from tensorrt_llm._utils import get_size_in_bytes
21+
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version
2222
from tensorrt_llm.bindings import DataType
2323
from tensorrt_llm.bindings.executor import KvCacheConfig
2424
from tensorrt_llm.bindings.internal.batch_manager import \
@@ -312,6 +312,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
312312
skip_indexer_for_ctx_reqs: bool = False
313313
# Whether skip the indexer for generation requests
314314
skip_indexer_for_gen_reqs: bool = False
315+
# Whether to use the expanded buffers for MTP support
316+
use_expanded_buffers_for_mtp: bool = False
315317

316318
def __init__(self, *args, **kwargs):
317319
self.num_sms = tensorrt_llm.deep_gemm.get_num_sms()
@@ -475,10 +477,10 @@ def __post_init__(self):
475477
device='cpu',
476478
pin_memory=True,
477479
)
478-
# Create expanded buffers for MTP>1 support
480+
# Create expanded buffers for MTP support
479481
self.create_expanded_buffers(capture_graph=capture_graph)
480482

481-
# TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
483+
# TODO: remove these expanded buffers when fp8_paged_mqa_logits supports an arbitrary number of MTP draft tokens.
482484
def create_expanded_buffers(self, capture_graph=False):
483485
self.kv_lens_expanded_cuda = self.get_empty(
484486
self.cuda_graph_buffers,
@@ -514,9 +516,18 @@ def create_expanded_buffers(self, capture_graph=False):
514516
dtype=torch.int32,
515517
capture_graph=capture_graph,
516518
)
519+
# The fp8_paged_mqa_logits kernel needs different layout of the metadata buffer for MTP=3.
520+
if self.max_draft_tokens == 3:
521+
self.scheduler_metadata_buffer_mtp3 = self.get_empty(
522+
self.cuda_graph_buffers,
523+
(self.num_sms // 2 + 1, 2),
524+
cache_name="scheduler_metadata_buffer_mtp3",
525+
dtype=torch.int32,
526+
capture_graph=capture_graph,
527+
)
517528

518529
# This function is only used to create the expanded buffers when the max_draft_tokens is changed.
519-
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
530+
# TODO: remove this function once fp8_paged_mqa_logits supports an arbitrary number of MTP draft tokens.
520531
def update_spec_dec_param(
521532
self,
522533
batch_size,
@@ -726,11 +737,17 @@ def prepare(self):
726737
else:
727738
self.max_gen_seq_len = 0
728739

729-
# Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
730-
# MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
731-
# block_table for to use the fp8_paged_mqa_logits.
732-
# TODO: remove this when fp8_paged_mqa_logits supports MTP > 1.
733-
if self.max_draft_tokens > 1:
740+
# Because the fp8_paged_mqa_logits only supports seq_len == 1/2/4 (i.e., max_draft_tokens == 0/1/3) on sm100, and
741+
# seq_len == 1/2 (i.e., max_draft_tokens == 0/1) on sm90, for other cases, we need to flatten the q tensor and
742+
# expand the kv_lens and block_table for MTP support.
743+
# TODO:
744+
# - No distinction between sm90 and sm100 is needed once MTP3 is supported on sm90.
745+
# - Remove this once fp8_paged_mqa_logits supports an arbitrary number of MTP draft tokens.
746+
self.use_expanded_buffers_for_mtp = (
747+
(self.max_draft_tokens > 1 and get_sm_version() == 90)
748+
or ((self.max_draft_tokens == 2 or self.max_draft_tokens > 3)
749+
and get_sm_version() >= 100))
750+
if self.use_expanded_buffers_for_mtp:
734751
# Expand kv_lens_cuda (only generation)
735752
num_tokens = self.num_generations * (1 + self.max_draft_tokens)
736753
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
@@ -786,6 +803,26 @@ def on_update_kv_lens(self):
786803
tokens_per_block, self.num_sms)
787804
self.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
788805
non_blocking=True)
806+
if self.use_expanded_buffers_for_mtp:
807+
num_draft_tokens = 1 + self.max_draft_tokens
808+
num_tokens = self.num_generations * num_draft_tokens
809+
gen_kv_lens = self.kv_lens_cuda[self.num_contexts:self.num_seqs]
810+
kv_lens_expanded = torch.stack([gen_kv_lens] * num_draft_tokens,
811+
dim=0)
812+
self.kv_lens_expanded_cuda[:num_tokens] = \
813+
kv_lens_expanded.transpose(0, 1).contiguous().flatten()
814+
# Expand schedule metadata buffer (only generation)
815+
kv_lens_expanded = self.kv_lens_expanded_cuda[:num_tokens]
816+
scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata(
817+
kv_lens_expanded, tokens_per_block, self.num_sms)
818+
self.scheduler_metadata_buffer_expanded.copy_(
819+
scheduler_metadata_buffer_expanded, non_blocking=True)
820+
elif self.max_draft_tokens == 3:
821+
scheduler_metadata_buffer_mtp3 = get_paged_mqa_logits_metadata(
822+
self.kv_lens_cuda[self.num_contexts:self.num_seqs],
823+
tokens_per_block, self.num_sms // 2)
824+
self.scheduler_metadata_buffer_mtp3.copy_(
825+
scheduler_metadata_buffer_mtp3, non_blocking=True)
789826
self.prepare_dense_topk_indices(self.kv_lens_cuda, device=True)
790827

791828
def update_for_spec_dec(self):
@@ -1058,13 +1095,18 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
10581095
if num_generations > 0:
10591096
# Prepare schedule metadata for fp8_paged_mqa_logits
10601097
# This is a preprocessing step that computes scheduling information for the kernel
1061-
if metadata.max_draft_tokens <= 1:
1098+
if not metadata.use_expanded_buffers_for_mtp:
10621099
gen_seq_lens = metadata.kv_lens_cuda_runtime[
10631100
num_contexts:num_contexts + num_generations]
10641101
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
10651102
gen_seq_lens, tokens_per_block, metadata.num_sms)
10661103
metadata.scheduler_metadata_buffer.copy_(
10671104
scheduler_metadata_buffer, non_blocking=True)
1105+
if metadata.max_draft_tokens == 3:
1106+
scheduler_metadata_buffer_mtp3 = get_paged_mqa_logits_metadata(
1107+
gen_seq_lens, tokens_per_block, metadata.num_sms // 2)
1108+
metadata.scheduler_metadata_buffer_mtp3.copy_(
1109+
scheduler_metadata_buffer_mtp3, non_blocking=True)
10681110
else:
10691111
# Expand schedule metadata buffer (only generation)
10701112
num_tokens = metadata.num_generations * (
@@ -1399,18 +1441,22 @@ def sparse_attn_indexer(
13991441
...]
14001442
batch_size = num_generations
14011443
next_n = num_gen_tokens // num_generations
1402-
# Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
1444+
# Because fp8_paged_mqa_logits can only support next_n == 1/2/4 on sm100, and
1445+
# next_n == 1/2 on sm90, for other next_n, we need to flatten the q_decode tensor
14031446
# and expand the corresponding metadata.
1404-
if next_n <= 2:
1447+
if not metadata.use_expanded_buffers_for_mtp or next_n == 1:
14051448
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
14061449
context_lens = metadata.kv_lens_cuda_runtime[
14071450
num_contexts:num_contexts + num_generations]
14081451
block_table = metadata.indexer_k_cache_block_offsets[
14091452
num_contexts:num_contexts + num_generations]
1410-
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer
1453+
if q_decode.shape[1] == 4:
1454+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_mtp3
1455+
else:
1456+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer
14111457
else:
14121458
q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
1413-
num_tokens = num_generations * (1 + metadata.max_draft_tokens)
1459+
num_tokens = q_decode.shape[0]
14141460
context_lens = metadata.kv_lens_expanded_cuda[:num_tokens]
14151461
block_table = metadata.block_table_expanded[:num_tokens]
14161462
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tensorrt_llm._torch.attention_backend.sparse.dsa import (
2121
DSACacheManager, DSAtrtllmAttentionMetadata, Indexer,
2222
compute_cu_seqlen_kv_bounds_with_cache, split_prefill_chunks)
23+
from tensorrt_llm._utils import get_sm_version
2324
from tensorrt_llm.bindings import DataType
2425
from tensorrt_llm.bindings.executor import KvCacheConfig
2526
from tensorrt_llm.bindings.internal.batch_manager import \
@@ -515,7 +516,11 @@ def __init__(self):
515516

516517
self.runtime_features = RuntimeFeatures()
517518

518-
# Add expanded buffers for MTP>1 support
519+
# Add expanded buffers for MTP support
520+
self.use_expanded_buffers_for_mtp = (
521+
(self.max_draft_tokens > 1 and get_sm_version() == 90)
522+
or ((self.max_draft_tokens == 2 or self.max_draft_tokens > 3)
523+
and get_sm_version() >= 100))
519524
self.kv_lens_expanded_cuda = torch.zeros(
520525
(self.num_seqs * (1 + self.max_draft_tokens), ),
521526
device='cuda',
@@ -531,7 +536,12 @@ def __init__(self):
531536
self.block_table_expanded, device='cpu', pin_memory=True)
532537
self.scheduler_metadata_buffer_expanded = torch.zeros(
533538
(self.num_sms + 1, 2), device='cuda', dtype=torch.int32)
534-
if self.max_draft_tokens > 1:
539+
if self.max_draft_tokens == 3:
540+
self.scheduler_metadata_buffer_mtp3 = torch.zeros(
541+
(self.num_sms // 2 + 1, 2),
542+
device='cuda',
543+
dtype=torch.int32)
544+
if self.use_expanded_buffers_for_mtp:
535545
gen_kv_lens = kv_lens[num_contexts:self.num_seqs]
536546
gen_kv_lens_expanded = torch.stack([gen_kv_lens] *
537547
(1 + self.max_draft_tokens),
@@ -885,7 +895,7 @@ def test_fp8_k_cache_roundtrip():
885895

886896
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
887897
@skip_pre_hopper
888-
@pytest.mark.parametrize("batch_size,next_n", [(4, 1), (2, 2), (4, 4)])
898+
@pytest.mark.parametrize("batch_size,next_n", [(4, 1), (2, 2), (4, 3), (4, 4)])
889899
def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
890900
"""
891901
Test FP8 paged KV cache with two-phase workflow and variable context lengths.
@@ -1013,11 +1023,14 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
10131023
kv_cache_fp8_pool = cache_manager.get_indexer_k_cache_buffers(layer_idx)
10141024
q_fp8 = q.to(torch.float8_e4m3fn)
10151025

1016-
if next_n <= 2:
1026+
if not metadata_gen.use_expanded_buffers_for_mtp:
10171027
q_fp8 = q_fp8
10181028
context_lens = metadata_gen.kv_lens_cuda_runtime[0:batch_size]
10191029
block_table = metadata_gen.indexer_k_cache_block_offsets[0:batch_size]
1020-
scheduler_metadata_buffer = metadata_gen.scheduler_metadata_buffer
1030+
if q_fp8.shape[1] == 4:
1031+
scheduler_metadata_buffer = metadata_gen.scheduler_metadata_buffer_mtp3
1032+
else:
1033+
scheduler_metadata_buffer = metadata_gen.scheduler_metadata_buffer
10211034
else:
10221035
q_fp8 = q_fp8.view(-1, 1, *q_fp8.shape[2:])
10231036
num_tokens = batch_size * next_n

0 commit comments

Comments
 (0)