Skip to content

Commit e62e394

Browse files
jagadish-amdjithunnair-amd
authored andcommitted
[release/2.7] Enable mx fp8 support on ROCm (#2199)
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]> (cherry picked from commit d17e222)
1 parent dbb9f2a commit e62e394

File tree

5 files changed

+65
-8
lines changed

5 files changed

+65
-8
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,16 @@ void scaled_gemm(
18791879
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
18801880
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
18811881
}
1882+
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1883+
#if ROCM_VERSION >= 70000
1884+
if (at::detail::getCUDAHooks().isGPUArch(0, {"gfx950"})) {
1885+
// Validate matrix dimensions for MX format
1886+
TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0),
1887+
"Matrix dimensions must be multiples of 32 for MX format. ",
1888+
"Got m=", m, ", n=", n, ", k=", k);
1889+
}
1890+
#endif
1891+
}
18821892
#else
18831893
// rowwise isn't supported using older hipblaslt
18841894
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt");
@@ -1917,11 +1927,11 @@ void scaled_gemm(
19171927
}
19181928

19191929
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1920-
#if CUDA_VERSION >= 12080
1930+
#if (!defined(USE_ROCM) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
19211931
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
19221932
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
19231933
#else
1924-
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
1934+
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 or ROCm 7.0(with gfx950) and above");
19251935
#endif // if CUDA_VERSION >= 12080
19261936
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
19271937
#if CUDA_VERSION >= 12080

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,12 +1133,15 @@ ScalingType get_scaling_type(
11331133
auto expected_b_size =
11341134
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
11351135

1136+
//TODO: enable the checks for ROCm
1137+
#ifndef USE_ROCM
11361138
TORCH_CHECK(scale_a.numel() == expected_a_size,
11371139
"For BlockWise scaling: Expected scale_a size to be ",
11381140
expected_a_size, " but got ", scale_a.numel());
11391141
TORCH_CHECK(scale_b.numel() == expected_b_size,
11401142
"For BlockWise scaling: Expected scale_b size to be ",
11411143
expected_b_size, " but got ", scale_b.numel());
1144+
#endif
11421145

11431146
TORCH_CHECK(
11441147
scale_a.is_contiguous() && scale_b.is_contiguous(),
@@ -1205,6 +1208,7 @@ ScalingType get_scaling_type(
12051208

12061209
} // namespace
12071210

1211+
12081212
// Computes matrix multiply + bias while applying scaling to input and output matrices
12091213
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
12101214
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
@@ -1268,6 +1272,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
12681272
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
12691273
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
12701274
"Multiplication of two Float8_e5m2 matrices is not supported");
1275+
#endif
1276+
#ifdef USE_ROCM
1277+
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
1278+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e5m2 is only supported for ROCm 6.0 and above");
1279+
}
1280+
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
1281+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e4m3fn is only supported for ROCm 6.0 and above");
1282+
}
12711283
#endif
12721284
if (use_fast_accum) {
12731285
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
@@ -1327,17 +1339,33 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13271339
}
13281340
#else
13291341
if (scaling_choice == ScalingType::RowWise) {
1330-
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
1342+
// For ROCm, match behavior of f8f8bf16_rowwise type checking
13311343
Tensor b = mat2;
13321344
if (_scaled_mm_is_fnuz()) {
13331345
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
13341346
}
13351347
else {
13361348
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
13371349
}
1338-
// Until more than bf16 is supported.
1350+
// Until more than bf16 is supported
13391351
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
1340-
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
1352+
"hipblaslt rowwise _scaled_mm only supports BFloat16 output");
1353+
}
1354+
else if (scaling_choice == ScalingType::BlockWise) {
1355+
#if ROCM_VERSION >= 70000
1356+
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}, 0),
1357+
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
1358+
1359+
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
1360+
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
1361+
"Matrix dimensions must be multiples of 32 for block-wise scaling");
1362+
1363+
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
1364+
out.scalar_type() == ScalarType::Half,
1365+
"Block-wise scaling only supports BFloat16 or Half output types");
1366+
#else
1367+
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
1368+
#endif
13411369
}
13421370
#endif
13431371

@@ -1416,10 +1444,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
14161444
params.k = args.k;
14171445
params.a = args.mata->data_ptr();
14181446
params.a_scale_ptr = args.scale_mata_ptr;
1447+
params.a_scale_dtype = scale_a.scalar_type();
14191448
params.lda = args.lda;
14201449
params.a_dtype = args.mata->scalar_type();
14211450
params.b = args.matb->data_ptr();
14221451
params.b_scale_ptr = args.scale_matb_ptr;
1452+
params.b_scale_dtype = scale_b.scalar_type();
14231453
params.ldb = args.ldb;
14241454
params.b_dtype = args.matb->scalar_type();
14251455
params.bias_ptr = bias ? bias->data_ptr(): nullptr;

test/test_matmul_cuda.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,10 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
14531453

14541454
device = "cuda"
14551455
M, K, N = mkn
1456+
if torch.version.hip:
1457+
if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0):
1458+
raise unittest.SkipTest("Matrix dimensions must be multiples of 32 on ROCm, skipping")
1459+
14561460
if recipe == "nvfp4" and K % 32 != 0:
14571461
return unittest.skip("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
14581462

@@ -1462,7 +1466,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
14621466

14631467
if test_case_name == "a_eye_b_eye":
14641468
if not ((M == K) and (M == N)):
1465-
return unittest.skip("this test is only defined for M == K == N, skipping")
1469+
raise unittest.SkipTest("this test is only defined for M == K == N, skipping")
14661470
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
14671471
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
14681472

@@ -1601,7 +1605,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16011605

16021606
elif test_case_name == "data_random_scales_from_data":
16031607
if not K % BLOCK_SIZE == 0:
1604-
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
1608+
raise unittest.SkipTest(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
16051609
require_exact_match = False
16061610
# random data, scales from data
16071611
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000

torch/testing/_internal/common_cuda.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,15 @@ def evaluate_platform_supports_fp8():
108108

109109
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
110110

111-
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
111+
def _platform_supports_mx_gemm():
112+
if torch.cuda.is_available():
113+
if torch.version.hip:
114+
return 'gfx95' in torch.cuda.get_device_properties(0).gcnArchName
115+
else:
116+
return SM100OrLater
117+
return False
118+
119+
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: _platform_supports_mx_gemm())
112120

113121
if TEST_NUMBA:
114122
try:

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3870,6 +3870,7 @@
38703870
("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)),
38713871
("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)),
38723872
("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)),
3873+
("CUDA_R_4F_E2M1", ("HIP_R_4F_E2M1", CONV_TYPE, API_RUNTIME)),
38733874
(
38743875
"MAJOR_VERSION",
38753876
("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),
@@ -7347,6 +7348,10 @@
73477348
("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)),
73487349
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
73497350
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
7351+
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7352+
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7353+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
7354+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)),
73507355
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
73517356
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
73527357
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),

0 commit comments

Comments
 (0)