@@ -803,7 +803,15 @@ def kernel(x_ref, o_ref):
803803 ELEMENTWISE_OPS = [
804804 (
805805 [jnp .abs , jnp .negative ],
806- ["int16" , "int32" , "int64" , "float16" , "float32" , "float64" ],
806+ [
807+ "int16" ,
808+ "int32" ,
809+ "int64" ,
810+ "bfloat16" ,
811+ "float16" ,
812+ "float32" ,
813+ "float64" ,
814+ ],
807815 ),
808816 ([jnp .ceil , jnp .floor ], ["bfloat16" , "float32" , "float64" , "int32" ]),
809817 (
@@ -819,7 +827,7 @@ def kernel(x_ref, o_ref):
819827 ["float32" , "float64" ],
820828 ),
821829 ([lax .population_count , lax .clz , jnp .invert ], ["int32" , "int64" ]),
822- ([jnp .logical_not ], ["bool" ])
830+ ([jnp .logical_not ], ["bool" ]),
823831 ]
824832
825833 @parameterized .named_parameters (
@@ -831,8 +839,21 @@ def test_elementwise(self, fn, dtype):
831839 if not jax .config .x64_enabled and jnp .dtype (dtype ).itemsize == 8 :
832840 self .skipTest ("64-bit types require x64_enabled" )
833841
834- if jtu .test_device_matches (["tpu" ]) and dtype in ("int16" , "float16" ):
835- self .skipTest ("int16 and float16 are not supported on TPU" )
842+ if jtu .test_device_matches (["tpu" ]):
843+ if dtype in ("int16" , "float16" ):
844+ self .skipTest ("int16 and float16 are not supported on TPU" )
845+ if (
846+ fn in (jnp .ceil , jnp .floor , jnp .negative )
847+ and dtype == "bfloat16"
848+ and not jtu .is_device_tpu_at_least (6 )
849+ ):
850+ self .skipTest (f"bfloat16 { fn .__name__ } is only supported on TPU v6+" )
851+ # TODO(b/370578663): implement these lowerings on TPU
852+ if fn in (
853+ jnp .acos , jnp .acosh , jnp .asin , jnp .asinh , jnp .atan , jnp .atanh ,
854+ jnp .cbrt , jnp .cosh , jnp .expm1 , jnp .sinh ,
855+ ):
856+ self .skipTest (f"{ fn .__name__ } not implemented on TPU" )
836857
837858 if (
838859 jtu .test_device_matches (["gpu" ])
@@ -841,21 +862,6 @@ def test_elementwise(self, fn, dtype):
841862 ):
842863 self .skipTest (f"bfloat16 { fn .__name__ } is not supported on GPU" )
843864
844- if (
845- jtu .test_device_matches (["tpu" ])
846- and not jtu .is_device_tpu_at_least (6 )
847- and fn in (jnp .ceil , jnp .floor )
848- and dtype == "bfloat16"
849- ):
850- self .skipTest (f"bfloat16 { fn .__name__ } is only supported on TPU v6+" )
851-
852- # TODO(b/370578663): implement these lowerings on TPU
853- if jtu .test_device_matches (["tpu" ]) and fn in (
854- jnp .acos , jnp .acosh , jnp .asin , jnp .asinh , jnp .atan , jnp .atanh ,
855- jnp .cbrt , jnp .cosh , jnp .expm1 , jnp .sinh ,
856- ):
857- self .skipTest (f"{ fn .__name__ } not implemented on TPU" )
858-
859865 @functools .partial (
860866 self .pallas_call ,
861867 out_shape = jax .ShapeDtypeStruct ((8 , 128 ), dtype ),
0 commit comments