1919from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
2020from lightllm .models .llama .triton_kernel .rmsnorm import rmsnorm_forward
2121from lightllm .models .chatglm2 .triton_kernel .rotary_emb import rotary_emb_fwd
22- from lightllm .models .llama .infer_struct import LlamaInferStateInfo
22+ from lightllm .models .deepseek2 .infer_struct import Deepseek2InferStateInfo
2323from functools import partial
2424from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2525import os
@@ -83,7 +83,7 @@ def _get_qkv(
8383 self ,
8484 input : torch .Tensor ,
8585 cache_kv ,
86- infer_state : LlamaInferStateInfo ,
86+ infer_state : Deepseek2InferStateInfo ,
8787 layer_weight : Deepseek2TransformerLayerWeight ,
8888 ) -> torch .Tensor :
8989 input = input .view (- 1 , self .embed_dim_ )
@@ -133,7 +133,7 @@ def _get_qkv(
133133 return (q_nope , q_rope ), cache_kv
134134
135135 def _get_o (
136- self , input , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
136+ self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
137137 ) -> torch .Tensor :
138138 if not self .disable_vo_absorb :
139139 input = input .view (- 1 , self .tp_q_head_num_ * self .kv_lora_rank )
@@ -145,7 +145,7 @@ def _get_o(
145145 return o_tensor
146146
147147 def _CC_method (
148- self , q , compressed_kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
148+ self , q , compressed_kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
149149 ):
150150 num_local_heads = self .num_heads
151151 num_local_kv_heads = self .num_kv_heads
@@ -176,7 +176,7 @@ def _CC_method(
176176 return self ._context_attention_kernel_with_v (q , [k_nope , k_pe ], v , infer_state , layer_weight )
177177
178178 def _ACC_method (
179- self , q , compressed_kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
179+ self , q , compressed_kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
180180 ):
181181 q_nope , q_rope = q
182182 num_local_heads = self .num_heads
@@ -185,22 +185,21 @@ def _ACC_method(
185185 num_local_heads //= self .world_size_
186186 num_local_kv_heads //= self .world_size_
187187 # ACC
188- q_nope = layer_weight .k_b_proj_ .weight . bmm (
188+ q_nope = layer_weight .k_b_proj_ .bmm (
189189 q_nope .transpose (0 , 1 ),
190190 ).transpose (0 , 1 )
191191 if self .enable_opt_decoding_mha :
192192 import lightllm_ppl_mla
193193
194194 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
195- kvstarts = torch .cat (
196- [infer_state .b_start_loc , infer_state .b_start_loc [- 1 :] + infer_state .b_seq_len [- 1 :]], dim = 0
197- )
195+ q = torch .cat ([q_nope , q_rope ], dim = - 1 )
198196 lightllm_ppl_mla .decode_mla (
199197 o_tensor ,
200198 q ,
201- compressed_kv [: infer_state .mem_end , :, :],
202- infer_state .b_start_loc ,
203- kvstarts ,
199+ compressed_kv ,
200+ infer_state .req_manager .req_to_token_indexs ,
201+ infer_state .kv_starts ,
202+ infer_state .b_req_idx ,
204203 self .softmax_scale ,
205204 q .shape [- 1 ],
206205 q_nope .shape [- 1 ],
@@ -214,20 +213,20 @@ def _ACC_method(
214213 return vo
215214
216215 def _context_attention_kernel (
217- self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
216+ self , q , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
218217 ) -> torch .Tensor :
219218 if self .mla_type == "MIX" :
220219 return self ._context_attention_kernel_with_CC (q , kv , infer_state , layer_weight , out )
221220 else :
222221 return self ._context_attention_kernel_origin (q , kv , infer_state , layer_weight , out )
223222
224223 def _context_attention_kernel_with_CC (
225- self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
224+ self , q , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
226225 ) -> torch .Tensor :
227226 return self ._CC_method (q , kv , infer_state , layer_weight )
228227
229228 def _context_attention_kernel_with_v (
230- self , q : Tuple [torch .Tensor , torch .Tensor ], k , v , infer_state : LlamaInferStateInfo , layer_weight , out = None
229+ self , q : Tuple [torch .Tensor , torch .Tensor ], k , v , infer_state : Deepseek2InferStateInfo , layer_weight , out = None
231230 ) -> torch .Tensor :
232231 q_nope , q_rope = q
233232 k_nope , k_rope = k
@@ -267,7 +266,7 @@ def _context_attention_kernel_with_v(
267266 return o_tensor
268267
269268 def _context_attention_kernel_origin (
270- self , q : Tuple [torch .Tensor , torch .Tensor ], kv , infer_state : LlamaInferStateInfo , layer_weight , out = None
269+ self , q : Tuple [torch .Tensor , torch .Tensor ], kv , infer_state : Deepseek2InferStateInfo , layer_weight , out = None
271270 ) -> torch .Tensor :
272271 q_nope , q_rope = q
273272 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
@@ -304,20 +303,22 @@ def _context_attention_kernel_origin(
304303 q_rope = None
305304 return o_tensor
306305
307- def _token_gqa_decode_attention_flashdecoding (self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
306+ def _token_gqa_decode_attention_flashdecoding (
307+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight , out = None
308+ ):
308309 if self .mla_type == "MIX" :
309310 return self ._token_gqa_decode_attention_flashdecoding_with_ACC (q , infer_state , layer_weight , out )
310311 else :
311312 return self ._token_gqa_decode_attention_flashdecoding_origin (q , infer_state , layer_weight , out )
312313
313314 def _token_gqa_decode_attention_flashdecoding_with_ACC (
314- self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None
315+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight , out = None
315316 ):
316- # compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][: infer_state.mem_end, :, : ]
317- return self ._ACC_method (q , None , infer_state , layer_weight )
317+ compressed_kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
318+ return self ._ACC_method (q , compressed_kv , infer_state , layer_weight )
318319
319320 def _token_gqa_decode_attention_flashdecoding_origin (
320- self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None
321+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight , out = None
321322 ):
322323 q_nope , q_rope = q
323324 kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ]
@@ -347,7 +348,7 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
347348 return
348349
349350 def _moe_ffn (
350- self , input , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
351+ self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
351352 ) -> torch .Tensor :
352353 hidden_states = input .view (- 1 , self .embed_dim_ )
353354 num_tokens , hidden_dim = hidden_states .shape
0 commit comments