Skip to content

Conversation

@copybara-service
Copy link

jax.nn.standardize fix NaN output.

Fixes #30426. Using abs here is probably better than any solution based on clip or maximum, because it maintains differentiability in the valid regime. I'll run some additional tests before merging this in order to ensure that it doesn't change numerics in important contexts.

Fixes #30426. Using `abs` here is probably better than any solution based on `clip` or `maximum`, because it maintains differentiability in the valid regime. I'll run some additional tests before merging this in order to ensure that it doesn't change numerics in important contexts.

PiperOrigin-RevId: 836332119
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.nn.standardize returns nan when variance is lower than -epsilon

1 participant