Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,7 +2643,7 @@ def forward(self,
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
spec_resource_manager, self.is_draft_model, self.attn_backend,
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
self.model_is_wrapped)
attn_metadata.update_spec_dec_param(
batch_size=scheduled_requests.batch_size,
is_spec_decoding_enabled=is_spec_dec_mode,
Expand Down
41 changes: 16 additions & 25 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,15 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
# 1-model has separate logic for handling draft tokens
return False

if issubclass(attention_backend,
TrtllmAttention) and self.is_mtp_eagle():
# TRTLLM MLA does not work with the chunked context mode.
return False

return not issubclass(attention_backend,
TrtllmAttention) or get_sm_version() != 100
TrtllmAttention) or get_sm_version() < 90

def attention_need_spec_dec_mode(
self,
spec_resource_manager: BaseResourceManager,
is_draft_model: bool,
attention_backend: Type[AttentionBackend],
use_chain_drafter: bool, # CDL
is_spec_dec_tree: bool,
self,
spec_resource_manager: Optional[BaseResourceManager],
is_draft_model: bool,
attention_backend: Type[AttentionBackend],
use_chain_drafter: bool, # CDL
):
"""
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
Expand All @@ -159,22 +153,19 @@ def attention_need_spec_dec_mode(
is_draft_model: whether the model is a draft model.
attention_backend: the attention backend.
use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False.
is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree.
"""
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
# Case 1: one model

# Always use the multi-token query mode for 1-model.
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
# the target model (verification) or on the first draft for CDL based speculation.
use_case_1 = self.is_eagle3_one_model()
# Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
use_case_2 = self.is_eagle3(
) and spec_resource_manager.is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
# Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
use_case_3 = self.is_eagle3(
) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention
# Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention
use_case_4 = self.is_eagle3(
) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention

return use_case_1 or use_case_2 or use_case_3 or use_case_4
use_case_2 = (not is_draft_model or
(spec_resource_manager is not None
and spec_resource_manager.is_first_draft
and use_chain_drafter)) and is_trtllm_attention

return use_case_1 or use_case_2

@staticmethod
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
Expand Down
2 changes: 2 additions & 0 deletions tests/unittest/_torch/speculative/test_draft_len_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def enforce_single_worker():
# # ============================================================================
# # test 1: Generation correctness check
# # ============================================================================
@pytest.mark.skip("https://nvbugspro.nvidia.com/bug/5680911")
@pytest.mark.parametrize(
"drafter_type,schedule",
[
Expand Down Expand Up @@ -150,6 +151,7 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
],
)
@pytest.mark.high_cuda_memory
@pytest.mark.skip("https://nvbugspro.nvidia.com/bug/5680911")
def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dict):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
Expand Down
4 changes: 2 additions & 2 deletions tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
num_tokens = len(new_tokens)

accept_rate = num_accepted / num_drafted
assert accept_rate > 0.15
assert accept_rate > 0.10

# Output tests
sampling_params = SamplingParams(max_tokens=10, temperature=0)
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
speculative_config=spec_config,
max_batch_size=1,
cuda_graph_config=cuda_graph_config,
disable_overlap_scheduler=False)
disable_overlap_scheduler=True)

prompt = [", ".join(str(i) for i in range(1000))]

Expand Down