Skip to content

Commit d10296c

Browse files
[release/2.6] fix scaled matmul and test_float8_basics_cuda (#2739)
This PR fixes: - test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_basics_cuda - AssertionError: RuntimeError not raised - test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_mm_vs_emulated_row_wise_bfloat16_cuda - AssertionError: Tensor-likes are not close! need to swap A_SCALE and B_SCALE descriptors data if `use_rowwise` like as [HIPBLASLT_VEC_EXT](https://github.com/ROCm/pytorch/blob/78f6ff789a11bcdca072f019305485d1cf06c7eb/aten/src/ATen/cuda/CUDABlas.cpp#L1450-L1454) Fixes SWDEV-544098
1 parent 78f6ff7 commit d10296c

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,12 @@ void scaled_gemm(
14471447
#if defined(USE_ROCM)
14481448
#if defined(HIPBLASLT_OUTER_VEC)
14491449
// this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
1450+
if (use_rowwise) {
1451+
// swapped
1452+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat2_scale_ptr);
1453+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat1_scale_ptr);
1454+
}
1455+
else
14501456
#elif defined(HIPBLASLT_VEC_EXT)
14511457
if (use_rowwise) {
14521458
// swapped

test/test_matmul_cuda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: linear algebra"]
22

3+
from contextlib import nullcontext
34
import unittest
45
from itertools import product
56
from functools import partial
@@ -356,7 +357,8 @@ def test_float8_basics(self, device) -> None:
356357
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
357358
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
358359
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
359-
with self.assertRaises(RuntimeError):
360+
# supported on ROCm but fails on CUDA
361+
with self.assertRaises(RuntimeError) if torch.version.hip is None else nullcontext():
360362
self._test_tautological_mm(device, e5m2_type, e5m2_type)
361363

362364
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)

0 commit comments

Comments
 (0)