Skip to content

Commit f1ebb1e

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
Skip failing tests on TPU v6+
PiperOrigin-RevId: 741515935
1 parent 28f63ee commit f1ebb1e

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/lax_numpy_reducers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -905,8 +905,8 @@ def testCumulativeSumBool(self):
905905
@jtu.ignore_warning(category=NumpyComplexWarning)
906906
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
907907
def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial):
908-
if jtu.is_device_tpu(6):
909-
raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e")
908+
if jtu.is_device_tpu_at_least(6):
909+
raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6+")
910910
rng = jtu.rand_some_zero(self.rng())
911911

912912
# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as

tests/lax_scipy_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ def scipy_fun(z):
339339
)
340340
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
341341
def testLpmn(self, l_max, shape, dtype):
342-
if jtu.is_device_tpu(6, "e"):
343-
self.skipTest("TODO(b/364258243): fails on TPU v6e")
342+
if jtu.is_device_tpu_at_least(6):
343+
self.skipTest("TODO(b/364258243): fails on TPU v6+")
344344
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
345345
args_maker = lambda: [rng(shape, dtype)]
346346

@@ -461,8 +461,8 @@ def testSphHarmOrderOneDegreeOne(self):
461461
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
462462
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
463463
"""Tests against JIT compatibility and Numpy."""
464-
if jtu.is_device_tpu(6, "e"):
465-
self.skipTest("TODO(b/364258243): fails on TPU v6e")
464+
if jtu.is_device_tpu_at_least(6):
465+
self.skipTest("TODO(b/364258243): fails on TPU v6+")
466466
n_max = l_max
467467
shape = (num_z,)
468468
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
@@ -508,8 +508,8 @@ def testSphHarmCornerCaseWithWrongNmax(self):
508508
)
509509
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
510510
def testSphHarmY(self, l_max, num_z, dtype):
511-
if jtu.is_device_tpu(6, "e"):
512-
self.skipTest("TODO(b/364258243): fails on TPU v6e")
511+
if jtu.is_device_tpu_at_least(6):
512+
self.skipTest("TODO(b/364258243): fails on TPU v6+")
513513
n_max = l_max
514514
shape = (num_z,)
515515
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

0 commit comments

Comments
 (0)