22import os
33import weakref
44from dataclasses import dataclass , field
5- from typing import Optional , Tuple , Union
5+ from typing import TYPE_CHECKING , Optional , Tuple , Union
66
77import torch
88
9+ if TYPE_CHECKING :
10+ from ..speculative .utils import SpecDecodingTensor
11+
912from tensorrt_llm ._utils import get_sm_version
1013from tensorrt_llm .bindings .internal import thop
1114from tensorrt_llm .functional import AttentionMaskType
@@ -1045,12 +1048,32 @@ def prepare_context_mla_with_cached_kv(self,
10451048 self .ctx_kv_indptr [:self .num_contexts + 1 ].copy_ (
10461049 self .host_ctx_kv_indptr [:self .num_contexts + 1 ], non_blocking = True )
10471050
1048- def update_spec_dec_param (self , is_spec_decoding_enabled , is_spec_dec_tree ,
1049- is_spec_dec_dynamic_tree , max_draft_tokens ):
1051+ def update_spec_dec_param (
1052+ self ,
1053+ is_spec_decoding_enabled ,
1054+ is_spec_dec_tree ,
1055+ is_spec_dec_dynamic_tree ,
1056+ max_draft_tokens ,
1057+ spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
1058+ ):
1059+
1060+ if spec_decoding_tensor is not None :
1061+ spec_decoding_position_offsets = spec_decoding_tensor .position_offsets
1062+ spec_decoding_packed_mask = spec_decoding_tensor .packed_mask
1063+ spec_decoding_generation_lengths = spec_decoding_tensor .generation_lengths
1064+ else :
1065+ spec_decoding_position_offsets = None
1066+ spec_decoding_packed_mask = None
1067+ spec_decoding_generation_lengths = None
10501068 # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10511069 self .is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version (
10521070 ) < 100
10531071
1072+ if get_sm_version () >= 100 :
1073+ if is_spec_dec_tree or is_spec_dec_dynamic_tree :
1074+ assert not is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
1075+ assert not is_spec_dec_dynamic_tree , "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
1076+
10541077 # use_spec_decoding is default to true by default, change in runtime by layers / requests
10551078 self .use_spec_decoding = self .is_spec_decoding_enabled
10561079
@@ -1068,7 +1091,7 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
10681091 self .spec_decoding_packed_mask = torch .empty (
10691092 [
10701093 self .max_num_requests , max_draft_tokens + 1 ,
1071- math .ceil (max_draft_tokens / 32 )
1094+ math .ceil (( max_draft_tokens + 1 ) / 32 )
10721095 ],
10731096 dtype = torch .int ,
10741097 device = 'cuda' ,
@@ -1081,7 +1104,18 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
10811104 )
10821105
10831106 if self .is_spec_dec_dynamic_tree :
1084- assert False , "currently dynamic tree is not supported"
1107+ assert spec_decoding_position_offsets is not None , "spec_decoding_position_offsets is required for dynamic tree"
1108+ assert spec_decoding_packed_mask is not None , "spec_decoding_packed_mask is required for dynamic tree"
1109+ self .spec_decoding_position_offsets .copy_ (
1110+ spec_decoding_position_offsets , non_blocking = True )
1111+ self .spec_decoding_packed_mask .copy_ (spec_decoding_packed_mask ,
1112+ non_blocking = True )
1113+ if spec_decoding_generation_lengths is not None :
1114+ self .spec_decoding_generation_lengths .copy_ (
1115+ spec_decoding_generation_lengths , non_blocking = True )
1116+ else :
1117+ self .generate_spec_decoding_generation_length (
1118+ max_draft_tokens = max_draft_tokens )
10851119 else :
10861120 # Populate the mask that won't change during inference phase.
10871121 self .generate_spec_decoding_position_offsets (
@@ -1092,7 +1126,6 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
10921126 max_draft_tokens = max_draft_tokens )
10931127
10941128 def generate_spec_decoding_position_offsets (self , max_draft_tokens ):
1095- assert not self .is_spec_dec_tree , "only chained/linear tree is supported now"
10961129 position_offset = torch .arange (max_draft_tokens + 1 ,
10971130 dtype = torch .int ,
10981131 device = 'cpu' ,
@@ -1103,7 +1136,6 @@ def generate_spec_decoding_position_offsets(self, max_draft_tokens):
11031136 non_blocking = True )
11041137
11051138 def generate_spec_decoding_packed_mask (self , max_draft_tokens ):
1106- assert not self .is_spec_dec_tree , "only chained/linear tree is supported now"
11071139 dummy_idx = torch .arange (max_draft_tokens + 1 )
11081140 spec_decoding_packed_mask = torch .pow (2 , dummy_idx + 1 ) - 1
11091141 self .spec_decoding_packed_mask [:, :, 0 ].copy_ (spec_decoding_packed_mask ,
0 commit comments