diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index f5a6ef4b8..a724d5668 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -10,6 +10,7 @@ from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import ( @@ -297,3 +298,27 @@ def _init_config(self): if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size return + + +@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class InternVLQwen3MOETpPartModel(Qwen3MOEModel): + # weight class + pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight + + # infer class + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file)["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return