Skip to content

Commit ad1aff0

Browse files
pravnarGoogle-ML-Automation
authored andcommitted
Respect dot algorithm spec on TPU backends.
PiperOrigin-RevId: 688274131
1 parent 441aeeb commit ad1aff0

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

tests/lax_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,21 @@ def testDotAlgorithm(self, algorithm, dtype):
11121112
}:
11131113
raise SkipTest(
11141114
f"The dot algorithm '{algorithm}' is not supported on GPU.")
1115+
if jtu.test_device_matches(["tpu"]):
1116+
if algorithm not in {
1117+
lax.DotAlgorithmPreset.DEFAULT,
1118+
lax.DotAlgorithmPreset.BF16_BF16_F32,
1119+
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
1120+
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
1121+
}:
1122+
raise SkipTest(
1123+
f"The dot algorithm '{algorithm}' is not supported on TPU."
1124+
)
1125+
if algorithm != lax.DotAlgorithmPreset.DEFAULT and dtype != np.float32:
1126+
raise SkipTest(
1127+
f"The dot algorithm '{algorithm}' is only supported for float32 on"
1128+
" TPU."
1129+
)
11151130
lhs_shape = (3, 4)
11161131
rhs_shape = (4, 3)
11171132
rng = jtu.rand_default(self.rng())
@@ -1136,6 +1151,8 @@ def testDotAlgorithmCasting(self):
11361151
if xla_bridge.using_pjrt_c_api():
11371152
raise SkipTest(
11381153
"The dot algorithm attribute is not supported by PJRT C API.")
1154+
if jtu.test_device_matches(["tpu"]):
1155+
raise SkipTest("F32_F32_F32 is not supported on TPU.")
11391156
def fun(lhs, rhs):
11401157
return lax.dot(lhs, rhs, precision="F32_F32_F32")
11411158
lhs_shape = (3, 4)
@@ -1188,12 +1205,14 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
11881205
)
11891206
def test_mixed_fp8_dot_general(self, lhs_shape, rhs_shape, dtype_lhs, dtype_rhs):
11901207
if jtu.test_device_matches(["tpu"]):
1191-
raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU")
1208+
raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU")
11921209
if not jtu.is_device_rocm() and (
11931210
dtype_lhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz] or
11941211
dtype_rhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz]
11951212
):
1196-
raise SkipTest("float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm")
1213+
raise SkipTest(
1214+
"float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm"
1215+
)
11971216
rng = jtu.rand_default(self.rng())
11981217
lhs = rng(lhs_shape, dtype=dtype_lhs)
11991218
rhs = rng(rhs_shape, dtype=dtype_rhs)

0 commit comments

Comments
 (0)