Skip to content

Commit 3626756

Browse files
authored
[FIX] update flash-attention for ViT to sync with official repo (#793)
Update flash-attention for ViT to sync with the official repo's [update](Dao-AILab/flash-attention@7ae5f8c#diff-e3790eb114f13873b06146deb854ad12d785a3997ad79b6cf4dd9485419eb632R39)
1 parent 49e8997 commit 3626756

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed
4040

4141
RUN pip install --no-cache-dir nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100
4242

43+
RUN git clone https://github.com/Dao-AILab/flash-attention.git -b v2.7.4.post1
44+
RUN cd flash-attention/hopper && NVCC_THREADS=128 python setup.py install
45+
4346
COPY . /lightllm
4447
RUN pip install -e /lightllm --no-cache-dir

lightllm/models/internvl/model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.models.llama.model import LlamaTpPartModel
55
from lightllm.models.phi3.model import Phi3TpPartModel
66
from lightllm.models.qwen2.model import Qwen2TpPartModel
7+
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
78
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
89
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
910
from lightllm.common.build_utils import repair_config
@@ -26,10 +27,10 @@
2627
IMG_END_TOKEN = "</img>"
2728
IMG_TOKEN = "<image>"
2829

30+
2931
# Warp of the origal tokenizer
3032
class InternvlTokenizer:
3133
def __init__(self, tokenizer, model_cfg, **kwargs):
32-
3334
self.llm_model_type = model_cfg.get("llm_config").get("model_type")
3435
self.tokenizer = tokenizer
3536
self.image_length = int(os.environ.get("INTERNVL_IMAGE_LENGTH", 256))
@@ -200,3 +201,27 @@ def _init_config(self):
200201
if self.finetune_config:
201202
self.config["vocab_size"] = self.finetune_config.vocab_size
202203
return
204+
205+
206+
class InternVLDeepSeek2TpPartModel(Deepseek2TpPartModel):
207+
# support Deepseek2,3,R1
208+
# weight class
209+
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
210+
211+
# infer class
212+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
213+
214+
def __init__(self, kvargs):
215+
super().__init__(kvargs)
216+
return
217+
218+
def _init_config(self):
219+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
220+
self.config = json.load(json_file)["llm_config"]
221+
# rename keys
222+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
223+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
224+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
225+
if self.finetune_config:
226+
self.config["vocab_size"] = self.finetune_config.vocab_size
227+
return

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def flash_attention_v3_fwd(
192192
None,
193193
None,
194194
None,
195+
None,
195196
softmax_scale,
196197
causal=False,
197198
window_size=(-1, -1),

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
2828
from lightllm.models.phi3.model import Phi3TpPartModel
2929
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
30-
from lightllm.models.internvl.model import InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, InternVLQwen2TpPartModel
30+
from lightllm.models.internvl.model import (
31+
InternVLLlamaTpPartModel,
32+
InternVLPhi3TpPartModel,
33+
InternVLQwen2TpPartModel,
34+
InternVLDeepSeek2TpPartModel,
35+
)
3136
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
3237
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
3338
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
@@ -199,6 +204,8 @@ def init_model(self, kvargs):
199204
self.model = InternVLLlamaTpPartModel(model_kvargs)
200205
elif llm_model_type == "qwen2":
201206
self.model = InternVLQwen2TpPartModel(model_kvargs)
207+
elif llm_model_type == "deepseek_v3":
208+
self.model = InternVLDeepSeek2TpPartModel(model_kvargs)
202209
self.is_multimodal = True
203210
else:
204211
raise Exception(f"can not support {self.model_type} now")

0 commit comments

Comments
 (0)