Skip to content

Commit 6606cdc

Browse files
committed
Fix MTP 1-model sampler
Signed-off-by: Mike Iovine <[email protected]>
1 parent 4092a87 commit 6606cdc

File tree

3 files changed

+11
-15
lines changed

3 files changed

+11
-15
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848
get_spec_metadata,
4949
update_spec_config_from_model_config)
5050
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
51-
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
52-
Eagle3ResourceManager, Eagle3SpecMetadata)
51+
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
5352
from ..speculative.mtp import SampleStateTensorsMTP
5453
from ..speculative.utils import SpecDecodingTensor
5554
from ..utils import (get_model_extra_attrs,
@@ -2684,9 +2683,9 @@ def previous_seq_slots_device():
26842683
num_accepted_draft_tokens)]
26852684
if isinstance(spec_metadata, Eagle3SpecMetadata):
26862685
spec_metadata.request_accepted_path = request_accepted_path
2687-
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
2688-
spec_metadata.populate_sampling_params_for_one_model(
2689-
scheduled_requests.all_requests())
2686+
# No-op for non 1-model
2687+
spec_metadata.populate_sampling_params_for_one_model(
2688+
scheduled_requests.all_requests())
26902689
spec_metadata.prepare()
26912690
inputs['spec_metadata'] = spec_metadata
26922691

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,16 +281,12 @@ def create_py_executor(
281281
)
282282
llm_args.disable_overlap_scheduler = True
283283

284-
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
285-
if not spec_config.allow_advanced_sampling:
286-
logger.warning(
287-
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
288-
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
289-
)
290-
elif spec_config.spec_dec_mode.is_mtp_one_model():
291-
logger.warning(
292-
"Advanced sampling is not supported for MTP yet - this will be added soon."
293-
)
284+
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(
285+
) and not spec_config.allow_advanced_sampling:
286+
logger.warning(
287+
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
288+
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
289+
)
294290

295291
if mm_encoder_only:
296292
llm_args.mm_encoder_only = True

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
3131
mtp_num_modules=spec_config.num_nextn_predict_layers,
3232
max_num_requests=max_num_requests,
3333
mtp_hidden_states_manager=spec_resource_manager,
34+
allow_advanced_sampling=spec_config.allow_advanced_sampling,
3435
)
3536
if spec_config.spec_dec_mode.is_mtp_eagle():
3637
return Eagle3SpecMetadata(

0 commit comments

Comments
 (0)