Skip to content

Commit 5046c53

Browse files
shihaobaiwangzaijun
authored andcommitted
update
1 parent 71bcd72 commit 5046c53

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,20 @@ def __init__(self, data_type, network_config, mode):
1111
return
1212

1313
def load_hf_weights(self, weights):
14+
vob_size = self.network_config_["vocab_size"]
15+
split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64)
16+
split_start = split_indexes[self.tp_rank_]
17+
split_end = split_indexes[self.tp_rank_ + 1]
1418
if "model.layers.0.proj.weight" in weights:
1519
self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t()
1620
if "model.layers.0.norm_after_embedding.weight" in weights:
1721
self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"])
1822
if "model.layers.0.norm_before_output.weight" in weights:
1923
self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"])
24+
if "lm_head.weight" in weights:
25+
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
26+
if "model.norm.weight" in weights:
27+
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])
2028
return
2129

2230
def verify_load(self):

lightllm/models/qwen3_moe_mtp/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def _init_weights(self):
6464
self.pre_post_weight.verify_load()
6565
[weight.verify_load() for weight in self.trans_layers_weight]
6666
self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_
67-
self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
68-
self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_
67+
# self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
68+
# self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_
6969
return
7070

7171
def _init_infer_layer(self):

0 commit comments

Comments
 (0)