1313# limitations under the License.
1414
1515# Standard
16- from copy import deepcopy
1716from pathlib import Path
1817import logging
1918
@@ -136,6 +135,9 @@ def recompute_weight_with_sawb(
136135 integer domain.
137136 """
138137
138+ weight_pre_quant = weight_pre_quant .to ("cpu" )
139+ weight_int_as_fp = weight_int_as_fp .to ("cpu" )
140+
139141 is_w_recomputed = False
140142 weight_int_sawb : torch .Tensor | None = None
141143 weight_int_std : torch .Tensor | float | None = None
@@ -169,6 +171,7 @@ def recompute_weight_with_sawb(
169171 # some SAWB quantizers only process FP32 inputs, so weights are
170172 # temporarily upscaled
171173 weight_int_sawb = quantizer (weight_pre_quant .to (torch .float32 ))
174+ assert weight_int_sawb is not None
172175
173176 # 2. Recompute clip values using new SAWB quantizer
174177 w_cv_key = layer_name + ".quantize_weight.clip_val"
@@ -180,14 +183,33 @@ def recompute_weight_with_sawb(
180183 logger .info (
181184 f" { 'Overwrite' if w_cvn_key in new_sd else 'Add' } key: { w_cvn_key } "
182185 )
183- new_sd [w_cv_key ] = quantizer .clip_val .to ("cpu" ).to (torch .float16 )
184- new_sd [w_cvn_key ] = - quantizer .clip_val .to ("cpu" ).to (torch .float16 )
186+
187+ cv_sawb = quantizer .clip_val .to ("cpu" ).to (torch .float16 )
188+ if weight_per_channel :
189+ # Select SAWB rows only where clip value does not exceed row max
190+ cv_max = weight_pre_quant .abs ().max (dim = - 1 )[0 ]
191+ weight_int_guarded = torch .where (
192+ (cv_sawb < cv_max )[:, None ],
193+ weight_int_sawb ,
194+ weight_int_as_fp ,
195+ )
196+ cv_guarded = torch .where (cv_sawb < cv_max , cv_sawb , cv_max )
197+ weight_int_sawb = weight_int_guarded
198+ else :
199+ cv_max = weight_pre_quant .abs ().max ()
200+ weight_int_guarded = (
201+ weight_int_sawb if cv_sawb < cv_max else weight_int_as_fp
202+ )
203+ cv_guarded = torch .min (cv_sawb , cv_max )
204+
205+ new_sd [w_cv_key ] = cv_guarded
206+ new_sd [w_cvn_key ] = - cv_guarded
185207
186208 # 3. [optional] Recompute standard deviation of integer weights
187209 if verbose :
188- weight_int_sawb_as_fp = deepcopy ( weight_int_sawb ) .to (torch .float32 )
210+ weight_int_sawb_as_fp = weight_int_guarded .to (torch .float32 )
189211 if weight_per_channel :
190- weight_int_sawb_std_min = weight_int_sawb_as_fp .std (dim = - 1 )[ 0 ] .min ()
212+ weight_int_sawb_std_min = weight_int_sawb_as_fp .std (dim = - 1 ).min ()
191213 if verbose :
192214 logger .info (
193215 " Reprocessed weights "
@@ -204,11 +226,10 @@ def recompute_weight_with_sawb(
204226 f"-> { weight_int_sawb_as_fp_std :.1f} ) "
205227 f"and clips of { layer_name + '.weight' } "
206228 )
207- else :
208- log_min_std = "min_" if weight_per_channel else ""
209- log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
210- if verbose :
211- logger .info (f" Weights preserved ({ log_min_std } std={ log_w_std :.1f} )" )
229+ elif verbose :
230+ log_min_std = "min_" if weight_per_channel else ""
231+ log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
232+ logger .info (f" Weights preserved ({ log_min_std } std={ log_w_std :.1f} )" )
212233
213234 return weight_int_sawb , is_w_recomputed
214235
@@ -441,8 +462,20 @@ def save_sd_for_aiu(
441462 logger .info (
442463 "Attention: saving state dictionary without specifying a quantization "
443464 "configuration (qcfg) performs no recomputation for narrow weight "
444- "distributions and assumes the weight quantizer used was per-tensor."
465+ "distributions and assumes the weight quantizer used was 8-bit per-tensor."
445466 )
467+ else :
468+ nbits_w = qcfg .get ("nbits_w" , None )
469+ if nbits_w is None :
470+ logger .info (
471+ "Number of bits for weight quantization is not set in qcfg. "
472+ "Assuming default (nbits_w=8)."
473+ )
474+ elif nbits_w != 8 :
475+ raise ValueError (
476+ "Saving checkpoint in AIU-compliant format only supports INT8 "
477+ f"quantization for now, but found { nbits_w = } in qcfg."
478+ )
446479
447480 converted_sd = convert_sd_for_aiu (
448481 model = model ,
0 commit comments