Skip to content

Commit a2e9699

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Fix some flaky LAX autodiff tests.
The primitive autodiff tests for `lax.cosh` and `lax.cbrt` were disabled in #25443 because they were found to be flaky in some configurations. This change re-enables these tests with the following updates: 1. For `lax.cbrt`, we now test with random inputs offset from zero since the derivative of `cbrt` is steep at the origin. 2. For `lax.cosh`, we (further) relax the test tolerance for `complex64` types. PiperOrigin-RevId: 706758349
1 parent 2b06f93 commit a2e9699

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/lax_autodiff_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import jax
2828
from jax import dtypes
2929
from jax import lax
30-
from jax._src import config
3130
from jax._src import test_util as jtu
3231
from jax._src.util import NumpyComplexWarning
3332
from jax.test_util import check_grads
@@ -134,7 +133,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
134133
dtypes=grad_float_dtypes),
135134
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
136135
dtypes=grad_complex_dtypes, tol={np.float64: 2e-3}),
137-
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
136+
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_not_small,
138137
dtypes=grad_float_dtypes, tol={np.float64: 5e-3}),
139138
grad_test_spec(lax.logistic, nargs=1, order=2,
140139
rng_factory=jtu.rand_default,
@@ -206,8 +205,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
206205
))
207206
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
208207
rng = rng_factory(self.rng())
209-
if jtu.test_device_matches(["cpu"]) and (op is lax.cosh or op is lax.cbrt) and config.enable_x64.value:
210-
raise SkipTest("cosh and cbrt grad fails in x64 mode on CPU") # b/383756018
208+
if jtu.test_device_matches(["cpu"]):
209+
if op is lax.cosh and dtype == np.complex64:
210+
tol = 3e-1 # 2nd-order gradients are noisy on CPU
211211
if jtu.test_device_matches(["tpu"]):
212212
if op is lax.pow:
213213
raise SkipTest("pow grad imprecise on tpu")

0 commit comments

Comments
 (0)