Skip to content

Conversation

@nandan2003
Copy link

Description

Fixes #30426.

jax.nn.standardize previously returned NaN when the input array had a variance slightly less than 0 (due to floating point imprecision) but larger in magnitude than -epsilon. In these cases, variance + epsilon remained negative, causing rsqrt to produce NaNs.

This PR changes the normalization logic to clamp the variance using jnp.maximum(variance, epsilon) instead of adding epsilon. This ensures the denominator is always valid.

Reproduction

import jax
import jax.numpy as jnp

# Input that causes slight negative variance due to precision
val = -11.0
x = val * jnp.ones((3,))
noise = jax.random.normal(jax.random.key(0)) * 2e-6
bad_input = x + noise

# Before: Returns [nan, nan, nan]
# After: Returns valid floats
res = jax.nn.standardize(bad_input)
print(res)

@google-cla
Copy link

google-cla bot commented Nov 24, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nandan2003, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the numerical stability of the jax.nn.standardize function by preventing the generation of NaN values. It specifically targets a scenario where floating-point inaccuracies could lead to a computed variance that is slightly negative, causing the reciprocal square root operation to fail. The change ensures that the variance used in the standardization process is always clamped to a non-negative value, thereby making the function more robust and reliable for various inputs.

Highlights

  • Fixes NaN in jax.nn.standardize: This pull request resolves an issue where jax.nn.standardize could return NaN values.
  • Addresses negative variance: The NaNs were caused by floating-point imprecision leading to slightly negative variance values, which then resulted in rsqrt producing NaNs.
  • Clamps variance with jnp.maximum: The fix involves changing the normalization logic to use jnp.maximum(variance, epsilon) instead of variance + epsilon, ensuring the argument to rsqrt is always non-negative and preventing NaNs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a NaN issue in jax.nn.standardize that occurs when floating-point imprecision leads to a small negative variance. The proposed solution of using jnp.maximum(variance, epsilon) instead of variance + epsilon is a robust fix. It correctly clamps the variance at a small positive value, preventing the argument to rsqrt from becoming negative and thus avoiding NaNs in the output. The change is correct and improves the numerical stability of the function. I've added one minor suggestion for code cleanup.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@jakevdp jakevdp self-assigned this Nov 24, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 24, 2025

Thanks for this – I don't think this is the correct fix for the issue. See my comment at #30426 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.nn.standardize returns nan when variance is lower than -epsilon

2 participants