File tree Expand file tree Collapse file tree 3 files changed +11
-15
lines changed
Expand file tree Collapse file tree 3 files changed +11
-15
lines changed Original file line number Diff line number Diff line change 4848 get_spec_metadata ,
4949 update_spec_config_from_model_config )
5050from ..speculative .drafting_loops import BaseDraftingLoopWrapper
51- from ..speculative .eagle3 import (Eagle3OneModelSpecMetadata ,
52- Eagle3ResourceManager , Eagle3SpecMetadata )
51+ from ..speculative .eagle3 import Eagle3ResourceManager , Eagle3SpecMetadata
5352from ..speculative .mtp import SampleStateTensorsMTP
5453from ..speculative .utils import SpecDecodingTensor
5554from ..utils import (get_model_extra_attrs ,
@@ -2684,9 +2683,9 @@ def previous_seq_slots_device():
26842683 num_accepted_draft_tokens )]
26852684 if isinstance (spec_metadata , Eagle3SpecMetadata ):
26862685 spec_metadata .request_accepted_path = request_accepted_path
2687- if isinstance ( spec_metadata , Eagle3OneModelSpecMetadata ):
2688- spec_metadata .populate_sampling_params_for_one_model (
2689- scheduled_requests .all_requests ())
2686+ # No-op for non 1-model
2687+ spec_metadata .populate_sampling_params_for_one_model (
2688+ scheduled_requests .all_requests ())
26902689 spec_metadata .prepare ()
26912690 inputs ['spec_metadata' ] = spec_metadata
26922691
Original file line number Diff line number Diff line change @@ -281,16 +281,12 @@ def create_py_executor(
281281 )
282282 llm_args .disable_overlap_scheduler = True
283283
284- if spec_config is not None and spec_config .spec_dec_mode .use_one_engine ():
285- if not spec_config .allow_advanced_sampling :
286- logger .warning (
287- f"Falling back to greedy decoding for { spec_config .decoding_type } . If you "
288- "want to use non-greedy sampling, please set allow_advanced_sampling=True."
289- )
290- elif spec_config .spec_dec_mode .is_mtp_one_model ():
291- logger .warning (
292- "Advanced sampling is not supported for MTP yet - this will be added soon."
293- )
284+ if spec_config is not None and spec_config .spec_dec_mode .use_one_engine (
285+ ) and not spec_config .allow_advanced_sampling :
286+ logger .warning (
287+ f"Falling back to greedy decoding for { spec_config .decoding_type } . If you "
288+ "want to use non-greedy sampling, please set allow_advanced_sampling=True."
289+ )
294290
295291 if mm_encoder_only :
296292 llm_args .mm_encoder_only = True
Original file line number Diff line number Diff line change @@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
3131 mtp_num_modules = spec_config .num_nextn_predict_layers ,
3232 max_num_requests = max_num_requests ,
3333 mtp_hidden_states_manager = spec_resource_manager ,
34+ allow_advanced_sampling = spec_config .allow_advanced_sampling ,
3435 )
3536 if spec_config .spec_dec_mode .is_mtp_eagle ():
3637 return Eagle3SpecMetadata (
You can’t perform that action at this time.
0 commit comments