Skip to content

Commit bacc812

Browse files
authored
[Tests][Kernels] Enable passing tests for AMD GFX942 (#7365)
1 parent 54411be commit bacc812

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
257257
if split_k > 1:
258258
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
259259

260-
if is_hip_cdna3() and ("float8_e4m3fn" in (weight_dtype_str, act_dtype_str)
261-
or "float8_e5m2" in (weight_dtype_str, act_dtype_str)):
262-
pytest.skip("float8_e4m3fn and float8_e5m2 hasn't been fully tested on AMD CDNA3 platform.")
260+
if is_hip_cdna3() and ("float8_e4m3fn" in (weight_dtype_str, act_dtype_str)):
261+
pytest.skip("float8_e4m3fn hasn't been fully tested on AMD CDNA3 platform.")
263262

264263
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
265264
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")

python/triton_kernels/tests/test_mxfp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,6 @@ def test_mxfp_casting(
148148
pytest.skip("Other swizzling patterns are not supported by AMD GPU")
149149
if quant_dtype == 'float8_e4m3fn' and is_hip_cdna3():
150150
pytest.skip("float8_e4m3fn cast hasn't been fully tested on AMD CDNA3")
151-
if quant_dtype == 'float8_e5m2' and is_hip_cdna3():
152-
pytest.skip("float8_e5m2 cast hasn't been fully tested on AMD CDNA3")
153151

154152
swizzle_axis = swizzle_axis if (swizzle_value or swizzle_scale) else None
155153
quant_torch_type = dtype_str_to_torch(quant_dtype)

0 commit comments

Comments
 (0)