Skip to content

Commit 5d78549

Browse files
committed
Fix ckpt saving bug
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent ca44bf0 commit 5d78549

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def process_smoothquant(
122122

123123

124124
def recompute_weight_with_sawb(
125+
weight_pre_quant: torch.Tensor,
125126
weight_int_as_fp: torch.Tensor,
126127
weight_per_channel: bool,
127128
sq_a_scale: torch.Tensor | None,
@@ -165,7 +166,7 @@ def recompute_weight_with_sawb(
165166
quantizer.training = True # set SAWB to recompute clips
166167
# some SAWB quantizers only process FP32 inputs, so weights are
167168
# temporarily upscaled
168-
weight_int_sawb = quantizer(weight_int_as_fp.to(torch.float32))
169+
weight_int_sawb = quantizer(weight_pre_quant.to(torch.float32))
169170

170171
# 2. Recompute clip values using new SAWB quantizer
171172
w_cv_key = layer_name + ".quantize_weight.clip_val"
@@ -181,7 +182,7 @@ def recompute_weight_with_sawb(
181182
new_sd[w_cvn_key] = -quantizer.clip_val.to("cpu").to(torch.float16)
182183

183184
# 3. [optional] Recompute standard deviation of integer weights
184-
if verbose and weight_int_sawb is not None:
185+
if verbose:
185186
weight_int_sawb_as_fp = deepcopy(weight_int_sawb).to(torch.float32)
186187
if weight_per_channel:
187188
weight_int_sawb_std_min = weight_int_sawb_as_fp.std(dim=-1)[0].min()
@@ -236,6 +237,7 @@ def process_weight(
236237
weight_int_sawb = None
237238
if recompute_narrow_weights:
238239
weight_int_sawb, is_w_recomputed = recompute_weight_with_sawb(
240+
weight_pre_quant,
239241
weight_int_as_fp,
240242
weight_per_channel,
241243
sq_a_scale,
@@ -378,13 +380,17 @@ def convert_sd_for_aiu(
378380
process_zero_shift(model, layer_name, weight_int, new_sd)
379381

380382
elif all(excluded_key not in k for excluded_key in excluded_keys_from_new_sd):
381-
# guarding FP16 cast
382-
if v.abs().max() > torch.finfo(torch.float16).max:
383-
raise ValueError(
384-
f"Quantization parameters ({k}) exceeds float16 range. "
385-
"Aborted state dict saving."
386-
)
387-
new_sd[k] = v.to("cpu").to(torch.float16)
383+
if k not in new_sd:
384+
# guarding FP16 cast
385+
if v.abs().max() > torch.finfo(torch.float16).max:
386+
raise ValueError(
387+
f"Quantization parameters ({k}) exceeds float16 range. "
388+
"Aborted state dict saving."
389+
)
390+
logger.info(f" Save key: {k}")
391+
new_sd[k] = v.to("cpu").to(torch.float16)
392+
else:
393+
logger.info(f" Skip parameter already processed: {k}")
388394

389395
logger.info("New state dict processed.")
390396
if verbose:

0 commit comments

Comments
 (0)