File tree Expand file tree Collapse file tree 2 files changed +13
-8
lines changed
Expand file tree Collapse file tree 2 files changed +13
-8
lines changed Original file line number Diff line number Diff line change @@ -3358,9 +3358,13 @@ def forward(self,
33583358 no_cache = kv_cache_manager
33593359 is None )
33603360 # attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
3361+ enable_mla = is_mla (self .model .model_config .pretrained_config )
33613362 is_spec_dec_mode = spec_metadata .spec_dec_mode .attention_need_spec_dec_mode (
3362- spec_resource_manager , self .is_draft_model , self .attn_backend ,
3363- self .model_is_wrapped )
3363+ spec_resource_manager ,
3364+ self .is_draft_model ,
3365+ self .attn_backend ,
3366+ self .model_is_wrapped ,
3367+ is_mla = enable_mla )
33643368 attn_metadata .update_spec_dec_param (
33653369 batch_size = scheduled_requests .batch_size ,
33663370 is_spec_decoding_enabled = is_spec_dec_mode ,
Original file line number Diff line number Diff line change @@ -151,11 +151,12 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
151151 TrtllmAttention ) or not xqa_supported
152152
153153 def attention_need_spec_dec_mode (
154- self ,
155- spec_resource_manager : Optional [BaseResourceManager ],
156- is_draft_model : bool ,
157- attention_backend : Type [AttentionBackend ],
158- use_chain_drafter : bool , # CDL
154+ self ,
155+ spec_resource_manager : Optional [BaseResourceManager ],
156+ is_draft_model : bool ,
157+ attention_backend : Type [AttentionBackend ],
158+ use_chain_drafter : bool , # CDL
159+ is_mla : bool ,
159160 ):
160161 """
161162 If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@@ -168,7 +169,7 @@ def attention_need_spec_dec_mode(
168169 is_trtllm_attention = issubclass (attention_backend , TrtllmAttention )
169170
170171 # Always use the multi-token query mode for 1-model if the kernels are available.
171- xqa_supported = get_sm_version () < 120
172+ xqa_supported = not is_mla or get_sm_version () < 120
172173 use_case_1 = self .use_one_engine () and xqa_supported
173174 # For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
174175 # the target model (verification) or on the first draft for CDL based speculation.
You can’t perform that action at this time.
0 commit comments