Skip to content

jax.nn.standardize returns nan when variance is lower than -epsilon #30426

@bondquant

Description

@bondquant

Description

Sirs,

jax.nn.standardize could be easily improved by taking a lower bound on variance instead of adding epsilon. I found a minimal reproducible example and hacked a quick fix (below). Your comment in the code admits lower accuracy, but at least it should not return nans unnecessarily, don't you think?

return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
replace with
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.maximum(jnp.asarray(variance), epsilon))

Example:

import jax
from jax import numpy as jnp, random as jrnd
x = -11. * jnp.ones((3,))
noise = jrnd.normal(jrnd.key(0)) * 2e-6
jax.nn.standardize(x + noise)
# returns Array([nan, nan, nan], dtype=float32)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.2
jaxlib: 0.5.1
numpy:  2.0.2
python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='5453835f9caa', release='6.1.123+', version='#1 SMP PREEMPT_DYNAMIC Sun Mar 30 16:01:29 UTC 2025', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions