Skip to content

Commit 9623337

Browse files
Merge pull request #131 from andrea-fasoli/guard_sawb_recompute
feat: add guards to sawb recomputation
2 parents 7467f68 + f5db8df commit 9623337

File tree

1 file changed

+44
-11
lines changed

1 file changed

+44
-11
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from copy import deepcopy
1716
from pathlib import Path
1817
import logging
1918

@@ -136,6 +135,9 @@ def recompute_weight_with_sawb(
136135
integer domain.
137136
"""
138137

138+
weight_pre_quant = weight_pre_quant.to("cpu")
139+
weight_int_as_fp = weight_int_as_fp.to("cpu")
140+
139141
is_w_recomputed = False
140142
weight_int_sawb: torch.Tensor | None = None
141143
weight_int_std: torch.Tensor | float | None = None
@@ -169,6 +171,7 @@ def recompute_weight_with_sawb(
169171
# some SAWB quantizers only process FP32 inputs, so weights are
170172
# temporarily upscaled
171173
weight_int_sawb = quantizer(weight_pre_quant.to(torch.float32))
174+
assert weight_int_sawb is not None
172175

173176
# 2. Recompute clip values using new SAWB quantizer
174177
w_cv_key = layer_name + ".quantize_weight.clip_val"
@@ -180,14 +183,33 @@ def recompute_weight_with_sawb(
180183
logger.info(
181184
f" {'Overwrite' if w_cvn_key in new_sd else 'Add'} key: {w_cvn_key}"
182185
)
183-
new_sd[w_cv_key] = quantizer.clip_val.to("cpu").to(torch.float16)
184-
new_sd[w_cvn_key] = -quantizer.clip_val.to("cpu").to(torch.float16)
186+
187+
cv_sawb = quantizer.clip_val.to("cpu").to(torch.float16)
188+
if weight_per_channel:
189+
# Select SAWB rows only where clip value does not exceed row max
190+
cv_max = weight_pre_quant.abs().max(dim=-1)[0]
191+
weight_int_guarded = torch.where(
192+
(cv_sawb < cv_max)[:, None],
193+
weight_int_sawb,
194+
weight_int_as_fp,
195+
)
196+
cv_guarded = torch.where(cv_sawb < cv_max, cv_sawb, cv_max)
197+
weight_int_sawb = weight_int_guarded
198+
else:
199+
cv_max = weight_pre_quant.abs().max()
200+
weight_int_guarded = (
201+
weight_int_sawb if cv_sawb < cv_max else weight_int_as_fp
202+
)
203+
cv_guarded = torch.min(cv_sawb, cv_max)
204+
205+
new_sd[w_cv_key] = cv_guarded
206+
new_sd[w_cvn_key] = -cv_guarded
185207

186208
# 3. [optional] Recompute standard deviation of integer weights
187209
if verbose:
188-
weight_int_sawb_as_fp = deepcopy(weight_int_sawb).to(torch.float32)
210+
weight_int_sawb_as_fp = weight_int_guarded.to(torch.float32)
189211
if weight_per_channel:
190-
weight_int_sawb_std_min = weight_int_sawb_as_fp.std(dim=-1)[0].min()
212+
weight_int_sawb_std_min = weight_int_sawb_as_fp.std(dim=-1).min()
191213
if verbose:
192214
logger.info(
193215
" Reprocessed weights "
@@ -204,11 +226,10 @@ def recompute_weight_with_sawb(
204226
f"-> {weight_int_sawb_as_fp_std:.1f}) "
205227
f"and clips of {layer_name + '.weight'}"
206228
)
207-
else:
208-
log_min_std = "min_" if weight_per_channel else ""
209-
log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
210-
if verbose:
211-
logger.info(f" Weights preserved ({log_min_std}std={log_w_std:.1f})")
229+
elif verbose:
230+
log_min_std = "min_" if weight_per_channel else ""
231+
log_w_std = weight_int_std_min if weight_per_channel else weight_int_std
232+
logger.info(f" Weights preserved ({log_min_std}std={log_w_std:.1f})")
212233

213234
return weight_int_sawb, is_w_recomputed
214235

@@ -441,8 +462,20 @@ def save_sd_for_aiu(
441462
logger.info(
442463
"Attention: saving state dictionary without specifying a quantization "
443464
"configuration (qcfg) performs no recomputation for narrow weight "
444-
"distributions and assumes the weight quantizer used was per-tensor."
465+
"distributions and assumes the weight quantizer used was 8-bit per-tensor."
445466
)
467+
else:
468+
nbits_w = qcfg.get("nbits_w", None)
469+
if nbits_w is None:
470+
logger.info(
471+
"Number of bits for weight quantization is not set in qcfg. "
472+
"Assuming default (nbits_w=8)."
473+
)
474+
elif nbits_w != 8:
475+
raise ValueError(
476+
"Saving checkpoint in AIU-compliant format only supports INT8 "
477+
f"quantization for now, but found {nbits_w=} in qcfg."
478+
)
446479

447480
converted_sd = convert_sd_for_aiu(
448481
model=model,

0 commit comments

Comments
 (0)