@@ -136,21 +136,15 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
136136 # 1-model has separate logic for handling draft tokens
137137 return False
138138
139- if issubclass (attention_backend ,
140- TrtllmAttention ) and self .is_mtp_eagle ():
141- # TRTLLM MLA does not work with the chunked context mode.
142- return False
143-
144139 return not issubclass (attention_backend ,
145- TrtllmAttention ) or get_sm_version () != 100
140+ TrtllmAttention ) or get_sm_version () < 90
146141
147142 def attention_need_spec_dec_mode (
148- self ,
149- spec_resource_manager : BaseResourceManager ,
150- is_draft_model : bool ,
151- attention_backend : Type [AttentionBackend ],
152- use_chain_drafter : bool , # CDL
153- is_spec_dec_tree : bool ,
143+ self ,
144+ spec_resource_manager : Optional [BaseResourceManager ],
145+ is_draft_model : bool ,
146+ attention_backend : Type [AttentionBackend ],
147+ use_chain_drafter : bool , # CDL
154148 ):
155149 """
156150 If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@@ -159,22 +153,19 @@ def attention_need_spec_dec_mode(
159153 is_draft_model: whether the model is a draft model.
160154 attention_backend: the attention backend.
161155 use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False.
162- is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree.
163156 """
164157 is_trtllm_attention = issubclass (attention_backend , TrtllmAttention )
165- # Case 1: one model
158+
159+ # Always use the multi-token query mode for 1-model.
160+ # For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
161+ # the target model (verification) or on the first draft for CDL based speculation.
166162 use_case_1 = self .is_eagle3_one_model ()
167- # Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
168- use_case_2 = self .is_eagle3 (
169- ) and spec_resource_manager .is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
170- # Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
171- use_case_3 = self .is_eagle3 (
172- ) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention
173- # Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention
174- use_case_4 = self .is_eagle3 (
175- ) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention
176-
177- return use_case_1 or use_case_2 or use_case_3 or use_case_4
163+ use_case_2 = (not is_draft_model or
164+ (spec_resource_manager is not None
165+ and spec_resource_manager .is_first_draft
166+ and use_chain_drafter )) and is_trtllm_attention
167+
168+ return use_case_1 or use_case_2
178169
179170 @staticmethod
180171 def from_string (name : Optional [str ]) -> "SpeculativeDecodingMode" :
0 commit comments