Skip to content

Commit 2ef67ad

Browse files
committed
[None][feat] Make 2-model spec dec use the 1-model kernels (Hopper)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 078d3a5 commit 2ef67ad

File tree

3 files changed

+19
-30
lines changed

3 files changed

+19
-30
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2629,7 +2629,7 @@ def forward(self,
26292629
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
26302630
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
26312631
spec_resource_manager, self.is_draft_model, self.attn_backend,
2632-
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
2632+
self.model_is_wrapped)
26332633
attn_metadata.update_spec_dec_param(
26342634
batch_size=scheduled_requests.batch_size,
26352635
is_spec_decoding_enabled=is_spec_dec_mode,

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from tensorrt_llm.logger import logger
1010

11-
from ..._utils import get_sm_version
1211
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1312
from ..pyexecutor.resource_manager import BaseResourceManager
1413

@@ -136,21 +135,14 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
136135
# 1-model has separate logic for handling draft tokens
137136
return False
138137

139-
if issubclass(attention_backend,
140-
TrtllmAttention) and self.is_mtp_eagle():
141-
# TRTLLM MLA does not work with the chunked context mode.
142-
return False
143-
144-
return not issubclass(attention_backend,
145-
TrtllmAttention) or get_sm_version() != 100
138+
return not issubclass(attention_backend, TrtllmAttention)
146139

147140
def attention_need_spec_dec_mode(
148-
self,
149-
spec_resource_manager: BaseResourceManager,
150-
is_draft_model: bool,
151-
attention_backend: Type[AttentionBackend],
152-
use_chain_drafter: bool, # CDL
153-
is_spec_dec_tree: bool,
141+
self,
142+
spec_resource_manager: Optional[BaseResourceManager],
143+
is_draft_model: bool,
144+
attention_backend: Type[AttentionBackend],
145+
use_chain_drafter: bool, # CDL
154146
):
155147
"""
156148
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@@ -159,22 +151,19 @@ def attention_need_spec_dec_mode(
159151
is_draft_model: whether the model is a draft model.
160152
attention_backend: the attention backend.
161153
use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False.
162-
is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree.
163154
"""
164155
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
165-
# Case 1: one model
156+
157+
# Always use the multi-token query mode for 1-model.
158+
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
159+
# the target model (verification) or on the first draft for CDL based speculation.
166160
use_case_1 = self.is_eagle3_one_model()
167-
# Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
168-
use_case_2 = self.is_eagle3(
169-
) and spec_resource_manager.is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
170-
# Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
171-
use_case_3 = self.is_eagle3(
172-
) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention
173-
# Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention
174-
use_case_4 = self.is_eagle3(
175-
) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention
176-
177-
return use_case_1 or use_case_2 or use_case_3 or use_case_4
161+
use_case_2 = self.is_eagle3() and (
162+
not is_draft_model or
163+
(spec_resource_manager.is_first_draft
164+
and use_chain_drafter)) and is_trtllm_attention
165+
166+
return use_case_1 or use_case_2
178167

179168
@staticmethod
180169
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
206206
num_tokens = len(new_tokens)
207207

208208
accept_rate = num_accepted / num_drafted
209-
assert accept_rate > 0.15
209+
assert accept_rate > 0.10
210210

211211
# Output tests
212212
sampling_params = SamplingParams(max_tokens=10, temperature=0)
@@ -252,7 +252,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
252252
speculative_config=spec_config,
253253
max_batch_size=1,
254254
cuda_graph_config=cuda_graph_config,
255-
disable_overlap_scheduler=False)
255+
disable_overlap_scheduler=True)
256256

257257
prompt = [", ".join(str(i) for i in range(1000))]
258258

0 commit comments

Comments
 (0)