@@ -89,6 +89,7 @@ def process_smoothquant(
8989 model : PreTrainedModel ,
9090 layer_name : str ,
9191 new_sd : dict ,
92+ verbose : bool = False ,
9293) -> tuple [torch .Tensor | None , torch .Tensor | None ]:
9394 """Check if smoothquant was in use and, if so:
9495 1. compute combined weight/activation scaling factor
@@ -115,9 +116,10 @@ def process_smoothquant(
115116 "Quantization parameters (qscale) exceeds float16 range. "
116117 "Aborted state dict saving."
117118 )
118- new_sd [layer_name + ".smoothq_scale" ] = (
119- sq_scale .squeeze ().to (torch .float16 ).to ("cpu" )
120- )
119+ k = layer_name + ".smoothq_scale"
120+ if verbose :
121+ logger .info (f" Save key: { k } " )
122+ new_sd [k ] = sq_scale .squeeze ().to (torch .float16 ).to ("cpu" )
121123 return weight_scaled , sq_a_scale
122124
123125
@@ -128,7 +130,7 @@ def recompute_weight_with_sawb(
128130 sq_a_scale : torch .Tensor | None ,
129131 layer_name : str ,
130132 new_sd : dict ,
131- verbose : bool ,
133+ verbose : bool = False ,
132134) -> tuple [torch .Tensor | None , bool ]:
133135 """Use SAWB quantizer to recompute weights showing narrow distributions in the
134136 integer domain.
@@ -219,7 +221,7 @@ def process_weight(
219221 weight_per_channel : bool ,
220222 sq_a_scale : torch .Tensor | None ,
221223 new_sd : dict ,
222- verbose : bool ,
224+ verbose : bool = False ,
223225) -> tuple [torch .Tensor | None , bool | None ]:
224226 """Compute integer weights and store them into new state dictionary.
225227 If recomputation is enabled, int weights are updated using SAWB quantizer.
@@ -259,6 +261,7 @@ def process_zero_shift(
259261 layer_name : str ,
260262 weight_int : torch .Tensor | None ,
261263 new_sd : dict ,
264+ verbose : bool = False ,
262265) -> None :
263266 """Compute and store the zero shift, a correction factor that compensates the
264267 output of (W integer, X integer) matmuls to match the corresponding FP operation.
@@ -288,6 +291,9 @@ def process_zero_shift(
288291 # sum (squash) along in_feat dimension: dim=1
289292 zero_shift = torch .sum (weight_int , dim = 1 )
290293
294+ if verbose :
295+ logger .info (f" Save key: { k } " )
296+
291297 # zero shift can exceed FP16 max value, especially if INT weights have
292298 # been recomputed, so it is saved as FP32
293299 new_sd [k ] = zero_shift .to (torch .float32 ).to ("cpu" )
@@ -360,6 +366,7 @@ def convert_sd_for_aiu(
360366 model = model ,
361367 layer_name = layer_name ,
362368 new_sd = new_sd ,
369+ verbose = verbose ,
363370 )
364371
365372 weight_int , is_w_recomputed = process_weight (
@@ -377,7 +384,13 @@ def convert_sd_for_aiu(
377384 else :
378385 num_w_preserved += 1
379386
380- process_zero_shift (model , layer_name , weight_int , new_sd )
387+ process_zero_shift (
388+ model = model ,
389+ layer_name = layer_name ,
390+ weight_int = weight_int ,
391+ new_sd = new_sd ,
392+ verbose = verbose ,
393+ )
381394
382395 elif all (excluded_key not in k for excluded_key in excluded_keys_from_new_sd ):
383396 if k not in new_sd :
0 commit comments