File tree Expand file tree Collapse file tree 3 files changed +11
-15
lines changed
Expand file tree Collapse file tree 3 files changed +11
-15
lines changed Original file line number Diff line number Diff line change 5050 get_spec_metadata ,
5151 update_spec_config_from_model_config )
5252from ..speculative .drafting_loops import BaseDraftingLoopWrapper
53- from ..speculative .eagle3 import (Eagle3OneModelSpecMetadata ,
54- Eagle3ResourceManager , Eagle3SpecMetadata )
53+ from ..speculative .eagle3 import Eagle3ResourceManager , Eagle3SpecMetadata
5554from ..speculative .mtp import SampleStateTensorsMTP
5655from ..speculative .utils import SpecDecodingTensor
5756from ..utils import (get_model_extra_attrs ,
@@ -2756,9 +2755,9 @@ def previous_seq_slots_device():
27562755 num_accepted_draft_tokens )]
27572756 if isinstance (spec_metadata , Eagle3SpecMetadata ):
27582757 spec_metadata .request_accepted_path = request_accepted_path
2759- if isinstance ( spec_metadata , Eagle3OneModelSpecMetadata ):
2760- spec_metadata .populate_sampling_params_for_one_model (
2761- scheduled_requests .all_requests ())
2758+ # No-op for non 1-model
2759+ spec_metadata .populate_sampling_params_for_one_model (
2760+ scheduled_requests .all_requests ())
27622761 spec_metadata .prepare ()
27632762 inputs ['spec_metadata' ] = spec_metadata
27642763
Original file line number Diff line number Diff line change @@ -282,16 +282,12 @@ def create_py_executor(
282282 )
283283 llm_args .disable_overlap_scheduler = True
284284
285- if spec_config is not None and spec_config .spec_dec_mode .use_one_engine ():
286- if not spec_config .allow_advanced_sampling :
287- logger .warning (
288- f"Falling back to greedy decoding for { spec_config .decoding_type } . If you "
289- "want to use non-greedy sampling, please set allow_advanced_sampling=True."
290- )
291- elif spec_config .spec_dec_mode .is_mtp_one_model ():
292- logger .warning (
293- "Advanced sampling is not supported for MTP yet - this will be added soon."
294- )
285+ if spec_config is not None and spec_config .spec_dec_mode .use_one_engine (
286+ ) and not spec_config .allow_advanced_sampling :
287+ logger .warning (
288+ f"Falling back to greedy decoding for { spec_config .decoding_type } . If you "
289+ "want to use non-greedy sampling, please set allow_advanced_sampling=True."
290+ )
295291
296292 if mm_encoder_only :
297293 llm_args .mm_encoder_only = True
Original file line number Diff line number Diff line change @@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
3131 mtp_num_modules = spec_config .num_nextn_predict_layers ,
3232 max_num_requests = max_num_requests ,
3333 mtp_hidden_states_manager = spec_resource_manager ,
34+ allow_advanced_sampling = spec_config .allow_advanced_sampling ,
3435 )
3536 if spec_config .spec_dec_mode .is_mtp_eagle ():
3637 return Eagle3SpecMetadata (
You can’t perform that action at this time.
0 commit comments