fix: guard against division by zero in masked_var Bessel correction#1262
fix: guard against division by zero in masked_var Bessel correction#1262kbhujbal wants to merge 1 commit intogoogle:mainfrom
Conversation
When mask_sum == 1, Bessel's correction computes mask_sum / (mask_sum - 1) which is 1/0 = inf. This propagates through masked_whiten and corrupts advantage normalization in PPO training. Use jnp.where to fall back to uncorrected variance (bessel_corr=1.0) when mask_sum <= 1, since Bessel's correction is undefined for n <= 1. Adds regression tests for single-element and empty masks.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical numerical stability issue in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a division-by-zero error in masked_var when calculating Bessel's correction for a single data point. The use of jnp.where is an idiomatic and JIT-safe way to handle this edge case in JAX. The addition of regression tests for single-element and empty masks is also excellent. I have a couple of suggestions to make the new tests more specific by asserting the exact expected output (0.0) instead of just checking for finiteness, which will make them more robust.
| x = np.array([1.0, 2.0, 3.0]) | ||
| mask = np.array([False, True, False]) | ||
| computed = ppo_helpers.masked_var(x, mask=mask) | ||
| self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}') |
There was a problem hiding this comment.
The test correctly checks for finiteness, which solves the immediate bug. However, this test can be made more specific. For a single unmasked element, the variance should be exactly 0. Asserting this would make the test more robust.
| self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}') | |
| self.assertEqual(computed, 0.0) |
| x = np.array([1.0, 2.0, 3.0]) | ||
| mask = np.array([False, False, False]) | ||
| computed = ppo_helpers.masked_var(x, mask=mask) | ||
| self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}') |
There was a problem hiding this comment.
Similar to the single-element case, this test can be made more specific. When all elements are masked, the variance should be 0. Asserting for 0.0 is a stronger guarantee than just checking for finiteness.
| self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}') | |
| self.assertEqual(computed, 0.0) |
When
mask_sum == 1, Bessel's correction computesmask_sum / (mask_sum - 1)which is1/0 = inf. This propagates throughmasked_whitenand corrupts advantage normalization in PPO training.Use
jnp.whereto fall back to uncorrected variance (bessel_corr=1.0) whenmask_sum <= 1, since Bessel's correction is undefined forn <= 1.Adds regression tests for single element and empty masks.
Changes
masked_varintunix/rl/ppo/ppo_helpers.pycomputes Bessel's correction asmask_sum / (mask_sum - 1). When only one token is unmasked (mask_sum == 1), this evaluates to1/0 = inf, which propagates throughmasked_whiten→compute_gae_advantagesand corrupts PPO advantage normalization.This is triggered in practice when a batch contains very short completions or highly sparse completion masks.
Fix
Use
jnp.whereto fall back to uncorrected variance (bessel_corr = 1.0) whenmask_sum <= 1, since Bessel's correction is mathematically undefined forn <= 1:This is JIT-safe (no Python level branching) and preserves the existing behavior for all cases where mask_sum > 1.
Checklist