-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Sirs,
jax.nn.standardize could be easily improved by taking a lower bound on variance instead of adding epsilon. I found a minimal reproducible example and hacked a quick fix (below). Your comment in the code admits lower accuracy, but at least it should not return nans unnecessarily, don't you think?
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
replace with
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.maximum(jnp.asarray(variance), epsilon))
Example:
import jax
from jax import numpy as jnp, random as jrnd
x = -11. * jnp.ones((3,))
noise = jrnd.normal(jrnd.key(0)) * 2e-6
jax.nn.standardize(x + noise)
# returns Array([nan, nan, nan], dtype=float32)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.2
jaxlib: 0.5.1
numpy: 2.0.2
python: 3.11.13 (main, Jun 4 2025, 08:57:29) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='5453835f9caa', release='6.1.123+', version='#1 SMP PREEMPT_DYNAMIC Sun Mar 30 16:01:29 UTC 2025', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working