Skip to content

Commit 449c551

Browse files
committed
function private
1 parent daa6fc2 commit 449c551

File tree

4 files changed

+14
-41
lines changed

4 files changed

+14
-41
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,3 @@ def load_hf_weights(self, weights):
2626

2727
def verify_load(self):
2828
pass
29-
30-
31-
# class BaseWeightTpl(BaseWeight):
32-
# def __init__(self, weight_name, data_type, bias_name):
33-
# self.weight_name = weight_name
34-
# self.bias_name = bias_name
35-
# self.data_type_ = data_type
36-
# self.world_size_ = get_world_size()
37-
# self.tp_rank_ = get_rank()
38-
# self.weight = None
39-
# self.bias = None
40-
41-
# def load_hf_weights(self, weights):
42-
# if self.weight_name in weights:
43-
# self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.tp_rank_)
44-
# if self.bias_name in weights:
45-
# self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.tp_rank_)
46-
47-
# def verify_load(self):
48-
# load_ok = True
49-
# #Verify weight. The weight must be not None.
50-
# load_ok = load_ok and self.weight is not None
51-
# #Verify bias. If bias_name is set, it must be not None.
52-
# if self.bias_name is not None:
53-
# load_ok = load_ok and self.bias is not None
54-
# return load_ok

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ def __init__(self, data_type, split_n_embed):
1515
def set_quant_method(self, quant_method):
1616
self.quant_method = quant_method
1717

18-
def post_load_weights(self):
19-
if self.quant_method is not None:
20-
self.weight = self.quant_method.quantize(self.weight.cuda(self.tp_rank_))
21-
return
22-
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
23-
2418
def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
2519
if self.quant_method is not None:
2620
return self.quant_method.apply(input_tensor, self.weight, self.bias, out)
@@ -36,6 +30,12 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
3630
return torch.mm(input_tensor, self.weight, out=out)
3731
return torch.addmm(self.bias, input_tensor, self.weight, out=out)
3832

33+
def _post_load_weights(self):
34+
if self.quant_method is not None:
35+
self.weight = self.quant_method.quantize(self.weight.cuda(self.tp_rank_))
36+
return
37+
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
38+
3939

4040
class MMWeight(MMWeightTpl):
4141
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
@@ -69,7 +69,7 @@ def load_hf_weights(self, weights):
6969
self.bias = bias.cuda(self.tp_rank_)
7070
if weight is None:
7171
return
72-
self.post_load_weights()
72+
self._post_load_weights()
7373
return
7474

7575

@@ -89,7 +89,7 @@ def load_hf_weights(self, weights):
8989
self.bias = bias.cuda(self.tp_rank_) / self.world_size_
9090
if weight is None:
9191
return
92-
self.post_load_weights()
92+
self._post_load_weights()
9393
return
9494

9595

@@ -116,10 +116,10 @@ class MultiROWMMWeight(MultiMMWeight):
116116
def __init__(self, weight_names, data_type, split_n_embed, bias_names=None):
117117
super().__init__(weight_names, data_type, split_n_embed, bias_names)
118118

119-
def fuse(self):
119+
def _fuse(self):
120120
if self.weight is None and all(w is not None for w in self.weights):
121121
self.weight = torch.cat(self.weights, dim=0)
122-
self.post_load_weights()
122+
self._post_load_weights()
123123
if self.has_bias:
124124
if self.bias is None and all(b is not None for b in self.biases):
125125
self.bias = torch.cat(self.bias, dim=0).cuda(self.tp_rank_)
@@ -136,7 +136,7 @@ def load_hf_weights(self, weights):
136136
if self.has_bias and self.bias_names[i] in weights:
137137
bias = weights[self.bias_names[i]].to(self.data_type_)
138138
self.biases[i] = bias[start:end]
139-
self.fuse()
139+
self._fuse()
140140
return
141141

142142

lightllm/common/basemodel/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
1818
self.network_config_ = network_config
1919
self.mode = mode
2020
self.quant_cfg = quant_cfg
21-
self.init_static_params()
22-
self._init_config()
21+
self._parse_config()
2322
self._init_weight_names()
2423
self._init_weight()
2524
self.set_quantization()
2625
return
2726

28-
def _init_config(self):
27+
def _parse_config(self):
2928
pass
3029

3130
def _init_weight_names(self):

lightllm/models/llama/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _init_weight(self):
2525
self._init_ffn()
2626
self._init_norm()
2727

28-
def _init_config(self):
28+
def _parse_config(self):
2929
self.n_embed = self.network_config_["hidden_size"]
3030
self.n_head = self.network_config_["num_attention_heads"]
3131
self.n_inter = self.network_config_["intermediate_size"]

0 commit comments

Comments
 (0)