Skip to content

Commit f25d1fe

Browse files
authored
model factory (#882)
1 parent 0739a5a commit f25d1fe

File tree

34 files changed

+212
-224
lines changed

34 files changed

+212
-224
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
FROM nvcr.io/nvidia/tritonserver:24.04-py3-min as base
2-
ARG PYTORCH_VERSION=2.5.1
2+
ARG PYTORCH_VERSION=2.6.0
33
ARG PYTHON_VERSION=3.9
44
ARG CUDA_VERSION=12.4
55
ARG MAMBA_VERSION=23.1.0-1

lightllm/models/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from lightllm.models.cohere.model import CohereTpPartModel
2+
from lightllm.models.mixtral.model import MixtralTpPartModel
3+
from lightllm.models.bloom.model import BloomTpPartModel
4+
from lightllm.models.llama.model import LlamaTpPartModel
5+
from lightllm.models.starcoder.model import StarcoderTpPartModel
6+
from lightllm.models.starcoder2.model import Starcoder2TpPartModel
7+
from lightllm.models.qwen.model import QWenTpPartModel
8+
from lightllm.models.qwen2.model import Qwen2TpPartModel
9+
from lightllm.models.qwen3.model import Qwen3TpPartModel
10+
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
11+
from lightllm.models.chatglm2.model import ChatGlm2TpPartModel
12+
from lightllm.models.internlm.model import InternlmTpPartModel
13+
from lightllm.models.stablelm.model import StablelmTpPartModel
14+
from lightllm.models.internlm2.model import Internlm2TpPartModel
15+
from lightllm.models.internlm2_reward.model import Internlm2RewardTpPartModel
16+
from lightllm.models.mistral.model import MistralTpPartModel
17+
from lightllm.models.minicpm.model import MiniCPMTpPartModel
18+
from lightllm.models.llava.model import LlavaTpPartModel
19+
from lightllm.models.qwen_vl.model import QWenVLTpPartModel
20+
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
21+
from lightllm.models.phi3.model import Phi3TpPartModel
22+
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
23+
from lightllm.models.internvl.model import (
24+
InternVLLlamaTpPartModel,
25+
InternVLPhi3TpPartModel,
26+
InternVLQwen2TpPartModel,
27+
InternVLDeepSeek2TpPartModel,
28+
)
29+
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
30+
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
31+
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
32+
from lightllm.models.gemma3.model import Gemma3TpPartModel
33+
from lightllm.models.tarsier2.model import (
34+
Tarsier2Qwen2TpPartModel,
35+
Tarsier2Qwen2VLTpPartModel,
36+
Tarsier2LlamaTpPartModel,
37+
)
38+
from .registry import get_model

lightllm/models/bloom/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import torch
4+
from lightllm.models.registry import ModelRegistry
45
from lightllm.models.bloom.layer_infer.pre_layer_infer import BloomPreLayerInfer
56
from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer
67
from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer
@@ -12,6 +13,7 @@
1213
from lightllm.common.build_utils import repair_config
1314

1415

16+
@ModelRegistry("bloom")
1517
class BloomTpPartModel(TpPartBaseModel):
1618
# weight class
1719
pre_and_post_weight_class = BloomPreAndPostLayerWeight

lightllm/models/chatglm2/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import torch
44

5+
from lightllm.models.registry import ModelRegistry
56
from lightllm.models.chatglm2.layer_infer.transformer_layer_infer import ChatGLM2TransformerLayerInfer
67
from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight
78
from lightllm.models.chatglm2.layer_weights.pre_and_post_layer_weight import ChatGLM2PreAndPostLayerWeight
@@ -12,6 +13,7 @@
1213
logger = init_logger(__name__)
1314

1415

16+
@ModelRegistry("chatglm")
1517
class ChatGlm2TpPartModel(LlamaTpPartModel):
1618
# Please use the fast tokenizer from:
1719
# [THUDM/chatglm3-6b PR #12](https://huggingface.co/THUDM/chatglm3-6b/discussions/12).

lightllm/models/cohere/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
TransformerLayerCohereInferTpl,
66
)
77
from lightllm.common.mem_manager import MemoryManager
8+
from lightllm.models.registry import ModelRegistry
89
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
910
from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer
1011
from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer
@@ -17,6 +18,7 @@
1718
logger = init_logger(__name__)
1819

1920

21+
@ModelRegistry("cohere")
2022
class CohereTpPartModel(LlamaTpPartModel):
2123
pre_and_post_weight_class = CoherePreAndPostLayerWeight
2224
transformer_weight_class = CohereTransformerLayerWeight

lightllm/models/deepseek2/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from typing import final
3+
from lightllm.models.registry import ModelRegistry
34
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
45
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
56
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
@@ -49,6 +50,7 @@ def __init__(self, model):
4950
self.softmax_scale = self.softmax_scale * mscale * mscale
5051

5152

53+
@ModelRegistry(["deepseek_v2", "deepseek_v3"])
5254
class Deepseek2TpPartModel(LlamaTpPartModel):
5355
# weight class
5456
transformer_weight_class = Deepseek2TransformerLayerWeight

lightllm/models/gemma3/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import numpy as np
55
import torch
6+
from lightllm.models.registry import ModelRegistry
67
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
78
from lightllm.common.mem_utils import select_mem_manager_class
89
from lightllm.models.gemma3.infer_struct import Gemma3InferStateInfo
@@ -22,6 +23,7 @@
2223

2324
logger = init_logger(__name__)
2425

26+
2527
# Warp of the origal tokenizer
2628
class Gemma3Tokenizer(BaseMultiModalTokenizer):
2729
def __init__(self, tokenizer, model_cfg):
@@ -77,6 +79,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, add_special
7779
return input_ids
7880

7981

82+
@ModelRegistry("gemma3")
8083
class Gemma3TpPartModel(LlamaTpPartModel):
8184
# weight class
8285
pre_and_post_weight_class = Gemma3PreAndPostLayerWeight

lightllm/models/gemma_2b/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from lightllm.models.registry import ModelRegistry
12
from lightllm.models.gemma_2b.layer_weights.transformer_layer_weight import Gemma_2bTransformerLayerWeight
23
from lightllm.models.gemma_2b.layer_weights.pre_and_post_layer_weight import Gemma_2bPreAndPostLayerWeight
34
from lightllm.models.gemma_2b.layer_infer.pre_layer_infer import Gemma_2bPreLayerInfer
@@ -8,6 +9,7 @@
89
from lightllm.common.mem_utils import select_mem_manager_class
910

1011

12+
@ModelRegistry("gemma")
1113
class Gemma_2bTpPartModel(LlamaTpPartModel):
1214
# weight class
1315
pre_and_post_weight_class = Gemma_2bPreAndPostLayerWeight

lightllm/models/internlm/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import json
33
import torch
4-
4+
from lightllm.models.registry import ModelRegistry
55
from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight
66
from lightllm.models.llama.model import LlamaTpPartModel
77

88

9+
@ModelRegistry("internlm")
910
class InternlmTpPartModel(LlamaTpPartModel):
1011
# weight class
1112
transformer_weight_class = InternlmTransformerLayerWeight

lightllm/models/internlm2/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import json
33
import torch
44

5+
from lightllm.models.registry import ModelRegistry
56
from lightllm.models.internlm2.layer_weights.transformer_layer_weight import Internlm2TransformerLayerWeight
6-
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
7+
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
78
from lightllm.models.internlm.model import InternlmTpPartModel
89

910

11+
@ModelRegistry("internlm2")
1012
class Internlm2TpPartModel(InternlmTpPartModel):
1113
# weight class
12-
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
14+
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
1315
transformer_weight_class = Internlm2TransformerLayerWeight
1416

1517
def __init__(self, kvargs):
1618
super().__init__(kvargs)
17-

0 commit comments

Comments
 (0)