Skip to content

Commit b8df73f

Browse files
committed
Fix handling of smoothq in conversion and add verbosity option
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 199e3d1 commit b8df73f

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def process_smoothquant(
8989
model: PreTrainedModel,
9090
layer_name: str,
9191
new_sd: dict,
92+
verbose: bool = False,
9293
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
9394
"""Check if smoothquant was in use and, if so:
9495
1. compute combined weight/activation scaling factor
@@ -115,9 +116,10 @@ def process_smoothquant(
115116
"Quantization parameters (qscale) exceeds float16 range. "
116117
"Aborted state dict saving."
117118
)
118-
new_sd[layer_name + ".smoothq_scale"] = (
119-
sq_scale.squeeze().to(torch.float16).to("cpu")
120-
)
119+
k = layer_name + ".smoothq_scale"
120+
if verbose:
121+
logger.info(f" Save key: {k}")
122+
new_sd[k] = sq_scale.squeeze().to(torch.float16).to("cpu")
121123
return weight_scaled, sq_a_scale
122124

123125

@@ -128,7 +130,7 @@ def recompute_weight_with_sawb(
128130
sq_a_scale: torch.Tensor | None,
129131
layer_name: str,
130132
new_sd: dict,
131-
verbose: bool,
133+
verbose: bool = False,
132134
) -> tuple[torch.Tensor | None, bool]:
133135
"""Use SAWB quantizer to recompute weights showing narrow distributions in the
134136
integer domain.
@@ -219,7 +221,7 @@ def process_weight(
219221
weight_per_channel: bool,
220222
sq_a_scale: torch.Tensor | None,
221223
new_sd: dict,
222-
verbose: bool,
224+
verbose: bool = False,
223225
) -> tuple[torch.Tensor | None, bool | None]:
224226
"""Compute integer weights and store them into new state dictionary.
225227
If recomputation is enabled, int weights are updated using SAWB quantizer.
@@ -259,6 +261,7 @@ def process_zero_shift(
259261
layer_name: str,
260262
weight_int: torch.Tensor | None,
261263
new_sd: dict,
264+
verbose: bool = False,
262265
) -> None:
263266
"""Compute and store the zero shift, a correction factor that compensates the
264267
output of (W integer, X integer) matmuls to match the corresponding FP operation.
@@ -288,6 +291,9 @@ def process_zero_shift(
288291
# sum (squash) along in_feat dimension: dim=1
289292
zero_shift = torch.sum(weight_int, dim=1)
290293

294+
if verbose:
295+
logger.info(f" Save key: {k}")
296+
291297
# zero shift can exceed FP16 max value, especially if INT weights have
292298
# been recomputed, so it is saved as FP32
293299
new_sd[k] = zero_shift.to(torch.float32).to("cpu")
@@ -360,6 +366,7 @@ def convert_sd_for_aiu(
360366
model=model,
361367
layer_name=layer_name,
362368
new_sd=new_sd,
369+
verbose=verbose,
363370
)
364371

365372
weight_int, is_w_recomputed = process_weight(
@@ -377,7 +384,13 @@ def convert_sd_for_aiu(
377384
else:
378385
num_w_preserved += 1
379386

380-
process_zero_shift(model, layer_name, weight_int, new_sd)
387+
process_zero_shift(
388+
model=model,
389+
layer_name=layer_name,
390+
weight_int=weight_int,
391+
new_sd=new_sd,
392+
verbose=verbose,
393+
)
381394

382395
elif all(excluded_key not in k for excluded_key in excluded_keys_from_new_sd):
383396
if k not in new_sd:

0 commit comments

Comments
 (0)