diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 02da92d9e..98f628f07 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -24,8 +24,7 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_page_table_buffer - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() @@ -93,3 +92,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ) ) return + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + super().init_some_extra_state(model, input_ids) + self._init_flash_attention_state(model, input_ids) + return diff --git a/lightllm/models/qwen2_vl/flashattention_infer_struct.py b/lightllm/models/qwen2_vl/flashattention_infer_struct.py new file mode 100644 index 000000000..7d96d7370 --- /dev/null +++ b/lightllm/models/qwen2_vl/flashattention_infer_struct.py @@ -0,0 +1,30 @@ +import os +import torch +import numpy as np +import torch.distributed as dist +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.batch_objs import ModelInput + + +class Qwen2VLFlashAttentionStateInfo(FlashAttentionStateInfo): + def init_some_extra_state(self, model, input_ids: torch.Tensor): + InferStateInfo.init_some_extra_state(self, model, input_ids) + if self.is_prefill: + self.max_seq_len = self.max_kv_seq_len + self.q_max_seq_len = self.max_q_seq_len + position_ids = self.position_ids + self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1) + self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1) + position_ids = None + else: + position_ids = self.position_ids + self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1) + self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1) + + # init flash attention state + self._init_flash_attention_state(model, input_ids) + return diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index f87c3d6ba..70c8bf32e 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -12,6 +12,7 @@ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from typing import List, Optional, Union from transformers.utils import TensorType, logging +from lightllm.models.qwen2_vl.flashattention_infer_struct import Qwen2VLFlashAttentionStateInfo from lightllm.common.build_utils import repair_config from lightllm.models.registry import ModelRegistry from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo @@ -20,6 +21,7 @@ import torch from PIL import Image from .vision_process import smart_resize +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.models.qwen2.layer_weights import transformer_layer_weight, pre_and_post_layer_weight from lightllm.models.qwen2.model import Qwen2TpPartModel import os @@ -103,6 +105,10 @@ def __init__(self, kvargs): super().__init__(kvargs) return + def _init_inferstate_cls(self): + if get_env_start_args().enable_fa3: + self.infer_state_class = Qwen2VLFlashAttentionStateInfo + def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file)