|
18 | 18 | from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding |
19 | 19 | from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager |
20 | 20 | from tensorrt_llm._torch.utils import maybe_compile |
21 | | -from tensorrt_llm._utils import get_size_in_bytes |
| 21 | +from tensorrt_llm._utils import get_size_in_bytes, get_sm_version |
22 | 22 | from tensorrt_llm.bindings import DataType |
23 | 23 | from tensorrt_llm.bindings.executor import KvCacheConfig |
24 | 24 | from tensorrt_llm.bindings.internal.batch_manager import \ |
@@ -312,6 +312,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): |
312 | 312 | skip_indexer_for_ctx_reqs: bool = False |
313 | 313 | # Whether skip the indexer for generation requests |
314 | 314 | skip_indexer_for_gen_reqs: bool = False |
| 315 | + # Whether to use the expanded buffers for MTP support |
| 316 | + use_expanded_buffers_for_mtp: bool = False |
315 | 317 |
|
316 | 318 | def __init__(self, *args, **kwargs): |
317 | 319 | self.num_sms = tensorrt_llm.deep_gemm.get_num_sms() |
@@ -475,10 +477,10 @@ def __post_init__(self): |
475 | 477 | device='cpu', |
476 | 478 | pin_memory=True, |
477 | 479 | ) |
478 | | - # Create expanded buffers for MTP>1 support |
| 480 | + # Create expanded buffers for MTP support |
479 | 481 | self.create_expanded_buffers(capture_graph=capture_graph) |
480 | 482 |
|
481 | | - # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1. |
| 483 | + # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports an arbitrary number of MTP draft tokens. |
482 | 484 | def create_expanded_buffers(self, capture_graph=False): |
483 | 485 | self.kv_lens_expanded_cuda = self.get_empty( |
484 | 486 | self.cuda_graph_buffers, |
@@ -514,9 +516,18 @@ def create_expanded_buffers(self, capture_graph=False): |
514 | 516 | dtype=torch.int32, |
515 | 517 | capture_graph=capture_graph, |
516 | 518 | ) |
| 519 | + # The fp8_paged_mqa_logits kernel needs different layout of the metadata buffer for MTP=3. |
| 520 | + if self.max_draft_tokens == 3: |
| 521 | + self.scheduler_metadata_buffer_mtp3 = self.get_empty( |
| 522 | + self.cuda_graph_buffers, |
| 523 | + (self.num_sms // 2 + 1, 2), |
| 524 | + cache_name="scheduler_metadata_buffer_mtp3", |
| 525 | + dtype=torch.int32, |
| 526 | + capture_graph=capture_graph, |
| 527 | + ) |
517 | 528 |
|
518 | 529 | # This function is only used to create the expanded buffers when the max_draft_tokens is changed. |
519 | | - # TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1. |
| 530 | + # TODO: remove this function once fp8_paged_mqa_logits supports an arbitrary number of MTP draft tokens. |
520 | 531 | def update_spec_dec_param( |
521 | 532 | self, |
522 | 533 | batch_size, |
@@ -726,11 +737,17 @@ def prepare(self): |
726 | 737 | else: |
727 | 738 | self.max_gen_seq_len = 0 |
728 | 739 |
|
729 | | - # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support |
730 | | - # MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and |
731 | | - # block_table for to use the fp8_paged_mqa_logits. |
732 | | - # TODO: remove this when fp8_paged_mqa_logits supports MTP > 1. |
733 | | - if self.max_draft_tokens > 1: |
| 740 | + # Because the fp8_paged_mqa_logits only supports seq_len == 1/2/4 (i.e., max_draft_tokens == 0/1/3) on sm100, and |
| 741 | + # seq_len == 1/2 (i.e., max_draft_tokens == 0/1) on sm90, for other cases, we need to flatten the q tensor and |
| 742 | + # expand the kv_lens and block_table for MTP support. |
| 743 | + # TODO: |
| 744 | + # - No distinction between sm90 and sm100 is needed once MTP3 is supported on sm90. |
| 745 | + # - Remove this once fp8_paged_mqa_logits supports an arbitrary number of MTP draft tokens. |
| 746 | + self.use_expanded_buffers_for_mtp = ( |
| 747 | + (self.max_draft_tokens > 1 and get_sm_version() == 90) |
| 748 | + or ((self.max_draft_tokens == 2 or self.max_draft_tokens > 3) |
| 749 | + and get_sm_version() >= 100)) |
| 750 | + if self.use_expanded_buffers_for_mtp: |
734 | 751 | # Expand kv_lens_cuda (only generation) |
735 | 752 | num_tokens = self.num_generations * (1 + self.max_draft_tokens) |
736 | 753 | gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs] |
@@ -786,6 +803,26 @@ def on_update_kv_lens(self): |
786 | 803 | tokens_per_block, self.num_sms) |
787 | 804 | self.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer, |
788 | 805 | non_blocking=True) |
| 806 | + if self.use_expanded_buffers_for_mtp: |
| 807 | + num_draft_tokens = 1 + self.max_draft_tokens |
| 808 | + num_tokens = self.num_generations * num_draft_tokens |
| 809 | + gen_kv_lens = self.kv_lens_cuda[self.num_contexts:self.num_seqs] |
| 810 | + kv_lens_expanded = torch.stack([gen_kv_lens] * num_draft_tokens, |
| 811 | + dim=0) |
| 812 | + self.kv_lens_expanded_cuda[:num_tokens] = \ |
| 813 | + kv_lens_expanded.transpose(0, 1).contiguous().flatten() |
| 814 | + # Expand schedule metadata buffer (only generation) |
| 815 | + kv_lens_expanded = self.kv_lens_expanded_cuda[:num_tokens] |
| 816 | + scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata( |
| 817 | + kv_lens_expanded, tokens_per_block, self.num_sms) |
| 818 | + self.scheduler_metadata_buffer_expanded.copy_( |
| 819 | + scheduler_metadata_buffer_expanded, non_blocking=True) |
| 820 | + elif self.max_draft_tokens == 3: |
| 821 | + scheduler_metadata_buffer_mtp3 = get_paged_mqa_logits_metadata( |
| 822 | + self.kv_lens_cuda[self.num_contexts:self.num_seqs], |
| 823 | + tokens_per_block, self.num_sms // 2) |
| 824 | + self.scheduler_metadata_buffer_mtp3.copy_( |
| 825 | + scheduler_metadata_buffer_mtp3, non_blocking=True) |
789 | 826 | self.prepare_dense_topk_indices(self.kv_lens_cuda, device=True) |
790 | 827 |
|
791 | 828 | def update_for_spec_dec(self): |
@@ -1058,13 +1095,18 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): |
1058 | 1095 | if num_generations > 0: |
1059 | 1096 | # Prepare schedule metadata for fp8_paged_mqa_logits |
1060 | 1097 | # This is a preprocessing step that computes scheduling information for the kernel |
1061 | | - if metadata.max_draft_tokens <= 1: |
| 1098 | + if not metadata.use_expanded_buffers_for_mtp: |
1062 | 1099 | gen_seq_lens = metadata.kv_lens_cuda_runtime[ |
1063 | 1100 | num_contexts:num_contexts + num_generations] |
1064 | 1101 | scheduler_metadata_buffer = get_paged_mqa_logits_metadata( |
1065 | 1102 | gen_seq_lens, tokens_per_block, metadata.num_sms) |
1066 | 1103 | metadata.scheduler_metadata_buffer.copy_( |
1067 | 1104 | scheduler_metadata_buffer, non_blocking=True) |
| 1105 | + if metadata.max_draft_tokens == 3: |
| 1106 | + scheduler_metadata_buffer_mtp3 = get_paged_mqa_logits_metadata( |
| 1107 | + gen_seq_lens, tokens_per_block, metadata.num_sms // 2) |
| 1108 | + metadata.scheduler_metadata_buffer_mtp3.copy_( |
| 1109 | + scheduler_metadata_buffer_mtp3, non_blocking=True) |
1068 | 1110 | else: |
1069 | 1111 | # Expand schedule metadata buffer (only generation) |
1070 | 1112 | num_tokens = metadata.num_generations * ( |
@@ -1399,18 +1441,22 @@ def sparse_attn_indexer( |
1399 | 1441 | ...] |
1400 | 1442 | batch_size = num_generations |
1401 | 1443 | next_n = num_gen_tokens // num_generations |
1402 | | - # Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor |
| 1444 | + # Because fp8_paged_mqa_logits can only support next_n == 1/2/4 on sm100, and |
| 1445 | + # next_n == 1/2 on sm90, for other next_n, we need to flatten the q_decode tensor |
1403 | 1446 | # and expand the corresponding metadata. |
1404 | | - if next_n <= 2: |
| 1447 | + if not metadata.use_expanded_buffers_for_mtp or next_n == 1: |
1405 | 1448 | q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:]) |
1406 | 1449 | context_lens = metadata.kv_lens_cuda_runtime[ |
1407 | 1450 | num_contexts:num_contexts + num_generations] |
1408 | 1451 | block_table = metadata.indexer_k_cache_block_offsets[ |
1409 | 1452 | num_contexts:num_contexts + num_generations] |
1410 | | - scheduler_metadata_buffer = metadata.scheduler_metadata_buffer |
| 1453 | + if q_decode.shape[1] == 4: |
| 1454 | + scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_mtp3 |
| 1455 | + else: |
| 1456 | + scheduler_metadata_buffer = metadata.scheduler_metadata_buffer |
1411 | 1457 | else: |
1412 | 1458 | q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:]) |
1413 | | - num_tokens = num_generations * (1 + metadata.max_draft_tokens) |
| 1459 | + num_tokens = q_decode.shape[0] |
1414 | 1460 | context_lens = metadata.kv_lens_expanded_cuda[:num_tokens] |
1415 | 1461 | block_table = metadata.block_table_expanded[:num_tokens] |
1416 | 1462 | scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded |
|
0 commit comments