@@ -248,7 +248,6 @@ class _ModelFormatKind(Enum):
248248class DecodingBaseConfig (BaseModel ):
249249 max_draft_len : Optional [int ] = None
250250 speculative_model_dir : Optional [Union [str , Path ]] = None
251- num_extra_kv_tokens : int = 0
252251
253252 @classmethod
254253 def from_dict (cls , data : dict ):
@@ -295,13 +294,6 @@ def spec_dec_mode(self):
295294 return TorchSpeculativeDecodingMode .from_string (
296295 self .decoding_type .upper ())
297296
298- def update_from_model_config (self , model_config ):
299- pass
300-
301- def get_draft_model_prompt (self ,
302- input_tokens : torch .Tensor ) -> torch .Tensor :
303- return input_tokens
304-
305297
306298class MedusaDecodingConfig (DecodingBaseConfig ):
307299 medusa_choices : Optional [List [List [int ]]] = None
@@ -345,13 +337,6 @@ def spec_dec_mode(self):
345337 return TorchSpeculativeDecodingMode .EAGLE3_ONE_MODEL
346338 return TorchSpeculativeDecodingMode .EAGLE3
347339
348- def get_draft_model_prompt (self ,
349- input_tokens : torch .Tensor ) -> torch .Tensor :
350- """
351- Eagle3 always throws away the first token when processing draft inputs
352- """
353- return input_tokens [1 :]
354-
355340
356341class UserProvidedDecodingConfig (DecodingBaseConfig ):
357342 # Cannot use real type annotations due to circular imports
@@ -448,11 +433,6 @@ def spec_dec_mode(self):
448433 return TorchSpeculativeDecodingMode .MTP_EAGLE
449434 return TorchSpeculativeDecodingMode .MTP
450435
451- def update_from_model_config (self , model_config ):
452- assert self .num_nextn_predict_layers > 0
453- if model_config .num_nextn_predict_layers == 1 and not self .use_mtp_vanilla :
454- self .num_extra_kv_tokens = self .num_nextn_predict_layers - 1
455-
456436
457437class PybindMirror (ABC ):
458438 ''' A class containing the utilities for mirroring Python classes to
@@ -1468,8 +1448,6 @@ def validate_speculative_config(self):
14681448 assert self .speculative_config .speculative_model_dir is not None , "Path to EAGLE3 weights must be specified."
14691449 self .build_config .max_draft_len = self .speculative_config .max_draft_len
14701450 self .build_config .speculative_decoding_mode = SpeculativeDecodingMode .EAGLE
1471- if self .speculative_config .eagle3_one_model :
1472- self .speculative_config .num_extra_kv_tokens = self .speculative_config .max_draft_len - 1
14731451 if self .backend not in ['pytorch' , '_autodeploy' ]:
14741452 eagle_config = _EagleConfig (
14751453 self .speculative_config .eagle_choices ,
@@ -1490,6 +1468,7 @@ def validate_speculative_config(self):
14901468 elif isinstance (self .speculative_config , DraftTargetDecodingConfig ):
14911469 assert self .backend in ['pytorch' ]
14921470 assert self .speculative_config .max_draft_len > 0
1471+ assert self .speculative_config .speculative_model_dir is not None , "Path to draft model must be specified."
14931472 self .build_config .speculative_decoding_mode = SpeculativeDecodingMode .DRAFT_TOKENS_EXTERNAL
14941473 self .build_config .max_draft_len = self .speculative_config .max_draft_len
14951474
0 commit comments