Skip to content

Commit 55c289d

Browse files
eqyjeffdaily
authored andcommitted
[cuBLASLt][FP8] cuBLASLt appears to support float8 rowwise-scaling on H100 (pytorch#161305)
Following pytorch#157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: pytorch#161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
1 parent 2042d21 commit 55c289d

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,11 +1947,11 @@ void scaled_gemm(
19471947
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
19481948
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
19491949
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
1950+
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
19501951
// hipblaslt supported row-wise before cublas, and did so their own way (via
19511952
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
19521953
// the SCALE_MODEs). Here we check for this early custom mode.
19531954
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
1954-
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
19551955
if (use_rowwise) {
19561956
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
19571957
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
@@ -1966,8 +1966,12 @@ void scaled_gemm(
19661966
}
19671967
#endif
19681968
}
1969-
#else
1970-
// rowwise isn't supported using cublaslt or older hipblaslt
1969+
#elif (CUDA_VERSION < 12080) && !defined(USE_ROCM)
1970+
// hipblaslt supported row-wise before cublas, and did so their own way (via
1971+
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
1972+
// the SCALE_MODEs). Here we check for this early custom mode.
1973+
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
1974+
// rowwise isn't supported using older cublaslt or older hipblaslt
19711975
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
19721976
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
19731977
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);

test/test_matmul_cuda.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,18 +1315,26 @@ def test_float8_error_messages(self, device) -> None:
13151315
out_dtype=torch.bfloat16,
13161316
)
13171317

1318-
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
1319-
with self.assertRaisesRegex(
1320-
RuntimeError,
1321-
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
1322-
):
1323-
torch._scaled_mm(
1318+
def e5m2():
1319+
out = torch._scaled_mm(
13241320
x_fp8,
13251321
y_fp8.to(e5m2_type),
13261322
scale_a=torch.ones((M, 1), device="cuda"),
13271323
scale_b=torch.ones((1, N), device="cuda"),
13281324
out_dtype=torch.bfloat16,
13291325
)
1326+
return out
1327+
1328+
if torch.cuda.get_device_capability() == (9, 0):
1329+
out = e5m2()
1330+
self.assertEqual(out, torch.ones_like(out) * 128.)
1331+
else:
1332+
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
1333+
with self.assertRaisesRegex(
1334+
RuntimeError,
1335+
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
1336+
):
1337+
e5m2()
13301338

13311339
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
13321340
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")

0 commit comments

Comments
 (0)