Skip to content

Commit ca82911

Browse files
authored
[None][fix] Fix MTP 2-model (#8115)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Mike Iovine <miovine@nvidia.com>
1 parent aaf2c3c commit ca82911

File tree

4 files changed

+29
-2
lines changed

4 files changed

+29
-2
lines changed

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
3939
# Reserve one more slot for the dummy request.
4040
slot_size = self.max_seq_len + 1
4141
self.slot_manager = SlotManager(slot_size)
42-
self.max_total_draft_tokens = config.max_total_draft_tokens
42+
# This class is reused by MTP_EAGLE
43+
from ...llmapi.llm_args import EagleDecodingConfig
44+
45+
if isinstance(config, EagleDecodingConfig):
46+
self.max_total_draft_tokens = config.max_total_draft_tokens
47+
else:
48+
self.max_total_draft_tokens = self.max_draft_len
4349

4450
# empty hidden states tensor
4551
max_num_tokens = min(max_num_tokens,
@@ -55,7 +61,9 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
5561
# whether the next draft forward is the first
5662
self.is_first_draft = True
5763
self.spec_tree_manager = None
58-
if config.eagle_choices is not None:
64+
65+
if isinstance(config,
66+
EagleDecodingConfig) and config.eagle_choices is not None:
5967
self.spec_tree_manager = SpecTreeManager(
6068
max_num_requests=self.max_num_requests,
6169
use_dynamic_tree=config.use_dynamic_tree,

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def needs_kv_cache_rewind(self):
6767
) or self.is_ngram()
6868

6969
def support_overlap_scheduler(self):
70+
# TODO: fix accuracy issue
71+
if self.is_mtp_eagle():
72+
return False
73+
7074
return self.is_mtp_one_model() or self.is_eagle3_one_model(
7175
) or self.has_draft_model()
7276

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ def get_spec_metadata(spec_config,
2828
max_num_requests=max_num_requests,
2929
mtp_hidden_states_manager=spec_resource_manager,
3030
)
31+
if spec_config.spec_dec_mode.is_mtp_eagle():
32+
return Eagle3SpecMetadata(
33+
max_draft_len=spec_config.max_draft_len,
34+
spec_dec_mode=spec_config.spec_dec_mode,
35+
max_num_requests=max_num_requests,
36+
num_layers=model_config.num_hidden_layers,
37+
hidden_size=model_config.hidden_size,
38+
max_num_tokens=max_num_tokens,
39+
dtype=model_config.torch_dtype,
40+
is_draft_model=is_draft_model,
41+
eagle3_resource_manager=spec_resource_manager,
42+
layers_to_capture=None,
43+
is_mtp_eagle=True,
44+
)
3145
if spec_config.spec_dec_mode.is_eagle3():
3246
return Eagle3SpecMetadata(
3347
max_draft_len=spec_config.max_draft_len,

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ l0_b200:
5555
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
5656
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
5757
- test_e2e.py::test_ptp_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
58+
- test_e2e.py::test_ptp_quickstart_advanced_mtp_eagle[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
5859
- test_e2e.py::test_ptp_quickstart_advanced_mixed_precision
5960
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
6061
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]

0 commit comments

Comments
 (0)