1313# limitations under the License.
1414
1515# Standard
16- from copy import deepcopy
1716from pathlib import Path
1817import logging
1918
@@ -136,6 +135,9 @@ def recompute_weight_with_sawb(
136135 integer domain.
137136 """
138137
138+ weight_pre_quant = weight_pre_quant .to ("cpu" )
139+ weight_int_as_fp = weight_int_as_fp .to ("cpu" )
140+
139141 is_w_recomputed = False
140142 weight_int_sawb : torch .Tensor | None = None
141143 weight_int_std : torch .Tensor | float | None = None
@@ -169,6 +171,7 @@ def recompute_weight_with_sawb(
169171 # some SAWB quantizers only process FP32 inputs, so weights are
170172 # temporarily upscaled
171173 weight_int_sawb = quantizer (weight_pre_quant .to (torch .float32 ))
174+ assert weight_int_sawb is not None
172175
173176 # 2. Recompute clip values using new SAWB quantizer
174177 w_cv_key = layer_name + ".quantize_weight.clip_val"
@@ -180,14 +183,33 @@ def recompute_weight_with_sawb(
180183 logger .info (
181184 f" { 'Overwrite' if w_cvn_key in new_sd else 'Add' } key: { w_cvn_key } "
182185 )
183- new_sd [w_cv_key ] = quantizer .clip_val .to ("cpu" ).to (torch .float16 )
184- new_sd [w_cvn_key ] = - quantizer .clip_val .to ("cpu" ).to (torch .float16 )
186+
187+ cv_sawb = quantizer .clip_val .to ("cpu" ).to (torch .float16 )
188+ if weight_per_channel :
189+ # Select SAWB rows only where clip value does not exceed row max
190+ cv_max = weight_pre_quant .abs ().max (dim = - 1 )[0 ]
191+ weight_int_guarded = torch .where (
192+ (cv_sawb < cv_max )[:, None ],
193+ weight_int_sawb ,
194+ weight_int_as_fp ,
195+ )
196+ cv_guarded = torch .where (cv_sawb < cv_max , cv_sawb , cv_max )
197+ weight_int_sawb = weight_int_guarded
198+ else :
199+ cv_max = weight_pre_quant .abs ().max ()
200+ weight_int_guarded = (
201+ weight_int_sawb if cv_sawb < cv_max else weight_int_as_fp
202+ )
203+ cv_guarded = torch .min (cv_sawb , cv_max )
204+
205+ new_sd [w_cv_key ] = cv_guarded
206+ new_sd [w_cvn_key ] = - cv_guarded
185207
186208 # 3. [optional] Recompute standard deviation of integer weights
187209 if verbose :
188- weight_int_sawb_as_fp = deepcopy ( weight_int_sawb ) .to (torch .float32 )
210+ weight_int_sawb_as_fp = weight_int_guarded .to (torch .float32 )
189211 if weight_per_channel :
190- weight_int_sawb_std_min = weight_int_sawb_as_fp .std (dim = - 1 )[ 0 ] .min ()
212+ weight_int_sawb_std_min = weight_int_sawb_as_fp .std (dim = - 1 ).min ()
191213 if verbose :
192214 logger .info (
193215 " Reprocessed weights "
@@ -204,11 +226,11 @@ def recompute_weight_with_sawb(
204226 f"-> { weight_int_sawb_as_fp_std :.1f} ) "
205227 f"and clips of { layer_name + '.weight' } "
206228 )
207- else :
208- log_min_std = "min_" if weight_per_channel else ""
209- log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
210- if verbose :
211- logger .info (f" Weights preserved ({ log_min_std } std={ log_w_std :.1f} )" )
229+ else :
230+ log_min_std = "min_" if weight_per_channel else ""
231+ log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
232+ if verbose :
233+ logger .info (f" Weights preserved ({ log_min_std } std={ log_w_std :.1f} )" )
212234
213235 return weight_int_sawb , is_w_recomputed
214236
0 commit comments