|
4 | 4 | from lightllm.models.llama.model import LlamaTpPartModel |
5 | 5 | from lightllm.models.phi3.model import Phi3TpPartModel |
6 | 6 | from lightllm.models.qwen2.model import Qwen2TpPartModel |
| 7 | +from lightllm.models.deepseek2.model import Deepseek2TpPartModel |
7 | 8 | from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer |
8 | 9 | from lightllm.server.multimodal_params import MultimodalParams, ImageItem |
9 | 10 | from lightllm.common.build_utils import repair_config |
|
26 | 27 | IMG_END_TOKEN = "</img>" |
27 | 28 | IMG_TOKEN = "<image>" |
28 | 29 |
|
| 30 | + |
29 | 31 | # Warp of the origal tokenizer |
30 | 32 | class InternvlTokenizer: |
31 | 33 | def __init__(self, tokenizer, model_cfg, **kwargs): |
32 | | - |
33 | 34 | self.llm_model_type = model_cfg.get("llm_config").get("model_type") |
34 | 35 | self.tokenizer = tokenizer |
35 | 36 | self.image_length = int(os.environ.get("INTERNVL_IMAGE_LENGTH", 256)) |
@@ -200,3 +201,27 @@ def _init_config(self): |
200 | 201 | if self.finetune_config: |
201 | 202 | self.config["vocab_size"] = self.finetune_config.vocab_size |
202 | 203 | return |
| 204 | + |
| 205 | + |
| 206 | +class InternVLDeepSeek2TpPartModel(Deepseek2TpPartModel): |
| 207 | + # support Deepseek2,3,R1 |
| 208 | + # weight class |
| 209 | + pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight |
| 210 | + |
| 211 | + # infer class |
| 212 | + pre_layer_infer_class = LlamaMultimodalPreLayerInfer |
| 213 | + |
| 214 | + def __init__(self, kvargs): |
| 215 | + super().__init__(kvargs) |
| 216 | + return |
| 217 | + |
| 218 | + def _init_config(self): |
| 219 | + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: |
| 220 | + self.config = json.load(json_file)["llm_config"] |
| 221 | + # rename keys |
| 222 | + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) |
| 223 | + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) |
| 224 | + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) |
| 225 | + if self.finetune_config: |
| 226 | + self.config["vocab_size"] = self.finetune_config.vocab_size |
| 227 | + return |
0 commit comments