Skip to content

Commit ff4c412

Browse files
committed
add qwen backend for internvl
1 parent 9845569 commit ff4c412

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

lightllm/models/internvl/model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.models.internlm2.model import Internlm2TpPartModel
44
from lightllm.models.llama.model import LlamaTpPartModel
55
from lightllm.models.phi3.model import Phi3TpPartModel
6+
from lightllm.models.qwen2.model import Qwen2TpPartModel
67
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
78
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
89
from lightllm.common.build_utils import repair_config
@@ -145,3 +146,26 @@ def _init_config(self):
145146
if self.finetune_config:
146147
self.config["vocab_size"] = self.finetune_config.vocab_size
147148
return
149+
150+
151+
class InternVLQwen2TpPartModel(Qwen2TpPartModel):
152+
# weight class
153+
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
154+
155+
# infer class
156+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
157+
158+
def __init__(self, kvargs):
159+
super().__init__(kvargs)
160+
return
161+
162+
def _init_config(self):
163+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
164+
self.config = json.load(json_file)["llm_config"]
165+
# rename keys
166+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
167+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
168+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
169+
if self.finetune_config:
170+
self.config["vocab_size"] = self.finetune_config.vocab_size
171+
return

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
2929
from lightllm.models.phi3.model import Phi3TpPartModel
3030
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
31-
from lightllm.models.internvl.model import InternVLLlamaTpPartModel, InternVLPhi3TpPartModel
31+
from lightllm.models.internvl.model import InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, InternVLQwen2TpPartModel
3232
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
3333
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
3434
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
@@ -184,6 +184,8 @@ def init_model(self, kvargs):
184184
self.model = InternVLInternlm2TpPartModel(model_kvargs)
185185
elif llm_model_type == "llama":
186186
self.model = InternVLLlamaTpPartModel(model_kvargs)
187+
elif llm_model_type == "qwen2":
188+
self.model = InternVLQwen2TpPartModel(model_kvargs)
187189
self.is_multimodal = True
188190
else:
189191
raise Exception(f"can not support {self.model_type} now")

0 commit comments

Comments
 (0)