Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def standardize(x: ArrayLike,
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
variance = lax.abs(variance)
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(variance + epsilon)

# TODO(slebedev): Change the type of `x` to `ArrayLike`.
@api.jit(static_argnames=("num_classes", "dtype", "axis"))
Expand Down
20 changes: 20 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,26 @@ def testLog1mExpGrad(self):
atol=1e-3,
)

def testStandardizeNegativeVariance(self):
# This input provokes a negative variance due to floating point error
x = jnp.array([-11., -11., -11.]) + 3e-6
result = jax.nn.standardize(x)
self.assertFalse(jnp.any(jnp.isnan(result)))

def testStandardizeGradientStability(self):
# Verifies that fixing negative variance via abs() preserves gradients,
# unlike clipping which would result in dead gradients (zeros).
x = jnp.array([-11., -11., -11.]) + 3e-6

def loss(input_x):
return jnp.sum(jax.nn.standardize(input_x))

grads = jax.grad(loss)(x)

# Assert gradients are not NaN
self.assertFalse(jnp.any(jnp.isnan(grads)))
# Assert gradients are not zero (dead)
self.assertFalse(jnp.all(grads == 0), "Gradients should not be zero (dead neuron issue)")

InitializerRecord = collections.namedtuple(
"InitializerRecord",
Expand Down