@@ -1112,6 +1112,21 @@ def testDotAlgorithm(self, algorithm, dtype):
11121112 }:
11131113 raise SkipTest (
11141114 f"The dot algorithm '{ algorithm } ' is not supported on GPU." )
1115+ if jtu .test_device_matches (["tpu" ]):
1116+ if algorithm not in {
1117+ lax .DotAlgorithmPreset .DEFAULT ,
1118+ lax .DotAlgorithmPreset .BF16_BF16_F32 ,
1119+ lax .DotAlgorithmPreset .BF16_BF16_F32_X3 ,
1120+ lax .DotAlgorithmPreset .BF16_BF16_F32_X6 ,
1121+ }:
1122+ raise SkipTest (
1123+ f"The dot algorithm '{ algorithm } ' is not supported on TPU."
1124+ )
1125+ if algorithm != lax .DotAlgorithmPreset .DEFAULT and dtype != np .float32 :
1126+ raise SkipTest (
1127+ f"The dot algorithm '{ algorithm } ' is only supported for float32 on"
1128+ " TPU."
1129+ )
11151130 lhs_shape = (3 , 4 )
11161131 rhs_shape = (4 , 3 )
11171132 rng = jtu .rand_default (self .rng ())
@@ -1136,6 +1151,8 @@ def testDotAlgorithmCasting(self):
11361151 if xla_bridge .using_pjrt_c_api ():
11371152 raise SkipTest (
11381153 "The dot algorithm attribute is not supported by PJRT C API." )
1154+ if jtu .test_device_matches (["tpu" ]):
1155+ raise SkipTest ("F32_F32_F32 is not supported on TPU." )
11391156 def fun (lhs , rhs ):
11401157 return lax .dot (lhs , rhs , precision = "F32_F32_F32" )
11411158 lhs_shape = (3 , 4 )
@@ -1188,12 +1205,14 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
11881205 )
11891206 def test_mixed_fp8_dot_general (self , lhs_shape , rhs_shape , dtype_lhs , dtype_rhs ):
11901207 if jtu .test_device_matches (["tpu" ]):
1191- raise SkipTest ("Mixed fp8 precision matmul is not yet supported on TPU" )
1208+ raise SkipTest ("Mixed fp8 precision matmul is not yet supported on TPU" )
11921209 if not jtu .is_device_rocm () and (
11931210 dtype_lhs in [dtypes .float8_e4m3fnuz , dtypes .float8_e5m2fnuz ] or
11941211 dtype_rhs in [dtypes .float8_e4m3fnuz , dtypes .float8_e5m2fnuz ]
11951212 ):
1196- raise SkipTest ("float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm" )
1213+ raise SkipTest (
1214+ "float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm"
1215+ )
11971216 rng = jtu .rand_default (self .rng ())
11981217 lhs = rng (lhs_shape , dtype = dtype_lhs )
11991218 rhs = rng (rhs_shape , dtype = dtype_rhs )
0 commit comments