We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a6f6fcf commit e294bf9Copy full SHA for e294bf9
fms_mo/utils/aiu_utils.py
@@ -281,14 +281,15 @@ def process_zero_shift(
281
elif weight_int.dim() == 2:
282
# weight_int: [out_feat, in_feat]
283
# sum (squash) along in_feat dimension: dim=1
284
- new_sd[k] = (
285
- torch.sum(
286
- weight_int,
287
- dim=1,
+ zero_shift = torch.sum(weight_int, dim=1)
+
+ # guarding FP16 cast
+ if zero_shift.abs().max() > torch.finfo(torch.float16).max:
288
+ raise ValueError(
289
+ f"Zero shift ({k}) exceeds float16 range. "
290
+ "Aborted state dict saving."
291
)
- .to(torch.float16)
- .to("cpu")
- )
292
+ new_sd[k] = zero_shift.to(torch.float16).to("cpu")
293
else:
294
raise NotImplementedError(
295
"Zero shift computation for tensor "
0 commit comments