Skip to content

Commit 545b35e

Browse files
fix
1 parent 89cb65d commit 545b35e

File tree

3 files changed

+43
-66
lines changed

3 files changed

+43
-66
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,6 @@
22
from .base_weight import BaseWeightTpl
33
from lightllm.utils.dist_utils import get_current_device_id
44

5-
# For special weight
6-
class DummyWeight(BaseWeightTpl):
7-
def __init__(self, weight_name, data_type):
8-
super().__init__()
9-
self.weight_name = weight_name
10-
self.data_type_ = data_type
11-
self.weight = None
12-
13-
def load_hf_weights(self, weights):
14-
if self.weight_name in weights:
15-
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
16-
17-
def verify_load(self):
18-
load_ok = True
19-
load_ok = load_ok and self.weight is not None
20-
return load_ok
21-
225

236
class NormWeight(BaseWeightTpl):
247
def __init__(self, weight_name, data_type, bias_name=None):

lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import numpy as np
44

55
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight
6-
from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import DummyWeight
7-
from lightllm.models.bloom import model
6+
from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight
87
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
98
from lightllm.utils.log_utils import init_logger
109

@@ -54,14 +53,20 @@ def _init_moe(self):
5453
tp_rank=0,
5554
tp_world_size=1,
5655
)
57-
self.down_proj_bias = DummyWeight(self._down_bias_name, torch.bfloat16)
58-
self.down_proj_weight_blocks = DummyWeight(self._down_blocks_name, torch.uint8)
59-
self.down_proj_weight_scales = DummyWeight(self._down_scales_name, torch.uint8)
6056

61-
self.gate_up_proj_bias = DummyWeight(self._gate_up_bias_name, torch.bfloat16)
62-
self.gate_up_proj_weight_blocks = DummyWeight(self._gate_up_blocks_name, torch.uint8)
63-
self.gate_up_proj_weight_scales = DummyWeight(self._gate_up_scales_name, torch.uint8)
64-
self.attn_sinks = DummyWeight(self._attn_sink_name, torch.bfloat16)
57+
# Current defination of experts
58+
self.down_proj_bias = NormWeight(self._down_bias_name, torch.bfloat16)
59+
self.down_proj_weight_blocks = NormWeight(self._down_blocks_name, torch.uint8)
60+
self.down_proj_weight_scales = NormWeight(self._down_scales_name, torch.uint8)
61+
62+
self.gate_up_proj_bias = NormWeight(self._gate_up_bias_name, torch.bfloat16)
63+
self.gate_up_proj_weight_blocks = NormWeight(self._gate_up_blocks_name, torch.uint8)
64+
self.gate_up_proj_weight_scales = NormWeight(self._gate_up_scales_name, torch.uint8)
65+
self.attn_sinks = NormWeight(self._attn_sink_name, torch.bfloat16)
66+
67+
def load_hf_weights(self, weights):
68+
super().load_hf_weights(weights)
69+
self._post_weight_process()
6570

6671
def _init_weight_names(self):
6772
super()._init_weight_names()
@@ -105,31 +110,34 @@ def _post_weight_process(self):
105110
self.moe_intermediate_size = self.network_config_["intermediate_size"]
106111
self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_
107112

108-
self.down_proj_weight = self._convert_moe_packed_tensors(
109-
blocks=self.down_proj_weight_blocks.weight,
110-
scales=self.down_proj_weight_scales.weight,
111-
dtype=torch.bfloat16,
112-
)[
113-
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
114-
] # (32, 1440, 2880)
115-
116-
self.gate_up_proj_weight = self._convert_moe_packed_tensors(
117-
blocks=self.gate_up_proj_weight_blocks.weight,
118-
scales=self.gate_up_proj_weight_scales.weight,
119-
dtype=torch.bfloat16,
120-
) # (32, 2880, 5760)
121-
expert_num = self.gate_up_proj_weight.shape[0]
122-
self.gate_up_proj_weight = self.gate_up_proj_weight.reshape(expert_num, -1, 2, self.moe_intermediate_size)[
123-
:, :, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
124-
].reshape(
125-
expert_num, -1, 2 * self.split_inter_size
126-
) # (32, 2880, 2880)
127-
128-
self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape(
129-
expert_num, 2, self.moe_intermediate_size
130-
)[:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)].reshape(
131-
expert_num, 2 * self.split_inter_size
132-
) # (32, 2880)
113+
if self.down_proj_weight_blocks.verify_load() and self.down_proj_weight_scales.verify_load():
114+
self.down_proj_weight = self._convert_moe_packed_tensors(
115+
blocks=self.down_proj_weight_blocks.weight,
116+
scales=self.down_proj_weight_scales.weight,
117+
dtype=torch.bfloat16,
118+
)[
119+
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
120+
] # (32, 1440, 2880)
121+
122+
if self.gate_up_proj_weight_blocks.verify_load() and self.gate_up_proj_weight_scales.verify_load():
123+
self.gate_up_proj_weight = self._convert_moe_packed_tensors(
124+
blocks=self.gate_up_proj_weight_blocks.weight,
125+
scales=self.gate_up_proj_weight_scales.weight,
126+
dtype=torch.bfloat16,
127+
) # (32, 2880, 5760)
128+
expert_num = self.gate_up_proj_weight.shape[0]
129+
self.gate_up_proj_weight = self.gate_up_proj_weight.reshape(expert_num, -1, 2, self.moe_intermediate_size)[
130+
:, :, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
131+
].reshape(
132+
expert_num, -1, 2 * self.split_inter_size
133+
) # (32, 2880, 2880)
134+
135+
if self.gate_up_proj_bias.verify_load():
136+
self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape(
137+
expert_num, 2, self.moe_intermediate_size
138+
)[:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)].reshape(
139+
expert_num, 2 * self.split_inter_size
140+
) # (32, 2880)
133141

134142
def _convert_moe_packed_tensors(
135143
self,

lightllm/models/gpt_oss/model.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,15 @@
1717

1818
logger = init_logger(__name__)
1919

20+
2021
@ModelRegistry("gpt_oss")
2122
class GptOssTpPartModel(LlamaTpPartModel):
2223
# weight class
23-
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
2424
transformer_weight_class = GptOssTransformerLayerWeight
2525

2626
# infer class
27-
pre_layer_infer_class = LlamaPreLayerInfer
28-
post_layer_infer_class = LlamaPostLayerInfer
2927
transformer_layer_infer_class = GptOssTransformerLayerInfer
3028

31-
# infer state class
32-
infer_state_class = LlamaInferStateInfo
33-
3429
def __init__(self, kvargs):
3530
super().__init__(kvargs)
3631
assert get_env_start_args().enable_fa3, "For now GPT-OSS type model only support flashattention-3"
37-
38-
def _init_weights(self):
39-
super()._init_weights()
40-
self._post_weight_process()
41-
42-
def _post_weight_process(self):
43-
for i in range(self.config["n_layer"]):
44-
self.trans_layers_weight[i]._post_weight_process()
45-

0 commit comments

Comments
 (0)