@@ -207,6 +207,7 @@ def from_dict(cls, data: dict):
207207 "Eagle" : EagleDecodingConfig ,
208208 "Lookahead" : LookaheadDecodingConfig ,
209209 "NGram" : NGramDecodingConfig ,
210+ "DraftTarget" : DraftTargetDecodingConfig ,
210211 }
211212
212213 config_class = config_classes .get (decoding_type )
@@ -238,7 +239,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
238239 dynamic_tree_max_topK : Optional [int ] = None
239240 num_eagle_layers : Optional [int ] = None
240241 max_non_leaves_per_layer : Optional [int ] = None
241- pytorch_eagle_weights_path : Optional [str ] = None
242+ pytorch_weights_path : Optional [str ] = None
242243 eagle3_one_model : Optional [bool ] = True
243244
244245 @classmethod
@@ -282,6 +283,16 @@ def from_dict(cls, data: dict):
282283 decoding_type : ClassVar [str ] = "NGram"
283284
284285
286+ class DraftTargetDecodingConfig (DecodingBaseConfig ):
287+ pytorch_weights_path : Optional [str ] = None
288+
289+ @classmethod
290+ def from_dict (cls , data : dict ):
291+ return cls (** data )
292+
293+ decoding_type : ClassVar [str ] = "DraftTarget"
294+
295+
285296class MTPDecodingConfig (DecodingBaseConfig ):
286297 num_nextn_predict_layers : Optional [int ] = 1
287298 use_relaxed_acceptance_for_thinking : Optional [bool ] = False
@@ -896,10 +907,11 @@ class BaseLlmArgs(BaseModel):
896907 default = None , description = "Cache transceiver config." )
897908
898909 # Speculative decoding parameters
899- speculative_config : Optional [Union [
900- LookaheadDecodingConfig , MedusaDecodingConfig , EagleDecodingConfig ,
901- MTPDecodingConfig , NGramDecodingConfig ]] = Field (
902- default = None , description = "Speculative decoding config." )
910+ speculative_config : Optional [
911+ Union [LookaheadDecodingConfig , MedusaDecodingConfig ,
912+ EagleDecodingConfig , MTPDecodingConfig , NGramDecodingConfig ,
913+ DraftTargetDecodingConfig ]] = Field (
914+ default = None , description = "Speculative decoding config." )
903915
904916 batching_type : Optional [BatchingType ] = Field (default = None ,
905917 description = "Batching type." )
@@ -1302,7 +1314,7 @@ def validate_speculative_config(self):
13021314 self .speculative_config = Eagle3Config (
13031315 max_draft_tokens = self .speculative_config .max_draft_len ,
13041316 draft_model_path = self .speculative_config .
1305- pytorch_eagle_weights_path ,
1317+ pytorch_weights_path ,
13061318 eagle3_one_model = self .speculative_config .
13071319 eagle3_one_model )
13081320 elif isinstance (self .speculative_config , NGramDecodingConfig ):
@@ -1320,6 +1332,16 @@ def validate_speculative_config(self):
13201332 is_use_oldest = self .speculative_config .is_use_oldest ,
13211333 is_public_pool = self .speculative_config .is_public_pool ,
13221334 )
1335+ elif isinstance (self .speculative_config , DraftTargetDecodingConfig ):
1336+ self .build_config .speculative_decoding_mode = SpeculativeDecodingMode .DRAFT_TOKENS_EXTERNAL
1337+ assert self .backend == 'pytorch'
1338+ assert self .speculative_config .max_draft_len > 0
1339+ self .build_config .max_draft_len = self .speculative_config .max_draft_len
1340+ from tensorrt_llm ._torch .speculative import DraftTargetConfig
1341+ self .speculative_config = DraftTargetConfig (
1342+ max_draft_tokens = self .speculative_config .max_draft_len ,
1343+ draft_model_path = self .speculative_config .
1344+ pytorch_weights_path )
13231345 elif isinstance (self .speculative_config , MTPDecodingConfig ):
13241346 from tensorrt_llm ._torch .speculative import MTPConfig
13251347 self .speculative_config = MTPConfig (
0 commit comments