Skip to content

Commit 0880114

Browse files
committed
Add test of relu grad at zero. Update paper links.
1 parent e88b578 commit 0880114

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

jax/_src/nn/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def relu(x: ArrayLike) -> Array:
6767
6868
For more information see
6969
`Numerical influence of ReLU’(0) on backpropagation
70-
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
70+
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
7171
7272
Args:
7373
x : input array
@@ -84,7 +84,7 @@ def relu(x: ArrayLike) -> Array:
8484
8585
"""
8686
return jnp.maximum(x, 0)
87-
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
87+
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
8888
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
8989

9090
@jax.jit

tests/nn_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ def testReluGrad(self):
317317
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
318318
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
319319

320+
def testReluGradAtZero(self):
321+
# https://dl.acm.org/doi/10.5555/3540261.3540297
322+
grad = jax.grad(nn.relu)(0.)
323+
self.assertEqual(grad, 0.)
324+
320325
def testRelu6Grad(self):
321326
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
322327
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)

0 commit comments

Comments
 (0)