@@ -122,6 +122,7 @@ def process_smoothquant(
122122
123123
124124def recompute_weight_with_sawb (
125+ weight_pre_quant : torch .Tensor ,
125126 weight_int_as_fp : torch .Tensor ,
126127 weight_per_channel : bool ,
127128 sq_a_scale : torch .Tensor | None ,
@@ -165,7 +166,7 @@ def recompute_weight_with_sawb(
165166 quantizer .training = True # set SAWB to recompute clips
166167 # some SAWB quantizers only process FP32 inputs, so weights are
167168 # temporarily upscaled
168- weight_int_sawb = quantizer (weight_int_as_fp .to (torch .float32 ))
169+ weight_int_sawb = quantizer (weight_pre_quant .to (torch .float32 ))
169170
170171 # 2. Recompute clip values using new SAWB quantizer
171172 w_cv_key = layer_name + ".quantize_weight.clip_val"
@@ -181,7 +182,7 @@ def recompute_weight_with_sawb(
181182 new_sd [w_cvn_key ] = - quantizer .clip_val .to ("cpu" ).to (torch .float16 )
182183
183184 # 3. [optional] Recompute standard deviation of integer weights
184- if verbose and weight_int_sawb is not None :
185+ if verbose :
185186 weight_int_sawb_as_fp = deepcopy (weight_int_sawb ).to (torch .float32 )
186187 if weight_per_channel :
187188 weight_int_sawb_std_min = weight_int_sawb_as_fp .std (dim = - 1 )[0 ].min ()
@@ -236,6 +237,7 @@ def process_weight(
236237 weight_int_sawb = None
237238 if recompute_narrow_weights :
238239 weight_int_sawb , is_w_recomputed = recompute_weight_with_sawb (
240+ weight_pre_quant ,
239241 weight_int_as_fp ,
240242 weight_per_channel ,
241243 sq_a_scale ,
@@ -378,13 +380,17 @@ def convert_sd_for_aiu(
378380 process_zero_shift (model , layer_name , weight_int , new_sd )
379381
380382 elif all (excluded_key not in k for excluded_key in excluded_keys_from_new_sd ):
381- # guarding FP16 cast
382- if v .abs ().max () > torch .finfo (torch .float16 ).max :
383- raise ValueError (
384- f"Quantization parameters ({ k } ) exceeds float16 range. "
385- "Aborted state dict saving."
386- )
387- new_sd [k ] = v .to ("cpu" ).to (torch .float16 )
383+ if k not in new_sd :
384+ # guarding FP16 cast
385+ if v .abs ().max () > torch .finfo (torch .float16 ).max :
386+ raise ValueError (
387+ f"Quantization parameters ({ k } ) exceeds float16 range. "
388+ "Aborted state dict saving."
389+ )
390+ logger .info (f" Save key: { k } " )
391+ new_sd [k ] = v .to ("cpu" ).to (torch .float16 )
392+ else :
393+ logger .info (f" Skip parameter already processed: { k } " )
388394
389395 logger .info ("New state dict processed." )
390396 if verbose :
0 commit comments