|
3 | 3 | import numpy as np |
4 | 4 |
|
5 | 5 | from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight |
6 | | -from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import DummyWeight |
7 | | -from lightllm.models.bloom import model |
| 6 | +from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight |
8 | 7 | from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight |
9 | 8 | from lightllm.utils.log_utils import init_logger |
10 | 9 |
|
@@ -54,14 +53,20 @@ def _init_moe(self): |
54 | 53 | tp_rank=0, |
55 | 54 | tp_world_size=1, |
56 | 55 | ) |
57 | | - self.down_proj_bias = DummyWeight(self._down_bias_name, torch.bfloat16) |
58 | | - self.down_proj_weight_blocks = DummyWeight(self._down_blocks_name, torch.uint8) |
59 | | - self.down_proj_weight_scales = DummyWeight(self._down_scales_name, torch.uint8) |
60 | 56 |
|
61 | | - self.gate_up_proj_bias = DummyWeight(self._gate_up_bias_name, torch.bfloat16) |
62 | | - self.gate_up_proj_weight_blocks = DummyWeight(self._gate_up_blocks_name, torch.uint8) |
63 | | - self.gate_up_proj_weight_scales = DummyWeight(self._gate_up_scales_name, torch.uint8) |
64 | | - self.attn_sinks = DummyWeight(self._attn_sink_name, torch.bfloat16) |
| 57 | + # Current defination of experts |
| 58 | + self.down_proj_bias = NormWeight(self._down_bias_name, torch.bfloat16) |
| 59 | + self.down_proj_weight_blocks = NormWeight(self._down_blocks_name, torch.uint8) |
| 60 | + self.down_proj_weight_scales = NormWeight(self._down_scales_name, torch.uint8) |
| 61 | + |
| 62 | + self.gate_up_proj_bias = NormWeight(self._gate_up_bias_name, torch.bfloat16) |
| 63 | + self.gate_up_proj_weight_blocks = NormWeight(self._gate_up_blocks_name, torch.uint8) |
| 64 | + self.gate_up_proj_weight_scales = NormWeight(self._gate_up_scales_name, torch.uint8) |
| 65 | + self.attn_sinks = NormWeight(self._attn_sink_name, torch.bfloat16) |
| 66 | + |
| 67 | + def load_hf_weights(self, weights): |
| 68 | + super().load_hf_weights(weights) |
| 69 | + self._post_weight_process() |
65 | 70 |
|
66 | 71 | def _init_weight_names(self): |
67 | 72 | super()._init_weight_names() |
@@ -105,31 +110,34 @@ def _post_weight_process(self): |
105 | 110 | self.moe_intermediate_size = self.network_config_["intermediate_size"] |
106 | 111 | self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ |
107 | 112 |
|
108 | | - self.down_proj_weight = self._convert_moe_packed_tensors( |
109 | | - blocks=self.down_proj_weight_blocks.weight, |
110 | | - scales=self.down_proj_weight_scales.weight, |
111 | | - dtype=torch.bfloat16, |
112 | | - )[ |
113 | | - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : |
114 | | - ] # (32, 1440, 2880) |
115 | | - |
116 | | - self.gate_up_proj_weight = self._convert_moe_packed_tensors( |
117 | | - blocks=self.gate_up_proj_weight_blocks.weight, |
118 | | - scales=self.gate_up_proj_weight_scales.weight, |
119 | | - dtype=torch.bfloat16, |
120 | | - ) # (32, 2880, 5760) |
121 | | - expert_num = self.gate_up_proj_weight.shape[0] |
122 | | - self.gate_up_proj_weight = self.gate_up_proj_weight.reshape(expert_num, -1, 2, self.moe_intermediate_size)[ |
123 | | - :, :, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) |
124 | | - ].reshape( |
125 | | - expert_num, -1, 2 * self.split_inter_size |
126 | | - ) # (32, 2880, 2880) |
127 | | - |
128 | | - self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape( |
129 | | - expert_num, 2, self.moe_intermediate_size |
130 | | - )[:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)].reshape( |
131 | | - expert_num, 2 * self.split_inter_size |
132 | | - ) # (32, 2880) |
| 113 | + if self.down_proj_weight_blocks.verify_load() and self.down_proj_weight_scales.verify_load(): |
| 114 | + self.down_proj_weight = self._convert_moe_packed_tensors( |
| 115 | + blocks=self.down_proj_weight_blocks.weight, |
| 116 | + scales=self.down_proj_weight_scales.weight, |
| 117 | + dtype=torch.bfloat16, |
| 118 | + )[ |
| 119 | + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : |
| 120 | + ] # (32, 1440, 2880) |
| 121 | + |
| 122 | + if self.gate_up_proj_weight_blocks.verify_load() and self.gate_up_proj_weight_scales.verify_load(): |
| 123 | + self.gate_up_proj_weight = self._convert_moe_packed_tensors( |
| 124 | + blocks=self.gate_up_proj_weight_blocks.weight, |
| 125 | + scales=self.gate_up_proj_weight_scales.weight, |
| 126 | + dtype=torch.bfloat16, |
| 127 | + ) # (32, 2880, 5760) |
| 128 | + expert_num = self.gate_up_proj_weight.shape[0] |
| 129 | + self.gate_up_proj_weight = self.gate_up_proj_weight.reshape(expert_num, -1, 2, self.moe_intermediate_size)[ |
| 130 | + :, :, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) |
| 131 | + ].reshape( |
| 132 | + expert_num, -1, 2 * self.split_inter_size |
| 133 | + ) # (32, 2880, 2880) |
| 134 | + |
| 135 | + if self.gate_up_proj_bias.verify_load(): |
| 136 | + self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape( |
| 137 | + expert_num, 2, self.moe_intermediate_size |
| 138 | + )[:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)].reshape( |
| 139 | + expert_num, 2 * self.split_inter_size |
| 140 | + ) # (32, 2880) |
133 | 141 |
|
134 | 142 | def _convert_moe_packed_tensors( |
135 | 143 | self, |
|
0 commit comments