Skip to content

Commit 3e3fe9b

Browse files
committed
[None][feat] Make 2-model spec dec use the 1-model kernels (Hopper)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 078d3a5 commit 3e3fe9b

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,11 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
136136
# 1-model has separate logic for handling draft tokens
137137
return False
138138

139-
if issubclass(attention_backend,
140-
TrtllmAttention) and self.is_mtp_eagle():
141-
# TRTLLM MLA does not work with the chunked context mode.
142-
return False
143-
144-
return not issubclass(attention_backend,
145-
TrtllmAttention) or get_sm_version() != 100
139+
return not issubclass(attention_backend, TrtllmAttention)
146140

147141
def attention_need_spec_dec_mode(
148142
self,
149-
spec_resource_manager: BaseResourceManager,
143+
spec_resource_manager: Optional[BaseResourceManager],
150144
is_draft_model: bool,
151145
attention_backend: Type[AttentionBackend],
152146
use_chain_drafter: bool, # CDL
@@ -164,9 +158,9 @@ def attention_need_spec_dec_mode(
164158
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
165159
# Case 1: one model
166160
use_case_1 = self.is_eagle3_one_model()
167-
# Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
161+
# Case 2: eagle3 two model + is_first_draft + TRTLLM attention
168162
use_case_2 = self.is_eagle3(
169-
) and spec_resource_manager.is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
163+
) and spec_resource_manager.is_first_draft and is_trtllm_attention
170164
# Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
171165
use_case_3 = self.is_eagle3(
172166
) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
206206
num_tokens = len(new_tokens)
207207

208208
accept_rate = num_accepted / num_drafted
209-
assert accept_rate > 0.15
209+
assert accept_rate > 0.10
210210

211211
# Output tests
212212
sampling_params = SamplingParams(max_tokens=10, temperature=0)
@@ -252,7 +252,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
252252
speculative_config=spec_config,
253253
max_batch_size=1,
254254
cuda_graph_config=cuda_graph_config,
255-
disable_overlap_scheduler=False)
255+
disable_overlap_scheduler=True)
256256

257257
prompt = [", ".join(str(i) for i in range(1000))]
258258

0 commit comments

Comments
 (0)