Skip to content

Commit c6fa27b

Browse files
authored
[AMD] Enable OCP fp8/bf8 tests on MI350 (#6021)
This patch enables tests for OCP fp8/bf8 MFMA ops on MI350.
1 parent c4c8bac commit c6fa27b

File tree

4 files changed

+29
-10
lines changed

4 files changed

+29
-10
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_hip_mi350
1111

1212

1313
def format_exception(type, value, tb):
@@ -379,6 +379,8 @@ def test_fp8_support(fresh_triton_cache, dtype):
379379
elif is_hip():
380380
if is_hip_mi300():
381381
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
382+
if is_hip_mi350():
383+
supported_dtypes += [tl.float8e4nv]
382384

383385
@triton.jit
384386
def dtype_kernel(dtype: tl.constexpr):

python/test/unit/language/test_core.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3424,6 +3424,10 @@ def convert_fp8_to_fp32(x, device, dtype_str):
34243424
return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32)
34253425
elif dtype_str == 'float8e5':
34263426
return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32)
3427+
elif dtype_str == 'float8e4b8':
3428+
return torch.tensor(x, device=device).view(torch.float8_e4m3fnuz).to(torch.float32)
3429+
elif dtype_str == 'float8e5b16':
3430+
return torch.tensor(x, device=device).view(torch.float8_e5m2fnuz).to(torch.float32)
34273431
assert "Unsupported float8 dtype"
34283432

34293433

@@ -3553,12 +3557,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
35533557
if capability[0] < 9 and in_dtype == 'float8e4nv':
35543558
pytest.skip("float8e4nv not supported on sm <= 80")
35553559

3556-
if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'):
3557-
pytest.skip("float8e4nv and float8e5 not supported on HIP")
3558-
if is_hip() and not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_mi300())):
3559-
pytest.skip(f"{input_precision} not supported on HIP")
3560-
if is_hip() and (kpack == 2 and in_dtype == 'int8' and K < 64):
3561-
pytest.skip("kpack too large for K")
3560+
if is_hip():
3561+
if in_dtype in ("float8e5", "float8e4nv") and not is_hip_mi350():
3562+
pytest.skip(f"{in_dtype} only supported on mi350")
3563+
if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_mi300():
3564+
pytest.skip(f"{in_dtype} only supported on mi300")
3565+
if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_mi300())):
3566+
pytest.skip(f"{input_precision} not supported on HIP")
3567+
if kpack == 2 and in_dtype == 'int8' and K < 64:
3568+
pytest.skip("kpack too large for K")
35623569
if not is_hip() and kpack == 2:
35633570
pytest.skip("Skip duplicated tests on nv path")
35643571

@@ -3686,6 +3693,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
36863693
z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn)
36873694
elif in_dtype == 'float8e5':
36883695
z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2)
3696+
elif in_dtype == 'float8e4b8':
3697+
z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fnuz)
3698+
elif in_dtype == 'float8e5b16':
3699+
z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2fnuz)
36893700
else:
36903701
assert "Unsupported float8 dtype"
36913702
z_ref = to_numpy(z_fp8.to(torch.float32))
@@ -6411,7 +6422,8 @@ def matmul_kernel( #
64116422
@pytest.mark.parametrize("M, N, K", [(128, 256, 256)])
64126423
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)])
64136424
@pytest.mark.parametrize(
6414-
"in_type_str", ['float8e5', 'float8e5b16', 'float8e4b8'] if is_hip() else ['float8e5', 'float8e4nv', 'float8e4b15'])
6425+
"in_type_str",
6426+
['float8e5', 'float8e5b16', 'float8e4b8', 'float8e4nv'] if is_hip() else ['float8e5', 'float8e4nv', 'float8e4b15'])
64156427
@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128])
64166428
def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device):
64176429
num_stages = 3
@@ -6423,6 +6435,8 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
64236435
num_stages = 2
64246436
if in_type_str in ("float8e5b16", "float8e4b8") and not is_hip_mi300():
64256437
pytest.skip(f"{in_type_str} only supported on mi300")
6438+
if in_type_str in ("float8e5", "float8e4nv") and not is_hip_mi350():
6439+
pytest.skip(f"{in_type_str} only supported on mi350")
64266440

64276441
check_type_supported(in_type_str, device)
64286442
A = numpy_random((M, K), dtype_str=in_type_str)

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ def parse_options(self, opts) -> Any:
102102

103103
if "supported_fp8_dtypes" not in opts:
104104
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
105-
if self.target.arch in ('gfx940', 'gfx941', 'gfx942', 'gfx950'):
105+
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
106106
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
107+
elif self.target.arch in ('gfx950'):
108+
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
107109
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
108110

109111
if "enable_fp_fusion" not in opts:

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,8 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
467467
// store instructions, except for fp8 matmul kernels due to regression
468468
// TODO (lixun): investigate the regression and enable this feature again
469469
auto aElemTy = mfmaInstr->aElementType;
470-
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(aElemTy);
470+
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType,
471+
Float8E4M3FNType, Float8E5M2Type>(aElemTy);
471472
bool isTransposed =
472473
isChainDotHead(dotOp) || isChainDotTail(dotOp) || !isFP8;
473474
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(

0 commit comments

Comments
 (0)