Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def standardize(x: ArrayLike,
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.maximum(jnp.asarray(variance), epsilon))

# TODO(slebedev): Change the type of `x` to `ArrayLike`.
@api.jit(static_argnames=("num_classes", "dtype", "axis"))
Expand Down