File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -1572,6 +1572,15 @@ void scaled_gemm(
15721572#else
15731573 TORCH_CHECK (false , " scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above" );
15741574#endif // CUDA_VERSION >= 12080
1575+ } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
1576+ #if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
1577+ computeDesc.setAttribute (CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
1578+ computeDesc.setAttribute (CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
1579+ #elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
1580+ // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below
1581+ #else
1582+ TORCH_CHECK (false , " scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above" );
1583+ #endif // if CUDA_VERSION >= 12090
15751584 }
15761585
15771586 size_t workspaceSize = _getWorkspaceSize ();
You can’t perform that action at this time.
0 commit comments