|
10 | 10 | context_attention_fwd_no_prompt_cache, |
11 | 11 | ) |
12 | 12 |
|
13 | | -from lightllm.models.deepseek2.triton_kernel.flash_decoding import token_decode_attention_flash_decoding |
14 | 13 | from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding |
15 | 14 | from lightllm.models.deepseek2.layer_infer.fused_moe import fused_experts, grouped_topk |
16 | 15 | from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer |
@@ -60,14 +59,9 @@ def __init__( |
60 | 59 |
|
61 | 60 | def _bind_attention(self): |
62 | 61 | self._context_attention_kernel = partial(Deepseek2TransformerLayerInfer._context_attention_kernel, self) |
63 | | - if "triton_flashdecoding" in self.mode: |
64 | | - self._token_attention_kernel = partial( |
65 | | - Deepseek2TransformerLayerInfer._token_decode_attention_flashdecoding, self |
66 | | - ) |
67 | | - else: |
68 | | - self._token_attention_kernel = partial( |
69 | | - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self |
70 | | - ) |
| 62 | + self._token_attention_kernel = partial( |
| 63 | + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self |
| 64 | + ) |
71 | 65 | self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
72 | 66 | if self.is_moe: |
73 | 67 | self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) |
@@ -171,24 +165,6 @@ def _context_attention_kernel( |
171 | 165 | q_rope = None |
172 | 166 | return o_tensor |
173 | 167 |
|
174 | | - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): |
175 | | - q_nope, q_rope = q |
176 | | - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank] |
177 | | - kv_rope = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :] |
178 | | - return token_decode_attention_flash_decoding( |
179 | | - q_nope, |
180 | | - q_rope, |
181 | | - kv, |
182 | | - kv_rope, |
183 | | - infer_state, |
184 | | - self.tp_q_head_num_, |
185 | | - self.kv_lora_rank, |
186 | | - self.qk_rope_head_dim, |
187 | | - self.qk_nope_head_dim, |
188 | | - self.softmax_scale, |
189 | | - alloc_tensor_func=self.alloc_tensor, |
190 | | - ) |
191 | | - |
192 | 168 | def _token_gqa_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): |
193 | 169 | q_nope, q_rope = q |
194 | 170 | kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank] |
|
0 commit comments