@@ -742,7 +742,7 @@ def kernel(x_ref, o_ref):
742742 [jnp .abs , jnp .negative ],
743743 ["int16" , "int32" , "int64" , "float16" , "float32" , "float64" ],
744744 ),
745- ([jnp .ceil , jnp .floor ], ["float32" , "float64" , "int32" ]),
745+ ([jnp .ceil , jnp .floor ], ["bfloat16" , " float32" , "float64" , "int32" ]),
746746 (
747747 [jnp .exp , jnp .exp2 , jnp .sin , jnp .cos , jnp .log , jnp .sqrt ],
748748 ["float16" , "float32" , "float64" ],
@@ -767,8 +767,23 @@ def test_elementwise(self, fn, dtype):
767767 if not jax .config .x64_enabled and jnp .dtype (dtype ).itemsize == 8 :
768768 self .skipTest ("64-bit types require x64_enabled" )
769769
770- if jtu .test_device_matches (["tpu" ]) and jnp .dtype (dtype ).itemsize == 2 :
771- self .skipTest ("16-bit types are not supported on TPU" )
770+ if jtu .test_device_matches (["tpu" ]) and dtype in ("int16" , "float16" ):
771+ self .skipTest ("int16 and float16 are not supported on TPU" )
772+
773+ if (
774+ jtu .test_device_matches (["gpu" ])
775+ and fn in (jnp .ceil , jnp .floor )
776+ and dtype == "bfloat16"
777+ ):
778+ self .skipTest (f"bfloat16 { fn .__name__ } is not supported on GPU" )
779+
780+ if (
781+ jtu .test_device_matches (["tpu" ])
782+ and not jtu .is_device_tpu_at_least (6 )
783+ and fn in (jnp .ceil , jnp .floor )
784+ and dtype == "bfloat16"
785+ ):
786+ self .skipTest (f"bfloat16 { fn .__name__ } is only supported on TPU v6+" )
772787
773788 # TODO(b/370578663): implement these lowerings on TPU
774789 if jtu .test_device_matches (["tpu" ]) and fn in (
@@ -784,7 +799,7 @@ def kernel(x_ref, o_ref):
784799 o_ref [:] = fn (x_ref [...])
785800
786801 x = jnp .array ([0.42 , 2.4 ]).astype (dtype )
787- np . testing . assert_allclose (kernel (x ), fn (x ), rtol = 1e-6 )
802+ self . assertAllClose (kernel (x ), fn (x ), rtol = 1e-6 )
788803
789804 @parameterized .named_parameters (
790805 (f"{ fn .__name__ } _{ dtype } " , fn , dtype )
@@ -798,6 +813,13 @@ def test_elementwise_scalar(self, fn, dtype):
798813 if jtu .test_device_matches (["tpu" ]) and jnp .dtype (dtype ).itemsize == 2 :
799814 self .skipTest ("16-bit types are not supported on TPU" )
800815
816+ if (
817+ jtu .test_device_matches (["gpu" ])
818+ and fn in (jnp .ceil , jnp .floor )
819+ and dtype == "bfloat16"
820+ ):
821+ self .skipTest (f"bfloat16 { fn .__name__ } is not supported on GPU" )
822+
801823 if (
802824 jtu .test_device_matches (["tpu" ])
803825 and fn == lax .population_count
@@ -826,7 +848,7 @@ def kernel(x_ref, o_ref):
826848 o_ref [1 ] = fn (x_ref [1 ])
827849
828850 x = jnp .array ([0.42 , 2.4 ]).astype (dtype )
829- np . testing . assert_allclose (kernel (x ), fn (x ), rtol = 1e-6 )
851+ self . assertAllClose (kernel (x ), fn (x ), rtol = 1e-6 )
830852
831853 def test_abs_weak_type (self ):
832854 # see https://github.com/jax-ml/jax/issues/23191
0 commit comments