Skip to content

Commit e294bf9

Browse files
committed
Add max value guarding vs FP16 range for zero_shift
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent a6f6fcf commit e294bf9

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,15 @@ def process_zero_shift(
281281
elif weight_int.dim() == 2:
282282
# weight_int: [out_feat, in_feat]
283283
# sum (squash) along in_feat dimension: dim=1
284-
new_sd[k] = (
285-
torch.sum(
286-
weight_int,
287-
dim=1,
284+
zero_shift = torch.sum(weight_int, dim=1)
285+
286+
# guarding FP16 cast
287+
if zero_shift.abs().max() > torch.finfo(torch.float16).max:
288+
raise ValueError(
289+
f"Zero shift ({k}) exceeds float16 range. "
290+
"Aborted state dict saving."
288291
)
289-
.to(torch.float16)
290-
.to("cpu")
291-
)
292+
new_sd[k] = zero_shift.to(torch.float16).to("cpu")
292293
else:
293294
raise NotImplementedError(
294295
"Zero shift computation for tensor "

0 commit comments

Comments
 (0)