1+ import os
12import torch
23from .base_weight import BaseWeightTpl
34from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
45
56
7+ def generate_scale_name (name ):
8+ weight_scale_name = "." .join (name .split ("." )[:- 1 ] + ["weight_scale" ])
9+ input_scale_name = "." .join (name .split ("." )[:- 1 ] + ["input_scale" ])
10+ return weight_scale_name , input_scale_name
11+
12+
13+ STATIC_QUANT = os .getenv ("STATIC_QUANT" , "0" ).upper () in ["1" , "TRUE" , "ON" ]
14+
15+
616class MMWeightTpl (BaseWeightTpl ):
717 def __init__ (self , data_type ):
818 super ().__init__ ()
919 self .data_type_ = data_type
1020 self .quant_method = None
1121 self .weight = None
1222 self .bias = None
23+ self .weight_scale = None
24+ self .input_scale = None
1325
1426 def set_quant_method (self , quant_method ):
1527 self .quant_method = quant_method
@@ -31,7 +43,11 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
3143
3244 def _post_load_weights (self ):
3345 if self .quant_method is not None :
34- self .weight = self .quant_method .quantize (self .weight .cuda (self .tp_rank_ ))
46+ if STATIC_QUANT :
47+ if all (w is not None for w in [self .weight , self .weight_scale , self .input_scale ]):
48+ self .weight = self .quant_method .quantize ((self .weight , self .weight_scale , self .input_scale ))
49+ else :
50+ self .weight = self .quant_method .quantize (self .weight .to (self .data_type_ ).cuda (self .tp_rank_ ))
3551 return
3652 self .weight = self .weight .transpose (0 , 1 ).cuda (self .tp_rank_ )
3753
@@ -43,6 +59,7 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
4359 self .end = split_n_embed * (self .tp_rank_ + 1 )
4460 self .weight_name = weight_name
4561 self .bias_name = bias_name
62+ self .weight_scale_name , self .input_scale_name = generate_scale_name (weight_name )
4663
4764 def verify_load (self ):
4865 load_ok = True
@@ -60,13 +77,24 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
6077
6178 def load_hf_weights (self , weights ):
6279 weight = None
80+ weight_scale = None
81+ input_scale = None
6382 if self .weight_name in weights :
64- weight = weights [self .weight_name ]. to ( self . data_type_ )
83+ weight = weights [self .weight_name ]
6584 self .weight = weight [self .start : self .end ]
6685 if self .bias_name in weights :
6786 bias = weights [self .bias_name ].to (self .data_type_ )[self .start : self .end ]
6887 self .bias = bias .cuda (self .tp_rank_ )
69- if weight is None :
88+
89+ if STATIC_QUANT and self .weight_scale_name in weights :
90+ weight_scale = weights [self .weight_scale_name ].to (torch .float )[self .start : self .end ]
91+ self .weight_scale = weight_scale .cuda ()
92+
93+ if STATIC_QUANT and self .input_scale_name in weights :
94+ input_scale = weights [self .input_scale_name ].to (torch .float )
95+ self .input_scale = input_scale .cuda ()
96+
97+ if weight is None and weight_scale is None and input_scale is None :
7098 return
7199 self ._post_load_weights ()
72100 return
@@ -85,13 +113,24 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
85113
86114 def load_hf_weights (self , weights ):
87115 weight = None
116+ weight_scale = None
117+ input_scale = None
88118 if self .weight_name in weights :
89- weight = weights [self .weight_name ]. to ( self . data_type_ )
119+ weight = weights [self .weight_name ]
90120 self .weight = weight [:, self .start : self .end ]
91121 if self .bias_name in weights :
92122 bias = weights [self .bias_name ]
93123 self .bias = (bias / self .world_size_ ).to (self .data_type_ ).cuda (self .tp_rank_ )
94- if weight is None :
124+
125+ if STATIC_QUANT and self .weight_scale_name in weights :
126+ weight_scale = weights [self .weight_scale_name ].to (torch .float )
127+ self .weight_scale = weight_scale .cuda ()
128+
129+ if STATIC_QUANT and self .input_scale_name in weights :
130+ input_scale = weights [self .input_scale_name ].to (torch .float )
131+ self .input_scale = input_scale .cuda ()
132+
133+ if weight is None and weight_scale is None and input_scale is None :
95134 return
96135 self ._post_load_weights ()
97136 return
@@ -109,8 +148,17 @@ def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]):
109148 self .ends = [i * (self .tp_rank_ + 1 ) for i in self .split_n_embeds ]
110149 self .weight_names = weight_names
111150 self .bias_names = bias_names
151+ self .weight_scale_names = []
152+ self .input_scale_names = []
153+ for weight_name in weight_names :
154+ weight_scale_name , input_scale_name = generate_scale_name (weight_name )
155+ self .weight_scale_names .append (weight_scale_name )
156+ self .input_scale_names .append (input_scale_name )
157+
112158 self .weights = [None ] * len (self .weight_names )
113159 self .biases = [None ] * len (self .bias_names )
160+ self .input_scales = [None ] * len (self .weight_names )
161+ self .weight_scales = [None ] * len (self .weight_names )
114162 self .has_bias = all (b is not None for b in self .bias_names ) and len (bias_names ) > 0
115163
116164 def verify_load (self ):
@@ -131,6 +179,16 @@ def _fuse(self):
131179 if self .weight is None and all (w is not None for w in self .weights ):
132180 self .weight = torch .cat (self .weights , dim = 0 )
133181 self ._post_load_weights ()
182+
183+ if self .weight_scale is None and all (w is not None for w in self .weight_scales ):
184+ self .weight_scale = torch .cat (self .weight_scales , dim = 0 ).cuda ()
185+ self ._post_load_weights ()
186+
187+ if self .input_scale is None and all (w is not None for w in self .input_scales ):
188+ input_scales = torch .stack (self .input_scales , dim = 0 )
189+ self .input_scale = torch .max (input_scales ).cuda ()
190+ self ._post_load_weights ()
191+
134192 if self .has_bias :
135193 if self .bias is None and all (b is not None for b in self .biases ):
136194 self .bias = torch .cat (self .biases , dim = 0 ).cuda (self .tp_rank_ )
@@ -140,11 +198,18 @@ def load_hf_weights(self, weights):
140198 weight = None
141199 for i in range (len (self .weight_names )):
142200 if self .weight_names [i ] in weights :
143- weight = weights [self .weight_names [i ]]. to ( self . data_type_ )
201+ weight = weights [self .weight_names [i ]]
144202 self .weights [i ] = weight [self .starts [i ] : self .ends [i ]]
145203 if self .has_bias and self .bias_names [i ] in weights :
146204 bias = weights [self .bias_names [i ]].to (self .data_type_ )
147205 self .biases [i ] = bias [self .starts [i ] : self .ends [i ]]
206+ if STATIC_QUANT and self .weight_scale_names [i ] in weights :
207+ weight_scale = weights [self .weight_scale_names [i ]][self .starts [i ] : self .ends [i ]]
208+ self .weight_scales [i ] = weight_scale .to (torch .float )
209+ if STATIC_QUANT and self .input_scale_names [i ] in weights :
210+ input_scale = weights [self .input_scale_names [i ]].to (torch .float )
211+ self .input_scales [i ] = input_scale
212+
148213 self ._fuse ()
149214 return
150215
@@ -164,11 +229,17 @@ def load_hf_weights(self, weights):
164229 weight = None
165230 for i in range (len (self .weight_names )):
166231 if self .weight_names [i ] in weights :
167- weight = weights [self .weight_names [i ]]. to ( self . data_type_ )
232+ weight = weights [self .weight_names [i ]]
168233 self .weights [i ] = weight [:, self .starts [i ] : self .ends [i ]]
169234 if self .has_bias and self .bias_names [i ] in weights :
170235 bias = weights [self .bias_names [i ]].to (self .data_type_ )
171236 self .biases [i ] = bias [:, self .starts [i ] : self .ends [i ]]
237+ if STATIC_QUANT and self .weight_scale_names [i ] in weights :
238+ weight_scale = weights [self .weight_scale_names [i ]]
239+ self .weight_scales [i ] = weight_scale .to (torch .float )
240+ if STATIC_QUANT and self .input_scale_names [i ] in weights :
241+ input_scale = weights [self .input_scale_names [i ]].to (torch .float )
242+ self .input_scales [i ] = input_scale
172243 self ._fuse ()
173244 return
174245
0 commit comments