Skip to content

Commit 04bc0cf

Browse files
committed
model factory
1 parent 42e8199 commit 04bc0cf

File tree

33 files changed

+211
-223
lines changed

33 files changed

+211
-223
lines changed

lightllm/models/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
import importlib
3+
import inspect
4+
from pathlib import Path
5+
6+
7+
def auto_import_models():
8+
"""
9+
Automatically imports all classes from model.py files in model directories
10+
"""
11+
base_dir = os.path.dirname(os.path.abspath(__file__))
12+
models_dir = Path(base_dir)
13+
for model_dir in models_dir.iterdir():
14+
if not model_dir.is_dir():
15+
continue
16+
model_file = model_dir / "model.py"
17+
if not model_file.exists():
18+
continue
19+
module_path = f"lightllm.models.{model_dir.name}.model"
20+
21+
try:
22+
importlib.import_module(module_path)
23+
except:
24+
pass
25+
26+
27+
auto_import_models()
28+
from .registry import get_model

lightllm/models/bloom/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import torch
4+
from lightllm.models.registry import ModelRegistry
45
from lightllm.models.bloom.layer_infer.pre_layer_infer import BloomPreLayerInfer
56
from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer
67
from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer
@@ -12,6 +13,7 @@
1213
from lightllm.common.build_utils import repair_config
1314

1415

16+
@ModelRegistry("bloom")
1517
class BloomTpPartModel(TpPartBaseModel):
1618
# weight class
1719
pre_and_post_weight_class = BloomPreAndPostLayerWeight

lightllm/models/chatglm2/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import torch
44

5+
from lightllm.models.registry import ModelRegistry
56
from lightllm.models.chatglm2.layer_infer.transformer_layer_infer import ChatGLM2TransformerLayerInfer
67
from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight
78
from lightllm.models.chatglm2.layer_weights.pre_and_post_layer_weight import ChatGLM2PreAndPostLayerWeight
@@ -12,6 +13,7 @@
1213
logger = init_logger(__name__)
1314

1415

16+
@ModelRegistry("chatglm")
1517
class ChatGlm2TpPartModel(LlamaTpPartModel):
1618
# Please use the fast tokenizer from:
1719
# [THUDM/chatglm3-6b PR #12](https://huggingface.co/THUDM/chatglm3-6b/discussions/12).

lightllm/models/cohere/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
TransformerLayerCohereInferTpl,
66
)
77
from lightllm.common.mem_manager import MemoryManager
8+
from lightllm.models.registry import ModelRegistry
89
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
910
from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer
1011
from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer
@@ -17,6 +18,7 @@
1718
logger = init_logger(__name__)
1819

1920

21+
@ModelRegistry("cohere")
2022
class CohereTpPartModel(LlamaTpPartModel):
2123
pre_and_post_weight_class = CoherePreAndPostLayerWeight
2224
transformer_weight_class = CohereTransformerLayerWeight

lightllm/models/deepseek2/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from typing import final
3+
from lightllm.models.registry import ModelRegistry
34
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
45
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
56
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
@@ -22,6 +23,7 @@
2223

2324
class DeepSeek2FlashInferStateExtraInfo:
2425
def __init__(self, model):
26+
print(model)
2527
num_heads = model.config["num_attention_heads"]
2628
self.tp_q_head_num = num_heads // get_dp_world_size()
2729
self.qk_nope_head_dim = model.qk_nope_head_dim
@@ -49,6 +51,7 @@ def __init__(self, model):
4951
self.softmax_scale = self.softmax_scale * mscale * mscale
5052

5153

54+
@ModelRegistry(["deepseek_v2", "deepseek_v3"])
5255
class Deepseek2TpPartModel(LlamaTpPartModel):
5356
# weight class
5457
transformer_weight_class = Deepseek2TransformerLayerWeight

lightllm/models/gemma3/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import numpy as np
55
import torch
6+
from lightllm.models.registry import ModelRegistry
67
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
78
from lightllm.common.mem_utils import select_mem_manager_class
89
from lightllm.models.gemma3.infer_struct import Gemma3InferStateInfo
@@ -22,6 +23,7 @@
2223

2324
logger = init_logger(__name__)
2425

26+
2527
# Warp of the origal tokenizer
2628
class Gemma3Tokenizer(BaseMultiModalTokenizer):
2729
def __init__(self, tokenizer, model_cfg):
@@ -77,6 +79,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, add_special
7779
return input_ids
7880

7981

82+
@ModelRegistry("gemma3")
8083
class Gemma3TpPartModel(LlamaTpPartModel):
8184
# weight class
8285
pre_and_post_weight_class = Gemma3PreAndPostLayerWeight

lightllm/models/gemma_2b/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from lightllm.models.registry import ModelRegistry
12
from lightllm.models.gemma_2b.layer_weights.transformer_layer_weight import Gemma_2bTransformerLayerWeight
23
from lightllm.models.gemma_2b.layer_weights.pre_and_post_layer_weight import Gemma_2bPreAndPostLayerWeight
34
from lightllm.models.gemma_2b.layer_infer.pre_layer_infer import Gemma_2bPreLayerInfer
@@ -8,6 +9,7 @@
89
from lightllm.common.mem_utils import select_mem_manager_class
910

1011

12+
@ModelRegistry("gemma")
1113
class Gemma_2bTpPartModel(LlamaTpPartModel):
1214
# weight class
1315
pre_and_post_weight_class = Gemma_2bPreAndPostLayerWeight

lightllm/models/internlm/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import json
33
import torch
4-
4+
from lightllm.models.registry import ModelRegistry
55
from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight
66
from lightllm.models.llama.model import LlamaTpPartModel
77

88

9+
@ModelRegistry("internlm")
910
class InternlmTpPartModel(LlamaTpPartModel):
1011
# weight class
1112
transformer_weight_class = InternlmTransformerLayerWeight

lightllm/models/internlm2/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import json
33
import torch
44

5+
from lightllm.models.registry import ModelRegistry
56
from lightllm.models.internlm2.layer_weights.transformer_layer_weight import Internlm2TransformerLayerWeight
6-
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
7+
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
78
from lightllm.models.internlm.model import InternlmTpPartModel
89

910

11+
@ModelRegistry("internlm2")
1012
class Internlm2TpPartModel(InternlmTpPartModel):
1113
# weight class
12-
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
14+
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
1315
transformer_weight_class = Internlm2TransformerLayerWeight
1416

1517
def __init__(self, kvargs):
1618
super().__init__(kvargs)
17-

lightllm/models/internlm2_reward/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
22
import json
33
import torch
4-
4+
from lightllm.models.registry import ModelRegistry, is_reward_model
55
from lightllm.models.internlm2_reward.layer_infer.post_layer_infer import Internlm2RewardPostLayerInfer
66
from lightllm.models.internlm2_reward.layer_weights.pre_and_post_layer_weight import (
77
Internlm2RewardPreAndPostLayerWeight,
88
)
99
from lightllm.models.internlm2.model import Internlm2TpPartModel
1010

1111

12+
@ModelRegistry("internlm2", condition=is_reward_model())
1213
class Internlm2RewardTpPartModel(Internlm2TpPartModel):
1314
# weight class
1415
pre_and_post_weight_class = Internlm2RewardPreAndPostLayerWeight

0 commit comments

Comments
 (0)