diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ba752a4e8..eccbe430d 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -23,6 +23,7 @@ from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo +from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor @@ -302,7 +303,7 @@ def _context_attention_flashattention_kernel_with_CC( self, q: torch.Tensor, kv, - infer_state: Deepseek2FlashInferStateInfo, + infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -323,7 +324,7 @@ def _context_attention_flashattention_kernel_with_CC( k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_k, + cu_seqlens_k=infer_state.cu_seqlens_q, max_seqlen_q=infer_state.q_max_seq_len, max_seqlen_k=infer_state.max_seq_len, softmax_scale=self.softmax_scale, @@ -547,7 +548,7 @@ def _context_attention_kernel_origin_fp8( return o_tensor def _token_gqa_decode_attention_flashattention( - self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index cea95a203..a0c40b57a 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -81,8 +81,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.req_manager.req_to_token_indexs, self.b_req_idx, self.b_seq_len, - kv_starts, - self.max_len_in_batch, + kv_starts[:-1], + self.max_kv_seq_len, kv_indices, ) self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(