1+ import math
12import os
23import weakref
34from dataclasses import dataclass , field
45from typing import Optional
56
67import torch
78
9+ from tensorrt_llm ._utils import get_sm_version
810from tensorrt_llm .functional import AttentionMaskType
911from tensorrt_llm .logger import logger
1012from tensorrt_llm .models .modeling_utils import QuantConfig
@@ -64,6 +66,10 @@ class TrtllmAttentionWrapper:
6466 qk_nope_head_dim : Optional [int ]
6567 v_head_dim : Optional [int ]
6668 attention_chunk_size : Optional [int ]
69+ use_spec_decoding : bool
70+ spec_decoding_position_offsets : Optional [torch .Tensor ]
71+ spec_decoding_packed_mask : Optional [torch .Tensor ]
72+ spec_decoding_generation_lengths : Optional [torch .Tensor ]
6773 kwargs : dict
6874
6975 def __init__ (
@@ -169,6 +175,11 @@ def plan(
169175 mla_context_paged_kv : Optional [torch .Tensor ] = None ,
170176 mla_context_kv_cache_block_offsets : Optional [torch .Tensor ] = None ,
171177 softmax_stats_tensor : Optional [torch .Tensor ] = None ,
178+ is_spec_decoding_enabled : bool = False ,
179+ use_spec_decoding : bool = False ,
180+ spec_decoding_position_offsets : Optional [torch .Tensor ] = None ,
181+ spec_decoding_packed_mask : Optional [torch .Tensor ] = None ,
182+ spec_decoding_generation_lengths : Optional [torch .Tensor ] = None ,
172183 ** kwargs ,
173184 ):
174185 """
@@ -245,7 +256,11 @@ def plan(
245256 self .rope_params .max_positions = max_sequence_length
246257 self .rotary_inv_freq , self .rotary_cos_sin = self .rope_params .create_rope_const_params (
247258 )
248-
259+ self .is_spec_decoding_enabled = is_spec_decoding_enabled
260+ self .use_spec_decoding = use_spec_decoding
261+ self .spec_decoding_position_offsets = spec_decoding_position_offsets
262+ self .spec_decoding_packed_mask = spec_decoding_packed_mask
263+ self .spec_decoding_generation_lengths = spec_decoding_generation_lengths
249264 self .kwargs .update (kwargs )
250265
251266 def run (
@@ -374,6 +389,23 @@ def run(
374389 # output is provided, expect output_sf be provided as well if has NVFP4 output.
375390 assert out_dtype is None or out_dtype != torch .uint8 or output_sf is not None
376391
392+ # packing parameters to avoid maxing out 64 arguments
393+ rotary_embedding_scales = [
394+ self .rotary_embedding_scale , self .rotary_embedding_short_m_scale ,
395+ self .rotary_embedding_long_m_scale
396+ ]
397+ rotary_embedding_max_position_info = [
398+ self .rotary_embedding_max_positions ,
399+ self .rotary_embedding_original_max_positions
400+ ]
401+ spec_decoding_bool_params = [
402+ self .is_spec_decoding_enabled , self .use_spec_decoding
403+ ]
404+ spec_decoding_tensor_params = [
405+ self .spec_decoding_generation_lengths ,
406+ self .spec_decoding_position_offsets , self .spec_decoding_packed_mask
407+ ]
408+
377409 torch .ops .trtllm .attention_inplace (
378410 q ,
379411 k ,
@@ -420,11 +452,8 @@ def run(
420452 self .rotary_embedding_dim ,
421453 self .rotary_embedding_base ,
422454 self .rotary_embedding_scale_type ,
423- self .rotary_embedding_scale ,
424- self .rotary_embedding_short_m_scale ,
425- self .rotary_embedding_long_m_scale ,
426- self .rotary_embedding_max_positions ,
427- self .rotary_embedding_original_max_positions ,
455+ rotary_embedding_scales ,
456+ rotary_embedding_max_position_info ,
428457 self .use_paged_context_fmha ,
429458 self .attention_input_type ,
430459 self .is_mla_enable ,
@@ -439,6 +468,8 @@ def run(
439468 self .mla_context_kv_cache_block_offsets ,
440469 self .attention_chunk_size ,
441470 self .softmax_stats_tensor ,
471+ spec_decoding_bool_params ,
472+ spec_decoding_tensor_params ,
442473 )
443474
444475 # reset the planned states (especially tensors) to avoid memory leak
@@ -495,6 +526,23 @@ class TrtllmAttentionMetadata(AttentionMetadata):
495526 init = True ,
496527 repr = False )
497528
529+ # Flags to enable spec-dec mode (multi-query mode) in TRTLLM XQA Kernels
530+ # spec decoding mode can be enabled for non-TRTLLM-gen kernels (pre-Blackwell XQA kernels)
531+ # is_spec_decoding_enabled specifies if spec-dec mode is supported for the entire runtime.
532+ is_spec_decoding_enabled : bool = False
533+ # use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer.
534+ use_spec_decoding : bool = False
535+
536+ # if spec-dec tree is a tree or a chain (linear tree)
537+ is_spec_dec_tree : bool = False
538+ # if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
539+ is_spec_dec_dynamic_tree : bool = False
540+
541+ # parameters required for spec-dec mode
542+ spec_decoding_position_offsets : Optional [torch .Tensor ] = None
543+ spec_decoding_packed_mask : Optional [torch .Tensor ] = None
544+ spec_decoding_generation_lengths : Optional [torch .Tensor ] = None
545+
498546 @property
499547 def max_seq_len (self ) -> int :
500548 """
@@ -849,6 +897,76 @@ def prepare_paged_context_mla(self, cached_token_lens: torch.Tensor,
849897 self .ctx_kv_indptr [:self .num_contexts + 1 ].copy_ (
850898 self .host_ctx_kv_indptr [:self .num_contexts + 1 ], non_blocking = True )
851899
900+ def update_spec_dec_param (self , is_spec_decoding_enabled , is_spec_dec_tree ,
901+ is_spec_dec_dynamic_tree , max_draft_tokens ):
902+ # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
903+ self .is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version (
904+ ) < 100
905+
906+ # use_spec_decoding is default to true by default, change in runtime by layers / requests
907+ self .use_spec_decoding = self .is_spec_decoding_enabled
908+
909+ self .is_spec_dec_tree = is_spec_dec_tree
910+ self .is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
911+
912+ # Parameters can be fixed and not changed during runtime if the
913+ if self .is_spec_decoding_enabled :
914+ self .spec_decoding_position_offsets = torch .empty (
915+ [self .max_num_requests , max_draft_tokens + 1 ],
916+ dtype = torch .int ,
917+ device = 'cuda' ,
918+ )
919+
920+ self .spec_decoding_packed_mask = torch .empty (
921+ [
922+ self .max_num_requests , max_draft_tokens + 1 ,
923+ math .ceil (max_draft_tokens / 32 )
924+ ],
925+ dtype = torch .int ,
926+ device = 'cuda' ,
927+ )
928+
929+ self .spec_decoding_generation_lengths = torch .empty (
930+ [self .max_num_requests ],
931+ dtype = torch .int ,
932+ device = 'cuda' ,
933+ )
934+
935+ if self .is_spec_dec_dynamic_tree :
936+ assert False , "currently dynamic tree is not supported"
937+ else :
938+ # Populate the mask that won't change during inference phase.
939+ self .generate_spec_decoding_position_offsets (
940+ max_draft_tokens = max_draft_tokens )
941+ self .generate_spec_decoding_packed_mask (
942+ max_draft_tokens = max_draft_tokens )
943+ self .generate_spec_decoding_generation_length (
944+ max_draft_tokens = max_draft_tokens )
945+
946+ def generate_spec_decoding_position_offsets (self , max_draft_tokens ):
947+ assert not self .is_spec_dec_tree , "only chained/linear tree is supported now"
948+ position_offset = torch .arange (max_draft_tokens + 1 ,
949+ dtype = torch .int ,
950+ device = 'cpu' ,
951+ pin_memory = True )
952+
953+ # fill all the batches with same position offset
954+ self .spec_decoding_position_offsets .copy_ (position_offset ,
955+ non_blocking = True )
956+
957+ def generate_spec_decoding_packed_mask (self , max_draft_tokens ):
958+ assert not self .is_spec_dec_tree , "only chained/linear tree is supported now"
959+ dummy_idx = torch .arange (max_draft_tokens + 1 )
960+ spec_decoding_packed_mask = torch .pow (2 , dummy_idx + 1 ) - 1
961+ self .spec_decoding_packed_mask [:, :, 0 ].copy_ (spec_decoding_packed_mask ,
962+ non_blocking = True )
963+
964+ def generate_spec_decoding_generation_length (self , max_draft_tokens ):
965+ spec_decoding_generation_length = torch .full ((self .max_num_requests , ),
966+ max_draft_tokens + 1 )
967+ self .spec_decoding_generation_lengths [:self .max_num_requests ].copy_ (
968+ spec_decoding_generation_length , non_blocking = True )
969+
852970
853971class TrtllmAttention (AttentionBackend [TrtllmAttentionMetadata ]):
854972
@@ -984,7 +1102,6 @@ def forward(
9841102 use_paged_context_fmha = use_paged_context_fmha ,
9851103 is_mla_enable = self .is_mla_enable ,
9861104 )
987-
9881105 self .wrapper .plan (
9891106 layer_idx = self .get_local_layer_idx (metadata ),
9901107 tokens_per_block = metadata .tokens_per_block ,
@@ -1021,6 +1138,13 @@ def forward(
10211138 mla_context_kv_cache_block_offsets =
10221139 mla_context_kv_cache_block_offsets ,
10231140 softmax_stats_tensor = softmax_stats_tensor ,
1141+ is_spec_decoding_enabled = metadata .is_spec_decoding_enabled ,
1142+ use_spec_decoding = metadata .use_spec_decoding ,
1143+ spec_decoding_position_offsets = metadata .
1144+ spec_decoding_position_offsets ,
1145+ spec_decoding_packed_mask = metadata .spec_decoding_packed_mask ,
1146+ spec_decoding_generation_lengths = metadata .
1147+ spec_decoding_generation_lengths ,
10241148 )
10251149 out_dtype = None
10261150 if out_scale is not None :
0 commit comments