Some tests can be sped up (2x). Is this an acceptable first contribution? ❤️ #9656
-
So, I noticed that here we are testing that FD=AG for # File tests/nn_test.py
def testSoftplusGradNegInf(self):
## This is twice as fast
## ====================== slowest durations =======================
## 0.13s call tests/nn_test.py::NNFunctionsTest::testSoftplusGradNegInf
# self.assertAllClose(
# 0., jax.grad(nn.softplus)(-float('inf')))
## Than this
## ====================== slowest durations =======================
## 0.27s call tests/nn_test.py::NNFunctionsTest::testSoftplusGradNegInf
check_grads(nn.softplus, (-float('inf'),), order=1,
rtol=1e-2 if jtu.device_under_test() == "tpu" else None) I want to carefully optimize tests like this and later add a few more tests. Is a PR like this likely to get accepted? Please consider giving me some pointers to what other janitorial work could be done by a JAX newb like myself. Thanks ❤️ |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It seems that original code check both reverse-mode AD and forward-mode AD, but your code only check reverse-mode AD. |
Beta Was this translation helpful? Give feedback.
It seems that original code check both reverse-mode AD and forward-mode AD, but your code only check reverse-mode AD.