Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lightllm/models/llama/flashattention_infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
]
return cls._shared_page_table_buffer

def init_some_extra_state(self, model, input_ids: torch.Tensor):
super().init_some_extra_state(model, input_ids)
def _init_flash_attention_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
Expand Down Expand Up @@ -93,3 +92,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
)
)
return

def init_some_extra_state(self, model, input_ids: torch.Tensor):
super().init_some_extra_state(model, input_ids)
self._init_flash_attention_state(model, input_ids)
return
30 changes: 30 additions & 0 deletions lightllm/models/qwen2_vl/flashattention_infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import torch
import numpy as np
import torch.distributed as dist
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
from lightllm.common.basemodel.batch_objs import ModelInput


class Qwen2VLFlashAttentionStateInfo(FlashAttentionStateInfo):
def init_some_extra_state(self, model, input_ids: torch.Tensor):
InferStateInfo.init_some_extra_state(self, model, input_ids)
if self.is_prefill:
self.max_seq_len = self.max_kv_seq_len
self.q_max_seq_len = self.max_q_seq_len
position_ids = self.position_ids
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
position_ids = None
else:
position_ids = self.position_ids
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)

# init flash attention state
self._init_flash_attention_state(model, input_ids)
return
6 changes: 6 additions & 0 deletions lightllm/models/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from typing import List, Optional, Union
from transformers.utils import TensorType, logging
from lightllm.models.qwen2_vl.flashattention_infer_struct import Qwen2VLFlashAttentionStateInfo
from lightllm.common.build_utils import repair_config
from lightllm.models.registry import ModelRegistry
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
Expand All @@ -20,6 +21,7 @@
import torch
from PIL import Image
from .vision_process import smart_resize
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function enable_env_vars is imported but not used within this file. To maintain code cleanliness and avoid potential confusion, it's best to remove unused imports.

Suggested change
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args

from lightllm.models.qwen2.layer_weights import transformer_layer_weight, pre_and_post_layer_weight
from lightllm.models.qwen2.model import Qwen2TpPartModel
import os
Expand Down Expand Up @@ -103,6 +105,10 @@ def __init__(self, kvargs):
super().__init__(kvargs)
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
self.infer_state_class = Qwen2VLFlashAttentionStateInfo

def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
self.config = json.load(json_file)
Expand Down