Skip to content

Commit b6a8f58

Browse files
committed
[https://nvbugs/5740075][fix] Fix sm120 speculation
Signed-off-by: Mike Iovine <[email protected]>
1 parent e033129 commit b6a8f58

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
141141
# 1-model has separate logic for handling draft tokens
142142
return False
143143

144+
xqa_supported = get_sm_version() < 120
144145
return not issubclass(attention_backend,
145-
TrtllmAttention) or get_sm_version() < 90
146+
TrtllmAttention) or not xqa_supported
146147

147148
def attention_need_spec_dec_mode(
148149
self,
@@ -161,14 +162,16 @@ def attention_need_spec_dec_mode(
161162
"""
162163
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
163164

164-
# Always use the multi-token query mode for 1-model.
165+
# Always use the multi-token query mode for 1-model if the kernels are available.
166+
xqa_supported = get_sm_version() < 120
167+
use_case_1 = self.use_one_engine() and xqa_supported
165168
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
166169
# the target model (verification) or on the first draft for CDL based speculation.
167-
use_case_1 = self.is_eagle3_one_model()
168-
use_case_2 = (not is_draft_model or
169-
(spec_resource_manager is not None
170-
and spec_resource_manager.is_first_draft
171-
and use_chain_drafter)) and is_trtllm_attention
170+
use_case_2 = not self.use_one_engine() and (
171+
not is_draft_model or
172+
(spec_resource_manager is not None
173+
and spec_resource_manager.is_first_draft
174+
and use_chain_drafter)) and is_trtllm_attention
172175

173176
return use_case_1 or use_case_2
174177

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,6 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco
443443
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
444444
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
445445
examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977)
446-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075)
447446
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
448447
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
449448
unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741)
@@ -493,7 +492,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_multi_gpu[1-CUTLASS] S
493492
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] SKIP (https://nvbugs/5707359)
494493
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2] SKIP (https://nvbugs/5673559)
495494
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5701445)
496-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740075)
497495
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] SKIP (https://nvbugs/5748600)
498496
unittest/_torch/ray_orchestrator/multi_gpu/test_multi_instance.py::test_multi_instance[tp2_2instances] SKIP (https://nvbugs/5784566)
499497
disaggregated/test_auto_scaling.py::test_worker_restart[etcd-round_robin] SKIP (https://nvbugs/5776445)

0 commit comments

Comments
 (0)