Skip to content

Commit ca44bf0

Browse files
committed
Use variable for std threshold and save zero_shift as FP32
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent e294bf9 commit ca44bf0

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
# pylint: disable=logging-not-lazy
3131

3232

33+
# empirical threshold to standard deviation of INT weights to trigger their recomputation
34+
STD_THRESHOLD = 20
35+
3336
logger = logging.getLogger()
3437

3538

@@ -137,11 +140,11 @@ def recompute_weight_with_sawb(
137140
# recompute if any channel shows narrow int weights
138141
weight_int_std = weight_int_as_fp.std(dim=-1)
139142
weight_int_std_min = weight_int_std.min()
140-
recompute = any(w < 20 for w in weight_int_std)
143+
recompute = any(w < STD_THRESHOLD for w in weight_int_std)
141144
else:
142145
# recompute if full tensor shows narrow int weights
143146
weight_int_std = weight_int_as_fp.std().item()
144-
recompute = weight_int_std < 20
147+
recompute = weight_int_std < STD_THRESHOLD
145148

146149
if recompute:
147150
is_w_recomputed = True
@@ -283,13 +286,9 @@ def process_zero_shift(
283286
# sum (squash) along in_feat dimension: dim=1
284287
zero_shift = torch.sum(weight_int, dim=1)
285288

286-
# guarding FP16 cast
287-
if zero_shift.abs().max() > torch.finfo(torch.float16).max:
288-
raise ValueError(
289-
f"Zero shift ({k}) exceeds float16 range. "
290-
"Aborted state dict saving."
291-
)
292-
new_sd[k] = zero_shift.to(torch.float16).to("cpu")
289+
# zero shift can exceed FP16 max value, especially if INT weights have
290+
# been recomputed, so it is saved as FP32
291+
new_sd[k] = zero_shift.to(torch.float32).to("cpu")
293292
else:
294293
raise NotImplementedError(
295294
"Zero shift computation for tensor "

0 commit comments

Comments
 (0)