Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo
from functools import partial
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
Expand Down Expand Up @@ -302,7 +303,7 @@ def _context_attention_flashattention_kernel_with_CC(
self,
q: torch.Tensor,
kv,
infer_state: Deepseek2FlashInferStateInfo,
infer_state: Deepseek2FlashAttentionStateInfo,
layer_weight: Deepseek2TransformerLayerWeight,
out=None,
) -> torch.Tensor:
Expand All @@ -323,7 +324,7 @@ def _context_attention_flashattention_kernel_with_CC(
k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
v=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
cu_seqlens_q=infer_state.cu_seqlens_q,
cu_seqlens_k=infer_state.cu_seqlens_k,
cu_seqlens_k=infer_state.cu_seqlens_q,
max_seqlen_q=infer_state.q_max_seq_len,
max_seqlen_k=infer_state.max_seq_len,
softmax_scale=self.softmax_scale,
Expand Down Expand Up @@ -547,7 +548,7 @@ def _context_attention_kernel_origin_fp8(
return o_tensor

def _token_gqa_decode_attention_flashattention(
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/llama/flashinfer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.req_manager.req_to_token_indexs,
self.b_req_idx,
self.b_seq_len,
kv_starts,
self.max_len_in_batch,
kv_starts[:-1],
self.max_kv_seq_len,
kv_indices,
)
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
Expand Down