Skip to content

Commit 03968e6

Browse files
authored
[AMD] Enable all existing scaled_dot data type tests on MI300 (#5074)
triton-lang/triton#5062 enabled upcasting fp8E4M3FN to bf16; so now we can support that variant too.
1 parent d6739d3 commit 03968e6

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
is_hip,
3232
is_hip_cdna,
3333
is_hip_mi200,
34+
is_hip_mi300,
3435
get_arch,
3536
torch_float8_dtypes,
3637
torch_dtypes,
@@ -3371,8 +3372,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
33713372
if is_hip():
33723373
if not is_hip_cdna():
33733374
pytest.skip("scaled_dot only implemented for HIP CDNA")
3374-
if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]):
3375-
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
3375+
if "e4m3" in (type_a, type_b) and not is_hip_mi300():
3376+
pytest.skip(f"scaled_dot({type_a}, {type_b}) only implemented for MI300")
33763377
if mma == 16 and K == 64:
33773378
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
33783379

0 commit comments

Comments
 (0)