Skip to content

Commit 0c41aeb

Browse files
mikeiovinedominicshanshan
authored andcommitted
[https://nvbugs/5814914][fix] Fix llama sm120 spec dec (NVIDIA#10765)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
1 parent 6129e62 commit 0c41aeb

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3358,9 +3358,13 @@ def forward(self,
33583358
no_cache=kv_cache_manager
33593359
is None)
33603360
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
3361+
enable_mla = is_mla(self.model.model_config.pretrained_config)
33613362
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
3362-
spec_resource_manager, self.is_draft_model, self.attn_backend,
3363-
self.model_is_wrapped)
3363+
spec_resource_manager,
3364+
self.is_draft_model,
3365+
self.attn_backend,
3366+
self.model_is_wrapped,
3367+
is_mla=enable_mla)
33643368
attn_metadata.update_spec_dec_param(
33653369
batch_size=scheduled_requests.batch_size,
33663370
is_spec_decoding_enabled=is_spec_dec_mode,

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,12 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
151151
TrtllmAttention) or not xqa_supported
152152

153153
def attention_need_spec_dec_mode(
154-
self,
155-
spec_resource_manager: Optional[BaseResourceManager],
156-
is_draft_model: bool,
157-
attention_backend: Type[AttentionBackend],
158-
use_chain_drafter: bool, # CDL
154+
self,
155+
spec_resource_manager: Optional[BaseResourceManager],
156+
is_draft_model: bool,
157+
attention_backend: Type[AttentionBackend],
158+
use_chain_drafter: bool, # CDL
159+
is_mla: bool,
159160
):
160161
"""
161162
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@@ -168,7 +169,7 @@ def attention_need_spec_dec_mode(
168169
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
169170

170171
# Always use the multi-token query mode for 1-model if the kernels are available.
171-
xqa_supported = get_sm_version() < 120
172+
xqa_supported = not is_mla or get_sm_version() < 120
172173
use_case_1 = self.use_one_engine() and xqa_supported
173174
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
174175
# the target model (verification) or on the first draft for CDL based speculation.

0 commit comments

Comments
 (0)