@@ -339,8 +339,8 @@ def scipy_fun(z):
339339 )
340340 @jtu .ignore_warning (category = DeprecationWarning , message = ".*scipy.special.lpmn.*" )
341341 def testLpmn (self , l_max , shape , dtype ):
342- if jtu .is_device_tpu ( 6 , "e" ):
343- self .skipTest ("TODO(b/364258243): fails on TPU v6e " )
342+ if jtu .is_device_tpu_at_least ( 6 ):
343+ self .skipTest ("TODO(b/364258243): fails on TPU v6+ " )
344344 rng = jtu .rand_uniform (self .rng (), low = - 0.2 , high = 0.9 )
345345 args_maker = lambda : [rng (shape , dtype )]
346346
@@ -461,8 +461,8 @@ def testSphHarmOrderOneDegreeOne(self):
461461 @jax .numpy_dtype_promotion ('standard' ) # This test explicitly exercises dtype promotion
462462 def testSphHarmForJitAndAgainstNumpy (self , l_max , num_z , dtype ):
463463 """Tests against JIT compatibility and Numpy."""
464- if jtu .is_device_tpu ( 6 , "e" ):
465- self .skipTest ("TODO(b/364258243): fails on TPU v6e " )
464+ if jtu .is_device_tpu_at_least ( 6 ):
465+ self .skipTest ("TODO(b/364258243): fails on TPU v6+ " )
466466 n_max = l_max
467467 shape = (num_z ,)
468468 rng = jtu .rand_int (self .rng (), - l_max , l_max + 1 )
@@ -508,8 +508,8 @@ def testSphHarmCornerCaseWithWrongNmax(self):
508508 )
509509 @jax .numpy_dtype_promotion ('standard' ) # This test explicitly exercises dtype promotion
510510 def testSphHarmY (self , l_max , num_z , dtype ):
511- if jtu .is_device_tpu ( 6 , "e" ):
512- self .skipTest ("TODO(b/364258243): fails on TPU v6e " )
511+ if jtu .is_device_tpu_at_least ( 6 ):
512+ self .skipTest ("TODO(b/364258243): fails on TPU v6+ " )
513513 n_max = l_max
514514 shape = (num_z ,)
515515 rng = jtu .rand_int (self .rng (), - l_max , l_max + 1 )
0 commit comments