88
99from tensorrt_llm .logger import logger
1010
11- from ..._utils import get_sm_version
1211from ..attention_backend .trtllm import AttentionBackend , TrtllmAttention
1312from ..pyexecutor .resource_manager import BaseResourceManager
1413
@@ -136,21 +135,14 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
136135 # 1-model has separate logic for handling draft tokens
137136 return False
138137
139- if issubclass (attention_backend ,
140- TrtllmAttention ) and self .is_mtp_eagle ():
141- # TRTLLM MLA does not work with the chunked context mode.
142- return False
143-
144- return not issubclass (attention_backend ,
145- TrtllmAttention ) or get_sm_version () != 100
138+ return not issubclass (attention_backend , TrtllmAttention )
146139
147140 def attention_need_spec_dec_mode (
148- self ,
149- spec_resource_manager : BaseResourceManager ,
150- is_draft_model : bool ,
151- attention_backend : Type [AttentionBackend ],
152- use_chain_drafter : bool , # CDL
153- is_spec_dec_tree : bool ,
141+ self ,
142+ spec_resource_manager : Optional [BaseResourceManager ],
143+ is_draft_model : bool ,
144+ attention_backend : Type [AttentionBackend ],
145+ use_chain_drafter : bool , # CDL
154146 ):
155147 """
156148 If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@@ -159,22 +151,19 @@ def attention_need_spec_dec_mode(
159151 is_draft_model: whether the model is a draft model.
160152 attention_backend: the attention backend.
161153 use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False.
162- is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree.
163154 """
164155 is_trtllm_attention = issubclass (attention_backend , TrtllmAttention )
165- # Case 1: one model
156+
157+ # Always use the multi-token query mode for 1-model.
158+ # For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
159+ # the target model (verification) or on the first draft for CDL based speculation.
166160 use_case_1 = self .is_eagle3_one_model ()
167- # Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
168- use_case_2 = self .is_eagle3 (
169- ) and spec_resource_manager .is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
170- # Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
171- use_case_3 = self .is_eagle3 (
172- ) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention
173- # Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention
174- use_case_4 = self .is_eagle3 (
175- ) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention
176-
177- return use_case_1 or use_case_2 or use_case_3 or use_case_4
161+ use_case_2 = self .is_eagle3 () and (
162+ not is_draft_model or
163+ (spec_resource_manager .is_first_draft
164+ and use_chain_drafter )) and is_trtllm_attention
165+
166+ return use_case_1 or use_case_2
178167
179168 @staticmethod
180169 def from_string (name : Optional [str ]) -> "SpeculativeDecodingMode" :
0 commit comments