[Fix] Mixed Precision (float16) numerical instability in GroupNormalization with small epsilon#22589
[Fix] Mixed Precision (float16) numerical instability in GroupNormalization with small epsilon#22589ChiragSW wants to merge 5 commits intokeras-team:masterfrom
Conversation
…zation with small epsilon
There was a problem hiding this comment.
Code Review
This pull request updates the GroupNormalization layer to improve numerical stability during mixed precision training. It disables automatic casting for the layer and its weights (gamma and beta) and adds an explicit cast to compute_dtype at the end of the call method. New test cases are included to ensure that large input values do not result in NaNs when running within a float16 autocast scope. I have no feedback to provide.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22589 +/- ##
==========================================
- Coverage 83.28% 83.28% -0.01%
==========================================
Files 596 596
Lines 68089 68110 +21
Branches 10607 10611 +4
==========================================
+ Hits 56711 56728 +17
- Misses 8634 8637 +3
- Partials 2744 2745 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@hertschuh please review |
| ): | ||
| super().__init__(**kwargs) | ||
| self.supports_masking = True | ||
| self.autocast = False |
There was a problem hiding this comment.
We should not hardcode self.autocast = False. The fix is indeed to do this:
keras.layers.GroupNormalization(groups=8, epsilon=1e-12, autocast=False)But this should be controlled by users, not hardcoded.
The contract of autocast is to accept lower precision to improve speed, and that option should remain open to people who want it.
Now, we could print a warning if the epsilon is lower than the precision, because this is not achievable.
Root Cause
With autocast=True (it is true in default), inputs were cast to float16 before reaching call(). Values exceeding float16's max (65504) overflowed to inf, causing NaN propagation through normalization math. The existing internal float32 upcast couldn't recover already-lost values.
Fix
self.autocast = Falsekeeps inputs in their original dtype (float32), preventing overflowautocast=Falseon gamma/beta weights stores weights in float32 for precisionops.cast(outputs, self.compute_dtype)returns proper float16 output for mixed precisionI also added regression tests:
test_large_value_within_autocast_scope: verifies weights aren't corrupted by autocast (same test in BatchNormalization and LayerNormalization)test_mixed_float16_large_inputs: catches actual NaN bug