Skip to content

Commit 7facac0

Browse files
authored
[None][fix] Fix MTP illegal memory access (#8161)
Signed-off-by: Mike Iovine <[email protected]>
1 parent ca9da1f commit 7facac0

File tree

3 files changed

+23
-18
lines changed

3 files changed

+23
-18
lines changed

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,18 @@ def has_spec_drafter(self):
110110
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
111111
"""
112112
If true, treat generation requests with draft tokens as
113-
chunked context requests at the kernel level. Required for
114-
any spec dec mode that uses the SpecExecutor.
113+
chunked context requests at the kernel level.
115114
"""
116115

117116
if self.use_one_engine():
118117
# 1-model has separate logic for handling draft tokens
119118
return False
119+
120+
if issubclass(attention_backend,
121+
TrtllmAttention) and self.is_mtp_eagle():
122+
# TRTLLM MLA does not work with the chunked context mode.
123+
return False
124+
120125
return not issubclass(attention_backend,
121126
TrtllmAttention) or get_sm_version() != 100
122127

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ def _create_draft_request_for_request(
165165
input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode,
166166
request.get_tokens(0))
167167

168+
is_eagle_style = self.spec_config.spec_dec_mode.is_eagle3(
169+
) or self.spec_config.spec_dec_mode.is_mtp_eagle()
170+
168171
# First time seeing this request - context request
169172
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
170173
# This is the first time the draft model is seeing this request.
@@ -174,10 +177,8 @@ def _create_draft_request_for_request(
174177
return self._create_context_request(request, input_tokens)
175178

176179
# For TRTLLM attention backend, we need to create a generation request for both no tokens accepted and tokens accepted
177-
elif issubclass(
178-
self.draft_model_engine.attn_backend, TrtllmAttention
179-
) and self.use_static_draft_loop and self.spec_config.spec_dec_mode.is_eagle3(
180-
):
180+
elif issubclass(self.draft_model_engine.attn_backend, TrtllmAttention
181+
) and self.use_static_draft_loop and is_eagle_style:
181182
return self._create_accepted_tokens_request_for_trtllm_attn(
182183
request, input_tokens, num_accepted_tokens)
183184

tests/integration/defs/test_e2e.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,18 +1953,17 @@ def test_ptp_quickstart_advanced_mtp_eagle(llm_root, llm_venv, model_name,
19531953
dir="./",
19541954
delete=True,
19551955
delete_on_close=True) as running_log:
1956-
llm_venv.run_cmd(
1957-
[
1958-
str(example_root / "quickstart_advanced.py"),
1959-
"--use_cuda_graph",
1960-
"--spec_decode_max_draft_len",
1961-
"1", # test 1 MTP module
1962-
"--spec_decode_algo",
1963-
"MTP",
1964-
"--model_dir",
1965-
f"{llm_models_root()}/{model_path}",
1966-
],
1967-
stdout=running_log)
1956+
llm_venv.run_cmd([
1957+
str(example_root / "quickstart_advanced.py"),
1958+
"--use_cuda_graph",
1959+
"--spec_decode_max_draft_len",
1960+
"3",
1961+
"--spec_decode_algo",
1962+
"MTP",
1963+
"--model_dir",
1964+
f"{llm_models_root()}/{model_path}",
1965+
],
1966+
stdout=running_log)
19681967
# 74.60 is the memory usage for DeepSeek-V3-Lite-BF16 with MTP Eagle 2 two model style as one extra kv cache is needed for draft model.
19691968
_check_mem_usage(running_log, [74.60, 0, 0, 0])
19701969

0 commit comments

Comments
 (0)