@@ -185,14 +185,9 @@ def _ACC_method(
185185 num_local_heads //= self .world_size_
186186 num_local_kv_heads //= self .world_size_
187187 # ACC
188- q_nope_up_ = self .alloc_tensor ([q_nope .shape [1 ], q_nope .shape [0 ], self .kv_lora_rank ], dtype = q_nope .dtype )
189- q_nope_up_ = torch .bmm ( # TODO: 转换成einsum 或者 cublas
190- q_nope .transpose (0 , 1 ), # (h, b*s, qk_n)
191- layer_weight .k_b_proj_ .weight , # (h, qk_n, kv_lora)
192- out = q_nope_up_ .view (q_nope .shape [1 ], q_nope .shape [0 ], self .kv_lora_rank ),
193- ).transpose (
194- 0 , 1
195- ) # (b*s, h, kv_lora)
188+ q_nope = layer_weight .k_b_proj_ .weight .bmm (
189+ q_nope .transpose (0 , 1 ),
190+ ).transpose (0 , 1 )
196191 if self .enable_opt_decoding_mha :
197192 import lightllm_ppl_mla
198193
@@ -213,19 +208,10 @@ def _ACC_method(
213208 output_parallel = o_tensor
214209 else :
215210 output_parallel = self ._token_gqa_decode_attention_flashdecoding_origin (
216- (q_nope_up_ , q_rope ), infer_state , layer_weight
211+ (q_nope , q_rope ), infer_state , layer_weight
217212 )
218- o_tensor = self .alloc_tensor (
219- [output_parallel .shape [1 ], output_parallel .shape [0 ], self .qk_nope_head_dim ], dtype = q_rope .dtype
220- )
221- o_tensor = torch .bmm ( # TODO: 转换成einsum 或者 cublas
222- output_parallel .transpose (0 , 1 ), # (h, b*s, kv_lora)
223- layer_weight .v_b_proj_ .weight , # (h, kv_lora, vo_d)
224- out = o_tensor ,
225- ).transpose (
226- 0 , 1
227- ) # (b*s, h, vo_d)
228- return o_tensor
213+ vo = layer_weight .v_b_proj_ .bmm (output_parallel .transpose (0 , 1 )).transpose (0 , 1 )
214+ return vo
229215
230216 def _context_attention_kernel (
231217 self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
0 commit comments