Skip to content

fix: guard against division by zero in masked_var Bessel correction#1262

Open
kbhujbal wants to merge 1 commit intogoogle:mainfrom
kbhujbal:fix/ppo-masked-var-div-by-zero
Open

fix: guard against division by zero in masked_var Bessel correction#1262
kbhujbal wants to merge 1 commit intogoogle:mainfrom
kbhujbal:fix/ppo-masked-var-div-by-zero

Conversation

@kbhujbal
Copy link

When mask_sum == 1, Bessel's correction computes mask_sum / (mask_sum - 1) which is 1/0 = inf. This propagates through masked_whiten and corrupts advantage normalization in PPO training.

Use jnp.where to fall back to uncorrected variance (bessel_corr=1.0) when mask_sum <= 1, since Bessel's correction is undefined for n <= 1.

Adds regression tests for single element and empty masks.

Changes

masked_var in tunix/rl/ppo/ppo_helpers.py computes Bessel's correction as mask_sum / (mask_sum - 1). When only one token is unmasked (mask_sum == 1), this evaluates to 1/0 = inf, which propagates through masked_whitencompute_gae_advantages and corrupts PPO advantage normalization.

This is triggered in practice when a batch contains very short completions or highly sparse completion masks.

Fix

Use jnp.where to fall back to uncorrected variance (bessel_corr = 1.0) when mask_sum <= 1, since Bessel's correction is mathematically undefined for n <= 1:

bessel_corr = jnp.where(mask_sum > 1, mask_sum / (mask_sum - 1), 1.0)

This is JIT-safe (no Python level branching) and preserves the existing behavior for all cases where mask_sum > 1.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

When mask_sum == 1, Bessel's correction computes mask_sum / (mask_sum - 1)
which is 1/0 = inf. This propagates through masked_whiten and corrupts
advantage normalization in PPO training.

Use jnp.where to fall back to uncorrected variance (bessel_corr=1.0)
when mask_sum <= 1, since Bessel's correction is undefined for n <= 1.

Adds regression tests for single-element and empty masks.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical numerical stability issue in the masked_var function, which is used in PPO training. Previously, when only a single element was unmasked, the Bessel's correction calculation would result in division by zero, propagating inf values and corrupting advantage normalization. The changes introduce a conditional application of Bessel's correction, ensuring finite variance values in edge cases and enhancing the robustness of the PPO implementation.

Highlights

  • Bessel Correction Safety: Implemented a guard in masked_var to prevent division by zero when mask_sum is 1, which previously led to inf values.
  • Numerical Stability: Ensured that masked_whiten and subsequent PPO advantage normalization remain numerically stable by handling edge cases where Bessel's correction is undefined.
  • Test Coverage: Added new regression tests to specifically cover scenarios with single-element and empty masks, verifying the fix.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a division-by-zero error in masked_var when calculating Bessel's correction for a single data point. The use of jnp.where is an idiomatic and JIT-safe way to handle this edge case in JAX. The addition of regression tests for single-element and empty masks is also excellent. I have a couple of suggestions to make the new tests more specific by asserting the exact expected output (0.0) instead of just checking for finiteness, which will make them more robust.

x = np.array([1.0, 2.0, 3.0])
mask = np.array([False, True, False])
computed = ppo_helpers.masked_var(x, mask=mask)
self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test correctly checks for finiteness, which solves the immediate bug. However, this test can be made more specific. For a single unmasked element, the variance should be exactly 0. Asserting this would make the test more robust.

Suggested change
self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}')
self.assertEqual(computed, 0.0)

x = np.array([1.0, 2.0, 3.0])
mask = np.array([False, False, False])
computed = ppo_helpers.masked_var(x, mask=mask)
self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the single-element case, this test can be made more specific. When all elements are masked, the variance should be 0. Asserting for 0.0 is a stronger guarantee than just checking for finiteness.

Suggested change
self.assertTrue(jnp.isfinite(computed), f'Expected finite, got {computed}')
self.assertEqual(computed, 0.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants