Skip to content

Commit 03d5094

Browse files
jeffdailyethanwee1
authored andcommitted
fix 2337da4
missing brace } after cherry-pick
1 parent d2aa57c commit 03d5094

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,15 @@ void scaled_gemm(
15721572
#else
15731573
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
15741574
#endif // CUDA_VERSION >= 12080
1575+
} else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
1576+
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
1577+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
1578+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
1579+
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
1580+
// no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below
1581+
#else
1582+
TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above");
1583+
#endif // if CUDA_VERSION >= 12090
15751584
}
15761585

15771586
size_t workspaceSize = _getWorkspaceSize();

0 commit comments

Comments
 (0)