@@ -10,7 +10,7 @@ def generate_scale_name(name):
1010 return weight_scale_name , input_scale_name
1111
1212
13- QUANTED_WEIGHT = os .getenv ("QUANTED_WEIGHT " , "0" ).upper () in ["1" , "TRUE" , "ON" ]
13+ STATIC_QUANT = os .getenv ("STATIC_QUANT " , "0" ).upper () in ["1" , "TRUE" , "ON" ]
1414
1515
1616class MMWeightTpl (BaseWeightTpl ):
@@ -43,7 +43,7 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
4343
4444 def _post_load_weights (self ):
4545 if self .quant_method is not None :
46- if QUANTED_WEIGHT :
46+ if STATIC_QUANT :
4747 if all (w is not None for w in [self .weight , self .weight_scale , self .input_scale ]):
4848 self .weight = self .quant_method .quantize ((self .weight , self .weight_scale , self .input_scale ))
4949 else :
@@ -86,11 +86,11 @@ def load_hf_weights(self, weights):
8686 bias = weights [self .bias_name ].to (self .data_type_ )[self .start : self .end ]
8787 self .bias = bias .cuda (self .tp_rank_ )
8888
89- if QUANTED_WEIGHT and self .weight_scale_name in weights :
89+ if STATIC_QUANT and self .weight_scale_name in weights :
9090 weight_scale = weights [self .weight_scale_name ].to (torch .float )[self .start : self .end ]
9191 self .weight_scale = weight_scale .cuda ()
9292
93- if QUANTED_WEIGHT and self .input_scale_name in weights :
93+ if STATIC_QUANT and self .input_scale_name in weights :
9494 input_scale = weights [self .input_scale_name ].to (torch .float )
9595 self .input_scale = input_scale .cuda ()
9696
@@ -122,11 +122,11 @@ def load_hf_weights(self, weights):
122122 bias = weights [self .bias_name ]
123123 self .bias = (bias / self .world_size_ ).to (self .data_type_ ).cuda (self .tp_rank_ )
124124
125- if QUANTED_WEIGHT and self .weight_scale_name in weights :
125+ if STATIC_QUANT and self .weight_scale_name in weights :
126126 weight_scale = weights [self .weight_scale_name ].to (torch .float )
127127 self .weight_scale = weight_scale .cuda ()
128128
129- if QUANTED_WEIGHT and self .input_scale_name in weights :
129+ if STATIC_QUANT and self .input_scale_name in weights :
130130 input_scale = weights [self .input_scale_name ].to (torch .float )
131131 self .input_scale = input_scale .cuda ()
132132
@@ -203,10 +203,10 @@ def load_hf_weights(self, weights):
203203 if self .has_bias and self .bias_names [i ] in weights :
204204 bias = weights [self .bias_names [i ]].to (self .data_type_ )
205205 self .biases [i ] = bias [self .starts [i ] : self .ends [i ]]
206- if QUANTED_WEIGHT and self .weight_scale_names [i ] in weights :
206+ if STATIC_QUANT and self .weight_scale_names [i ] in weights :
207207 weight_scale = weights [self .weight_scale_names [i ]][self .starts [i ] : self .ends [i ]]
208208 self .weight_scales [i ] = weight_scale .to (torch .float )
209- if QUANTED_WEIGHT and self .input_scale_names [i ] in weights :
209+ if STATIC_QUANT and self .input_scale_names [i ] in weights :
210210 input_scale = weights [self .input_scale_names [i ]].to (torch .float )
211211 self .input_scales [i ] = input_scale
212212
@@ -234,10 +234,10 @@ def load_hf_weights(self, weights):
234234 if self .has_bias and self .bias_names [i ] in weights :
235235 bias = weights [self .bias_names [i ]].to (self .data_type_ )
236236 self .biases [i ] = bias [:, self .starts [i ] : self .ends [i ]]
237- if QUANTED_WEIGHT and self .weight_scale_names [i ] in weights :
237+ if STATIC_QUANT and self .weight_scale_names [i ] in weights :
238238 weight_scale = weights [self .weight_scale_names [i ]]
239239 self .weight_scales [i ] = weight_scale .to (torch .float )
240- if QUANTED_WEIGHT and self .input_scale_names [i ] in weights :
240+ if STATIC_QUANT and self .input_scale_names [i ] in weights :
241241 input_scale = weights [self .input_scale_names [i ]].to (torch .float )
242242 self .input_scales [i ] = input_scale
243243 self ._fuse ()
0 commit comments