Skip to content

Commit d17e222

Browse files
authored
[release/2.7] Enable mx fp8 support on ROCm (#2199)
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent ba48d6f commit d17e222

File tree

5 files changed

+67
-33
lines changed

5 files changed

+67
-33
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,16 @@ void scaled_gemm(
15321532
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
15331533
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
15341534
}
1535+
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1536+
#if ROCM_VERSION >= 70000
1537+
if (at::detail::getCUDAHooks().isGPUArch(0, {"gfx950"})) {
1538+
// Validate matrix dimensions for MX format
1539+
TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0),
1540+
"Matrix dimensions must be multiples of 32 for MX format. ",
1541+
"Got m=", m, ", n=", n, ", k=", k);
1542+
}
1543+
#endif
1544+
}
15351545
#else
15361546
// rowwise isn't supported using older hipblaslt
15371547
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt");
@@ -1570,11 +1580,11 @@ void scaled_gemm(
15701580
}
15711581

15721582
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1573-
#if CUDA_VERSION >= 12080
1583+
#if (!defined(USE_ROCM) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
15741584
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
15751585
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
15761586
#else
1577-
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
1587+
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 or ROCm 7.0(with gfx950) and above");
15781588
#endif // CUDA_VERSION >= 12080
15791589
} else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
15801590
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,12 +1020,15 @@ ScalingType get_scaling_type(
10201020
auto expected_b_size =
10211021
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
10221022

1023+
//TODO: enable the checks for ROCm
1024+
#ifndef USE_ROCM
10231025
TORCH_CHECK(scale_a.numel() == expected_a_size,
10241026
"For BlockWise scaling: Expected scale_a size to be ",
10251027
expected_a_size, " but got ", scale_a.numel());
10261028
TORCH_CHECK(scale_b.numel() == expected_b_size,
10271029
"For BlockWise scaling: Expected scale_b size to be ",
10281030
expected_b_size, " but got ", scale_b.numel());
1031+
#endif
10291032

10301033
TORCH_CHECK(
10311034
scale_a.is_contiguous() && scale_b.is_contiguous(),
@@ -1092,6 +1095,7 @@ ScalingType get_scaling_type(
10921095

10931096
} // namespace
10941097

1098+
10951099
// Computes matrix multiply + bias while applying scaling to input and output matrices
10961100
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
10971101
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
@@ -1155,6 +1159,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11551159
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
11561160
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
11571161
"Multiplication of two Float8_e5m2 matrices is not supported");
1162+
#endif
1163+
#ifdef USE_ROCM
1164+
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
1165+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e5m2 is only supported for ROCm 6.0 and above");
1166+
}
1167+
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
1168+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e4m3fn is only supported for ROCm 6.0 and above");
1169+
}
11581170
#endif
11591171
if (bias) {
11601172
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
@@ -1211,17 +1223,33 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
12111223
}
12121224
#else
12131225
if (scaling_choice == ScalingType::RowWise) {
1214-
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
1226+
// For ROCm, match behavior of f8f8bf16_rowwise type checking
12151227
Tensor b = mat2;
12161228
if (_scaled_mm_is_fnuz()) {
12171229
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
12181230
}
12191231
else {
12201232
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
12211233
}
1222-
// Until more than bf16 is supported.
1234+
// Until more than bf16 is supported
12231235
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
1224-
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
1236+
"hipblaslt rowwise _scaled_mm only supports BFloat16 output");
1237+
}
1238+
else if (scaling_choice == ScalingType::BlockWise) {
1239+
#if ROCM_VERSION >= 70000
1240+
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch(0, {"gfx950"}),
1241+
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
1242+
1243+
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
1244+
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
1245+
"Matrix dimensions must be multiples of 32 for block-wise scaling");
1246+
1247+
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
1248+
out.scalar_type() == ScalarType::Half,
1249+
"Block-wise scaling only supports BFloat16 or Half output types");
1250+
#else
1251+
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
1252+
#endif
12251253
}
12261254
#endif
12271255

@@ -1300,10 +1328,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13001328
params.k = args.k;
13011329
params.a = args.mata->data_ptr();
13021330
params.a_scale_ptr = args.scale_mata_ptr;
1331+
params.a_scale_dtype = scale_a.scalar_type();
13031332
params.lda = args.lda;
13041333
params.a_dtype = args.mata->scalar_type();
13051334
params.b = args.matb->data_ptr();
13061335
params.b_scale_ptr = args.scale_matb_ptr;
1336+
params.b_scale_dtype = scale_b.scalar_type();
13071337
params.ldb = args.ldb;
13081338
params.b_dtype = args.matb->scalar_type();
13091339
params.bias_ptr = bias ? bias->data_ptr(): nullptr;

test/test_matmul_cuda.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -929,12 +929,16 @@ def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None
929929
BLOCK_SIZE = 32
930930
require_exact_match = True
931931

932+
if torch.version.hip:
933+
if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0):
934+
raise unittest.SkipTest("Matrix dimensions must be multiples of 32 on ROCm, skipping")
935+
932936
def ceil_div(a, b):
933937
return (a + b - 1) // b
934938

935939
if test_case_name == "a_eye_b_eye":
936940
if not ((M == K) and (M == N)):
937-
return unittest.skip("this test is only defined for M == K == N, skipping")
941+
raise unittest.SkipTest("this test is only defined for M == K == N, skipping")
938942
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
939943
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
940944

@@ -943,9 +947,6 @@ def ceil_div(a, b):
943947

944948
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
945949
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
946-
# convert to swizzled format
947-
A_scale = to_blocked(A_scale)
948-
B_scale = to_blocked(B_scale)
949950

950951
elif test_case_name == "a_ones_b_ones":
951952
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -956,9 +957,6 @@ def ceil_div(a, b):
956957

957958
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
958959
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
959-
# convert to swizzled format
960-
A_scale = to_blocked(A_scale)
961-
B_scale = to_blocked(B_scale)
962960

963961
elif test_case_name == "a_ones_modified_b_ones":
964962
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -972,9 +970,6 @@ def ceil_div(a, b):
972970

973971
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
974972
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
975-
# convert to swizzled format
976-
A_scale = to_blocked(A_scale)
977-
B_scale = to_blocked(B_scale)
978973

979974
elif test_case_name == "a_ones_b_ones_modified":
980975
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -988,9 +983,6 @@ def ceil_div(a, b):
988983

989984
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
990985
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
991-
# convert to swizzled format
992-
A_scale = to_blocked(A_scale)
993-
B_scale = to_blocked(B_scale)
994986

995987
elif test_case_name == "a_scale_modified_b_ones":
996988
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -1006,10 +998,6 @@ def ceil_div(a, b):
1006998
A[1][0:BLOCK_SIZE] = 2
1007999
A_scale[1][0] = 2
10081000

1009-
# convert to swizzled format
1010-
A_scale = to_blocked(A_scale)
1011-
B_scale = to_blocked(B_scale)
1012-
10131001
elif test_case_name == "a_ones_b_scale_modified":
10141002
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
10151003
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
@@ -1024,10 +1012,6 @@ def ceil_div(a, b):
10241012
B[1][0:BLOCK_SIZE] = 2
10251013
B_scale[1][0] = 2
10261014

1027-
# convert to swizzled format
1028-
A_scale = to_blocked(A_scale)
1029-
B_scale = to_blocked(B_scale)
1030-
10311015
elif test_case_name == "data_random_scales_one":
10321016
require_exact_match = False
10331017
# scales all-ones, element data random while being exactly representable in float8_e4m3fn
@@ -1045,13 +1029,9 @@ def ceil_div(a, b):
10451029
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
10461030
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
10471031

1048-
# convert to swizzled format
1049-
A_scale = to_blocked(A_scale)
1050-
B_scale = to_blocked(B_scale)
1051-
10521032
elif test_case_name == "data_random_scales_from_data":
10531033
if not K % BLOCK_SIZE == 0:
1054-
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
1034+
raise unittest.SkipTest(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
10551035
require_exact_match = False
10561036
# random data, scales from data
10571037
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
@@ -1069,7 +1049,8 @@ def ceil_div(a, b):
10691049
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
10701050
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
10711051

1072-
# convert to swizzled format
1052+
# convert to swizzled format
1053+
if not torch.version.hip:
10731054
A_scale = to_blocked(A_scale)
10741055
B_scale = to_blocked(B_scale)
10751056

torch/testing/_internal/common_cuda.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ def evaluate_platform_supports_fp8():
100100

101101
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
102102

103-
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
103+
def _platform_supports_mx_gemm():
104+
if torch.cuda.is_available():
105+
if torch.version.hip:
106+
return 'gfx95' in torch.cuda.get_device_properties(0).gcnArchName
107+
else:
108+
return SM100OrLater
109+
return False
110+
111+
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: _platform_supports_mx_gemm())
104112

105113
if TEST_NUMBA:
106114
try:

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3865,6 +3865,7 @@
38653865
("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)),
38663866
("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)),
38673867
("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)),
3868+
("CUDA_R_4F_E2M1", ("HIP_R_4F_E2M1", CONV_TYPE, API_RUNTIME)),
38683869
(
38693870
"MAJOR_VERSION",
38703871
("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),
@@ -7325,6 +7326,10 @@
73257326
("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)),
73267327
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
73277328
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
7329+
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7330+
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7331+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
7332+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)),
73287333
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
73297334
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
73307335
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),

0 commit comments

Comments
 (0)