diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 59a5e0129cf..99e9468f0c9 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -141,8 +141,9 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]): # 1-model has separate logic for handling draft tokens return False + xqa_supported = get_sm_version() < 120 return not issubclass(attention_backend, - TrtllmAttention) or get_sm_version() < 90 + TrtllmAttention) or not xqa_supported def attention_need_spec_dec_mode( self, @@ -161,14 +162,16 @@ def attention_need_spec_dec_mode( """ is_trtllm_attention = issubclass(attention_backend, TrtllmAttention) - # Always use the multi-token query mode for 1-model. + # Always use the multi-token query mode for 1-model if the kernels are available. + xqa_supported = get_sm_version() < 120 + use_case_1 = self.use_one_engine() and xqa_supported # For 2-model, we need to enable it when we process multiple tokens at once. This occurs with # the target model (verification) or on the first draft for CDL based speculation. - use_case_1 = self.is_eagle3_one_model() - use_case_2 = (not is_draft_model or - (spec_resource_manager is not None - and spec_resource_manager.is_first_draft - and use_chain_drafter)) and is_trtllm_attention + use_case_2 = not self.use_one_engine() and ( + not is_draft_model or + (spec_resource_manager is not None + and spec_resource_manager.is_first_draft + and use_chain_drafter)) and is_trtllm_attention return use_case_1 or use_case_2 diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 563a38a76ec..59a1c13dedd 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -443,7 +443,7 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822) unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627) examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977) -full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075) +full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5794313) examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502) unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741) unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741) @@ -493,7 +493,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_multi_gpu[1-CUTLASS] S accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] SKIP (https://nvbugs/5707359) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2] SKIP (https://nvbugs/5673559) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5701445) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740075) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] SKIP (https://nvbugs/5748600) unittest/_torch/ray_orchestrator/multi_gpu/test_multi_instance.py::test_multi_instance[tp2_2instances] SKIP (https://nvbugs/5784566) disaggregated/test_auto_scaling.py::test_worker_restart[etcd-round_robin] SKIP (https://nvbugs/5776445)