Skip to content

Commit 6df4156

Browse files
committed
refactor weight
1 parent a0c8bf0 commit 6df4156

File tree

13 files changed

+57
-156
lines changed

13 files changed

+57
-156
lines changed

lightllm/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
3030
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
3131
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
32-
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
32+
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel
33+
from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel
3334
from lightllm.models.gemma3.model import Gemma3TpPartModel
3435
from lightllm.models.tarsier2.model import (
3536
Tarsier2Qwen2TpPartModel,

lightllm/models/qwen3_vl/__init__.py

Whitespace-only changes.

lightllm/models/qwen3_vl/layer_infer/__init__.py

Whitespace-only changes.

lightllm/models/qwen3_vl/layer_weights/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,22 @@
11
import numpy as np
2-
from lightllm.common.basemodel import PreAndPostLayerWeight
2+
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
33

4+
# add key: language_model.xxx -> xxx
5+
# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
6+
def rename_weight_keys(weights):
7+
prefix = "model.language_model."
8+
keys = list(weights.keys())
9+
for k in keys:
10+
if prefix in k:
11+
weights[k.replace(prefix, "model.")] = weights.pop(k)
412

5-
class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight):
13+
14+
class Qwen3VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
615
def __init__(self, data_type, network_config, mode):
716
super().__init__(data_type, network_config, mode)
817
return
918

1019
def load_hf_weights(self, weights):
11-
vob_size = self.network_config_["vocab_size"]
12-
split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64)
13-
split_start = split_indexes[self.tp_rank_]
14-
split_end = split_indexes[self.tp_rank_ + 1]
15-
if "model.language_model.embed_tokens.weight" in weights:
16-
self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :])
17-
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
18-
if tie_word_embeddings:
19-
self.lm_head_weight_ = self.wte_weight_
20-
if "lm_head.weight" in weights:
21-
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
22-
if "model.language_model.norm.weight" in weights:
23-
self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"])
24-
25-
return
26-
27-
def verify_load(self):
28-
errors = "weights load not ok"
29-
weights = [
30-
self.wte_weight_,
31-
self.lm_head_weight_,
32-
self.final_norm_weight_,
33-
]
34-
for i in range(len(weights)):
35-
assert weights[i] is not None, "index:" + str(i) + " " + errors
20+
rename_weight_keys(weights)
21+
super().load_hf_weights(weights)
3622
return

lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

lightllm/models/qwen3_vl/model.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
import os
22
import json
3-
import numpy as np
43
from lightllm.common.build_utils import repair_config
54
from lightllm.models.registry import ModelRegistry
65
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
76
from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
87
from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
98
from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
10-
from lightllm.models.qwen3_vl.layer_weights.transformers_layer_weight import Qwen3VLTransformerLayerWeight
11-
from lightllm.models.qwen3_vl_moe.layer_weights.transformers_layer_weight import Qwen3VLMOETransformerLayerWeight
12-
from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer
139
from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
1410
from lightllm.models.qwen3.model import Qwen3TpPartModel
15-
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
1611

1712

1813
class QWen3VLTokenizer(QWen2VLTokenizer):
@@ -35,38 +30,6 @@ class Qwen3VLTpPartModel(Qwen3TpPartModel):
3530
transformer_layer_infer_class = Qwen3VLTransformerLayerInfer
3631

3732
pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight
38-
transformer_weight_class = Qwen3VLTransformerLayerWeight
39-
40-
infer_state_class = Qwen3VLInferStateInfo
41-
42-
def __init__(self, kvargs):
43-
super().__init__(kvargs)
44-
return
45-
46-
def _init_inferstate_cls(self):
47-
pass
48-
49-
def _init_config(self):
50-
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
51-
all_config = json.load(json_file)
52-
self.config = all_config["text_config"]
53-
# rename keys
54-
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
55-
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
56-
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
57-
if self.finetune_config:
58-
self.config["vocab_size"] = self.finetune_config.vocab_size
59-
return
60-
61-
62-
@ModelRegistry(["qwen3_vl_moe"], is_multimodal=True)
63-
class Qwen3VLMOETpPartModel(Qwen3MOEModel):
64-
65-
pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer
66-
transformer_layer_infer_class = Qwen3VLMOETransformerLayerInfer
67-
68-
pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight
69-
transformer_weight_class = Qwen3VLMOETransformerLayerWeight
7033

7134
infer_state_class = Qwen3VLInferStateInfo
7235

lightllm/models/qwen3_vl/triton_kernel/__init__.py

Whitespace-only changes.

lightllm/models/qwen3_vl_moe/__init__.py

Whitespace-only changes.

lightllm/models/qwen3_vl_moe/layer_infer/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)