|
30 | 30 | # pylint: disable=logging-not-lazy |
31 | 31 |
|
32 | 32 |
|
| 33 | +# empirical threshold to standard deviation of INT weights to trigger their recomputation |
| 34 | +STD_THRESHOLD = 20 |
| 35 | + |
33 | 36 | logger = logging.getLogger() |
34 | 37 |
|
35 | 38 |
|
@@ -137,11 +140,11 @@ def recompute_weight_with_sawb( |
137 | 140 | # recompute if any channel shows narrow int weights |
138 | 141 | weight_int_std = weight_int_as_fp.std(dim=-1) |
139 | 142 | weight_int_std_min = weight_int_std.min() |
140 | | - recompute = any(w < 20 for w in weight_int_std) |
| 143 | + recompute = any(w < STD_THRESHOLD for w in weight_int_std) |
141 | 144 | else: |
142 | 145 | # recompute if full tensor shows narrow int weights |
143 | 146 | weight_int_std = weight_int_as_fp.std().item() |
144 | | - recompute = weight_int_std < 20 |
| 147 | + recompute = weight_int_std < STD_THRESHOLD |
145 | 148 |
|
146 | 149 | if recompute: |
147 | 150 | is_w_recomputed = True |
@@ -283,13 +286,9 @@ def process_zero_shift( |
283 | 286 | # sum (squash) along in_feat dimension: dim=1 |
284 | 287 | zero_shift = torch.sum(weight_int, dim=1) |
285 | 288 |
|
286 | | - # guarding FP16 cast |
287 | | - if zero_shift.abs().max() > torch.finfo(torch.float16).max: |
288 | | - raise ValueError( |
289 | | - f"Zero shift ({k}) exceeds float16 range. " |
290 | | - "Aborted state dict saving." |
291 | | - ) |
292 | | - new_sd[k] = zero_shift.to(torch.float16).to("cpu") |
| 289 | + # zero shift can exceed FP16 max value, especially if INT weights have |
| 290 | + # been recomputed, so it is saved as FP32 |
| 291 | + new_sd[k] = zero_shift.to(torch.float32).to("cpu") |
293 | 292 | else: |
294 | 293 | raise NotImplementedError( |
295 | 294 | "Zero shift computation for tensor " |
|
0 commit comments