@@ -1895,35 +1895,35 @@ def reduce(x_ref, y_ref):
18951895 for axis in [0 , 1 , (1 ,), (0 , 1 )]
18961896 for dtype in [
18971897 "float16" ,
1898+ "bfloat16" ,
18981899 "float32" ,
18991900 "float64" ,
19001901 "int32" ,
19011902 "int64" ,
19021903 "uint32" ,
19031904 "uint64" ,
19041905 ]
1905- if isinstance (axis , int ) or "arg" not in op_name
19061906 ])
19071907 def test_array_reduce (self , op , dtype , axis ):
1908- if jtu . test_device_matches ([ "tpu" ]) and jnp . dtype ( dtype ). itemsize == 2 :
1909- self .skipTest ("16-bit types are not supported on TPU " )
1908+ if not isinstance ( axis , int ) :
1909+ self .skipTest ("TODO: tuple axes are not yet supported " )
19101910
19111911 if not jax .config .x64_enabled and jnp .dtype (dtype ).itemsize == 8 :
19121912 self .skipTest ("64-bit types require x64_enabled" )
19131913
1914+ # The Pallas TPU lowering currently supports only blocks of rank >= 1
1915+ if jtu .test_device_matches (["tpu" ]):
1916+ self .skipTest ("Not implemented on TPU" )
1917+
19141918 # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
19151919 # `index_type` to be i32
19161920 if (
19171921 jax .config .x64_enabled
19181922 and jtu .test_device_matches (["gpu" ])
1919- and op in { jnp .argmin , jnp .argmax }
1923+ and op in ( jnp .argmin , jnp .argmax )
19201924 ):
19211925 self .skipTest ("Not supported on GPU in 64-bit mode" )
19221926
1923- # The Pallas TPU lowering currently supports only blocks of rank >= 1
1924- if jtu .test_device_matches (["tpu" ]):
1925- self .skipTest ("Not supported on TPU" )
1926-
19271927 m , n = 32 , 8
19281928
19291929 def make_x (key ):
@@ -1955,7 +1955,7 @@ def reduce(x_ref, y_ref):
19551955 x = make_x (key )
19561956 y = reduce (x )
19571957 y_ref = op (x , axis = axis )
1958- np . testing . assert_allclose (y , y_ref , atol = 1e-2 , rtol = 1e-2 , err_msg = i )
1958+ self . assertAllClose (y , y_ref , atol = 1e-2 , rtol = 1e-2 , err_msg = i )
19591959
19601960 @parameterized .product (
19611961 axis = [0 , 1 ],
0 commit comments