99 context_attention_fwd ,
1010 context_attention_fwd_no_prompt_cache ,
1111)
12+ from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_with_v import (
13+ context_attention_fwd_with_v ,
14+ context_attention_fwd_no_prompt_cache_with_v ,
15+ )
1216
1317from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
1418from lightllm .models .deepseek2 .layer_infer .fused_moe import fused_experts , grouped_topk
1822from lightllm .models .llama .infer_struct import LlamaInferStateInfo
1923from functools import partial
2024from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
25+ import os
2126
2227
2328class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -55,6 +60,12 @@ def __init__(
5560 self .softmax_scale = self .softmax_scale * mscale * mscale
5661 super ().__init__ (layer_num , tp_rank , world_size , network_config , mode )
5762 self .tp_o_head_num_ = self .tp_q_head_num_
63+
64+ self .num_heads = network_config ["num_attention_heads" ]
65+ self .num_kv_heads = network_config ["num_key_value_heads" ]
66+ self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
67+ self .mla_type = "ACCM"
68+
5869 return
5970
6071 def _bind_attention (self ):
@@ -97,7 +108,12 @@ def _get_qkv(
97108
98109 q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim )
99110 q_nope , q_rope = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
100- q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
111+ if infer_state .use_dynamic_prompt_cache and infer_state .is_prefill :
112+ self .mla_type = "ACCM"
113+ else :
114+ self .mla_type = layer_weight .mla_type
115+ if self .mla_type == "ACCM" :
116+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
101117
102118 layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
103119
@@ -123,11 +139,153 @@ def _get_o(
123139 input = input .view (- 1 , self .tp_q_head_num_ * self .kv_lora_rank )
124140 o_tensor = layer_weight .fuse_vo_weight_ .mm (input )
125141 else :
126- input = layer_weight .v_b_proj_ .bmm (input .transpose (0 , 1 )).transpose (0 , 1 )
142+ if self .mla_type == "ACCM" :
143+ input = layer_weight .v_b_proj_ .bmm (input .transpose (0 , 1 )).transpose (0 , 1 )
127144 o_tensor = layer_weight .o_weight_ .mm (input .reshape (- 1 , self .tp_q_head_num_ * self .qk_nope_head_dim ))
128145 return o_tensor
129146
147+ def _CC_method (
148+ self , q , compressed_kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
149+ ):
150+ num_local_heads = self .num_heads
151+ num_local_kv_heads = self .num_kv_heads
152+ if self .world_size_ > 1 :
153+ num_local_heads //= self .world_size_
154+ num_local_kv_heads //= self .world_size_
155+ if infer_state .use_dynamic_prompt_cache :
156+ compressed_kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
157+ # CC
158+ compressed_kv , k_pe = torch .split ( # (b*s, 1, kv_lora + qk_r)
159+ compressed_kv , [layer_weight .kv_lora_rank , layer_weight .qk_rope_head_dim ], dim = - 1
160+ )
161+ compressed_kv = compressed_kv .view (- 1 , layer_weight .kv_lora_rank )
162+ k = self .alloc_tensor (
163+ [k_pe .shape [0 ], num_local_kv_heads , layer_weight .qk_nope_head_dim + layer_weight .qk_rope_head_dim ],
164+ dtype = q [0 ].dtype ,
165+ )
166+ k [..., layer_weight .qk_nope_head_dim :] = k_pe
167+ wk = layer_weight .k_b_proj_ .weight .view (- 1 , layer_weight .k_b_proj_ .weight .shape [- 1 ])
168+ o_tensor = self .alloc_tensor ([compressed_kv .shape [0 ], wk .shape [0 ]], dtype = q [0 ].dtype )
169+ torch .mm (compressed_kv , wk .transpose (0 , 1 ), out = o_tensor )
170+ k [..., : layer_weight .qk_nope_head_dim ] = o_tensor .view (- 1 , num_local_kv_heads , layer_weight .qk_nope_head_dim )
171+ trans_weight = layer_weight .v_b_proj_ .weight .transpose (1 , 2 )
172+ wv = trans_weight .view (- 1 , trans_weight .shape [- 1 ])
173+ o_tensor = self .alloc_tensor ([compressed_kv .shape [0 ], wv .shape [0 ]], dtype = q [0 ].dtype )
174+ torch .mm (compressed_kv , wv .transpose (0 , 1 ), out = o_tensor )
175+ v = o_tensor .view (- 1 , num_local_kv_heads , layer_weight .qk_nope_head_dim )
176+ return self ._context_attention_kernel_with_v (q , k , v , infer_state , layer_weight )
177+
178+ def _ACC_method (
179+ self , q , compressed_kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
180+ ):
181+ q_ne , q_pe = q
182+ num_local_heads = self .num_heads
183+ num_local_kv_heads = self .num_kv_heads
184+ if self .world_size_ > 1 :
185+ num_local_heads //= self .world_size_
186+ num_local_kv_heads //= self .world_size_
187+ # ACC
188+ q = self .alloc_tensor (
189+ [q_ne .shape [0 ], num_local_heads , self .kv_lora_rank + self .qk_rope_head_dim ], dtype = q_ne .dtype
190+ )
191+ q [..., self .kv_lora_rank :] = q_pe
192+ torch .bmm ( # TODO: 转换成einsum 或者 cublas
193+ q_ne .transpose (0 , 1 ), # (h, b*s, qk_n)
194+ layer_weight .k_b_proj_ .weight , # (h, qk_n, kv_lora)
195+ out = q [..., : self .kv_lora_rank ].view (q_ne .shape [1 ], q_ne .shape [0 ], self .kv_lora_rank ),
196+ ).transpose (
197+ 0 , 1
198+ ) # (b*s, h, kv_lora)
199+ q_nope , q_rope = torch .split ( # (b*s, h, qk_n + qk_r) -> (b*s, h, qk_n), (b*s, h, qk_r)
200+ q , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1
201+ )
202+ if self .enable_opt_decoding_mha :
203+ import lightllm_ppl_mla
204+
205+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
206+ kvstarts = torch .cat (
207+ [infer_state .b_start_loc , infer_state .b_start_loc [- 1 :] + infer_state .b_seq_len [- 1 :]], dim = 0
208+ )
209+ lightllm_ppl_mla .decode_mla (
210+ o_tensor ,
211+ q ,
212+ compressed_kv [: infer_state .mem_end , :, :],
213+ infer_state .b_start_loc ,
214+ kvstarts ,
215+ self .softmax_scale ,
216+ q .shape [- 1 ],
217+ q_nope .shape [- 1 ],
218+ )
219+ output_parallel = o_tensor
220+ else :
221+ output_parallel = self ._token_gqa_decode_attention_flashdecoding_origin (
222+ (q_nope , q_rope ), infer_state , layer_weight
223+ )
224+ o_tensor = self .alloc_tensor (
225+ [output_parallel .shape [1 ], output_parallel .shape [0 ], self .qk_nope_head_dim ], dtype = q_ne .dtype
226+ )
227+ torch .bmm ( # TODO: 转换成einsum 或者 cublas
228+ output_parallel .transpose (0 , 1 ), # (h, b*s, kv_lora)
229+ layer_weight .v_b_proj_ .weight , # (h, kv_lora, vo_d)
230+ out = o_tensor ,
231+ ).transpose (
232+ 0 , 1
233+ ) # (b*s, h, vo_d)
234+ return o_tensor
235+
130236 def _context_attention_kernel (
237+ self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
238+ ) -> torch .Tensor :
239+ if self .mla_type == "MIX" :
240+ return self ._context_attention_kernel_with_CC (q , kv , infer_state , layer_weight , out )
241+ else :
242+ return self ._context_attention_kernel_origin (q , kv , infer_state , layer_weight , out )
243+
244+ def _context_attention_kernel_with_CC (
245+ self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
246+ ) -> torch .Tensor :
247+ return self ._CC_method (q , kv , infer_state , layer_weight )
248+
249+ def _context_attention_kernel_with_v (
250+ self , q : Tuple [torch .Tensor , torch .Tensor ], kv , v , infer_state : LlamaInferStateInfo , layer_weight , out = None
251+ ) -> torch .Tensor :
252+ q_nope , q_rope = q
253+ nope_head_dim = q_nope .shape [- 1 ]
254+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
255+ if infer_state .use_dynamic_prompt_cache :
256+ context_attention_fwd_with_v (
257+ q_nope ,
258+ q_rope ,
259+ kv [:, :, :nope_head_dim ],
260+ kv [:, :, nope_head_dim :],
261+ v ,
262+ o_tensor .view (- 1 , self .tp_q_head_num_ , nope_head_dim ),
263+ infer_state .b_req_idx ,
264+ infer_state .b_start_loc ,
265+ infer_state .b_seq_len ,
266+ infer_state .b_ready_cache_len ,
267+ infer_state .max_len_in_batch ,
268+ infer_state .req_manager .req_to_token_indexs ,
269+ self .softmax_scale ,
270+ )
271+ else :
272+ context_attention_fwd_no_prompt_cache_with_v (
273+ q_nope ,
274+ q_rope ,
275+ kv [:, :, :nope_head_dim ],
276+ kv [:, :, nope_head_dim :],
277+ v ,
278+ o_tensor .view (- 1 , self .tp_q_head_num_ , nope_head_dim ),
279+ infer_state .b_start_loc ,
280+ infer_state .b_seq_len ,
281+ infer_state .max_len_in_batch ,
282+ self .softmax_scale ,
283+ )
284+ q_nope = None
285+ q_rope = None
286+ return o_tensor
287+
288+ def _context_attention_kernel_origin (
131289 self , q : Tuple [torch .Tensor , torch .Tensor ], kv , infer_state : LlamaInferStateInfo , layer_weight , out = None
132290 ) -> torch .Tensor :
133291 q_nope , q_rope = q
@@ -166,6 +324,20 @@ def _context_attention_kernel(
166324 return o_tensor
167325
168326 def _token_gqa_decode_attention_flashdecoding (self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
327+ if self .mla_type == "MIX" :
328+ return self ._token_gqa_decode_attention_flashdecoding_with_ACC (q , infer_state , layer_weight , out )
329+ else :
330+ return self ._token_gqa_decode_attention_flashdecoding_origin (q , infer_state , layer_weight , out )
331+
332+ def _token_gqa_decode_attention_flashdecoding_with_ACC (
333+ self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None
334+ ):
335+ compressed_kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][: infer_state .mem_end , :, :]
336+ return self ._ACC_method (q , compressed_kv , infer_state , layer_weight )
337+
338+ def _token_gqa_decode_attention_flashdecoding_origin (
339+ self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None
340+ ):
169341 q_nope , q_rope = q
170342 kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ]
171343 kv_rope = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank :]
0 commit comments