Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
# 1-model has separate logic for handling draft tokens
return False

xqa_supported = get_sm_version() < 120
return not issubclass(attention_backend,
TrtllmAttention) or get_sm_version() < 90
TrtllmAttention) or not xqa_supported

def attention_need_spec_dec_mode(
self,
Expand All @@ -161,14 +162,16 @@ def attention_need_spec_dec_mode(
"""
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)

# Always use the multi-token query mode for 1-model.
# Always use the multi-token query mode for 1-model if the kernels are available.
xqa_supported = get_sm_version() < 120
use_case_1 = self.use_one_engine() and xqa_supported
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
# the target model (verification) or on the first draft for CDL based speculation.
use_case_1 = self.is_eagle3_one_model()
use_case_2 = (not is_draft_model or
(spec_resource_manager is not None
and spec_resource_manager.is_first_draft
and use_chain_drafter)) and is_trtllm_attention
use_case_2 = not self.use_one_engine() and (
not is_draft_model or
(spec_resource_manager is not None
and spec_resource_manager.is_first_draft
and use_chain_drafter)) and is_trtllm_attention

return use_case_1 or use_case_2

Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5794313)
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741)
Expand Down Expand Up @@ -493,7 +493,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_multi_gpu[1-CUTLASS] S
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] SKIP (https://nvbugs/5707359)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2] SKIP (https://nvbugs/5673559)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5701445)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740075)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] SKIP (https://nvbugs/5748600)
unittest/_torch/ray_orchestrator/multi_gpu/test_multi_instance.py::test_multi_instance[tp2_2instances] SKIP (https://nvbugs/5784566)
disaggregated/test_auto_scaling.py::test_worker_restart[etcd-round_robin] SKIP (https://nvbugs/5776445)
Expand Down