Skip to content

Commit f747112

Browse files
Fix lax_autodiff_test on v5p
PiperOrigin-RevId: 738565192
1 parent 16dc0ad commit f747112

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/lax_autodiff_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ class LaxAutodiffTest(jtu.JaxTestCase):
205205
))
206206
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
207207
rng = rng_factory(self.rng())
208-
if jtu.test_device_matches(["cpu"]):
208+
if jtu.test_device_matches(["cpu", "tpu"]):
209209
if op is lax.cosh and dtype == np.complex64:
210-
tol = 3e-1 # 2nd-order gradients are noisy on CPU
210+
tol = 3e-1 # 2nd-order gradients are noisy on CPU and TPU
211211
if jtu.test_device_matches(["tpu"]):
212212
if op is lax.pow:
213213
raise SkipTest("pow grad imprecise on tpu")
214214
if op is lax.cos:
215215
order = 1 # 2nd-order gradient is imprecise on TPU.
216+
if op is lax.sin:
217+
order = 1 # 2nd-order gradient is imprecise on TPUv5p.
216218
if op is lax.log:
217219
order = 1 # 2nd-order gradient is imprecise on TPU.
218220

0 commit comments

Comments
 (0)