@@ -1056,16 +1056,12 @@ def test_binary(self, f, dtype):
10561056 if jtu .test_device_matches (["tpu" ]) and jnp .dtype (dtype ).itemsize == 2 :
10571057 self .skipTest ("16-bit types are not supported on TPU" )
10581058
1059- # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
1059+ # TODO(ayx): Fix these operations on TPU
10601060 if (
10611061 jtu .test_device_matches (["tpu" ])
1062- and f == jnp .remainder
1063- and not self . INTERPRET
1062+ and f in ( jnp .floor_divide , jnp . subtract )
1063+ and dtype == "uint32"
10641064 ):
1065- self .skipTest ("jnp.remainder on TPU is only supported in interpret mode" )
1066-
1067- # TODO(ayx): fix this on TPU
1068- if jtu .test_device_matches (["tpu" ]) and dtype == "uint32" :
10691065 self .skipTest ("Not supported on TPU" )
10701066
10711067 @functools .partial (
@@ -1092,16 +1088,13 @@ def test_binary_scalar(self, f, dtype):
10921088 self .skipTest ("Test only supported on TPU." )
10931089 if jtu .test_device_matches (["tpu" ]) and jnp .dtype (dtype ).itemsize == 2 :
10941090 self .skipTest ("16-bit types are not supported on TPU" )
1095- # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
1091+
1092+ # TODO(ayx): Fix these operations on TPU
10961093 if (
10971094 jtu .test_device_matches (["tpu" ])
1098- and f == jnp .remainder
1099- and not self . INTERPRET
1095+ and f in ( jnp .floor_divide , jnp . subtract )
1096+ and dtype == "uint32"
11001097 ):
1101- self .skipTest ("jnp.remainder on TPU is only supported in interpret mode" )
1102-
1103- # TODO: skipped due to https://github.com/jax-ml/jax/issues/23972
1104- if jtu .test_device_matches (["tpu" ]) and dtype == "uint32" :
11051098 self .skipTest ("Not supported on TPU" )
11061099
11071100 @functools .partial (
0 commit comments