Skip to content

Commit 6c28b4b

Browse files
jagadish-amdjeffdaily
authored andcommitted
[ROCm] Fix mx fp8 and fp4 code after scaling refactor changes. (pytorch#163127)
PR pytorch#151360 added mx fp8 and fp4 support on ROCm. 1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures. This patch corrects is_blockwise_1x32_scaling function code. 2. Fixes the m, n, k dimensions for ROCm mx case. 3. Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2. This resulted in higher SQNR value for mx fp4 test. Testing result on gfx950 w/ ROCm7.0 PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s OK passed 111 This is same as before. (when PR 151360 was merged) Pull Request resolved: pytorch#163127 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
1 parent b3b62fa commit 6c28b4b

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,8 +1954,8 @@ void scaled_gemm(
19541954
#if ROCM_VERSION >= 70000
19551955
if (at::detail::getCUDAHooks().isGPUArch({"gfx950"})) {
19561956
// TODO: add constraints based on hipblaslt internals
1957-
TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0),
1958-
"Matrix dimensions must be multiples of 32 for MX format. "
1957+
TORCH_CHECK((m % 16 == 0) && (n % 16 == 0) && (k % 128 == 0),
1958+
"M, N must be multiples of 16 and K should be multiple of 128 for MX format. "
19591959
"Got m=", m, ", n=", n, ", k=", k);
19601960
}
19611961
#endif

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,9 +1138,14 @@ bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) {
11381138
bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) {
11391139
// TODO: We might want to enforce some structure on the shapes of the scale
11401140
// tensors
1141-
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu
1142-
&& scale.numel() == round_up<int64_t>(t.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(t.size(1), 32), 4)
1143-
&& scale.is_contiguous());
1141+
bool is_fp8_path = (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu
1142+
&& scale.numel() == round_up<int64_t>(t.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(t.size(1), 32), 4));
1143+
bool is_packed_fp4_path = false;
1144+
#ifdef USE_ROCM
1145+
is_packed_fp4_path = (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e8m0fnu
1146+
&& scale.numel() == round_up<int64_t>(t.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(t.size(1) * 2, 32), 4));
1147+
#endif
1148+
return (is_fp8_path || is_packed_fp4_path) && scale.is_contiguous();
11441149
}
11451150

11461151
bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
@@ -1381,9 +1386,15 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13811386
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
13821387
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
13831388

1384-
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
1385-
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
1386-
"Matrix dimensions must be multiples of 32 for block-wise scaling");
1389+
int packed_factor = 1;
1390+
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
1391+
// For float4 data type, each byte stores two 4-bit floating-point values,
1392+
// effectively packing two elements into one byte.
1393+
packed_factor = 2;
1394+
}
1395+
TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
1396+
mat2.size(1) % 16 == 0,
1397+
"M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
13871398

13881399
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
13891400
out.scalar_type() == ScalarType::Half,

test/test_matmul_cuda.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
926926
# largest power of 2 representable in `torch.float8_e4m3fn`
927927
F8E4M3_LARGEST_POW2 = 8
928928
# largest power of 2 representable in `torch.float4_e2m1fn_x2`
929-
FP4E2M1FN_LARGEST_POW2 = 1.0
929+
FP4E2M1FN_LARGEST_POW2 = 2.0
930930
# max value of `torch.float8_e4m3fn` (448)
931931
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
932932
# exponent bias of `torch.float8_e8m0fnu`
@@ -1746,8 +1746,12 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum,
17461746

17471747
device = "cuda"
17481748
M, K, N = mkn
1749-
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
1750-
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
1749+
if recipe == "nvfp4" and K % 32 != 0:
1750+
raise unittest.SkipTest("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
1751+
1752+
if torch.version.hip:
1753+
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
1754+
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
17511755

17521756
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
17531757
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
@@ -1912,9 +1916,12 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum,
19121916
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
19131917
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
19141918
else: # nvfp4 # mxfp4
1915-
scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
1916-
A_scale = scale_func(*([A_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [A_ref, BLOCK_SIZE]))
1917-
B_scale = scale_func(*([B_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [B_ref, BLOCK_SIZE]))
1919+
if recipe == "mxfp4":
1920+
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe)
1921+
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe)
1922+
else:
1923+
A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE)
1924+
B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE)
19181925
max_val = FP4_MAX_VAL
19191926
min_val = -1 * max_val
19201927

@@ -1925,7 +1932,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum,
19251932
B = B.clamp(min=min_val, max=max_val)
19261933
B = _bfloat16_to_float4_e2m1fn_x2(B)
19271934

1928-
approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8
1935+
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
19291936

19301937
C_ref = A_ref @ B_ref.t()
19311938

0 commit comments

Comments
 (0)