Skip to content

Commit 2337da4

Browse files
authored
[release/2.7] support hipblaslt outer vec 32f enum (#2226)
Cherry-pick of upstream pytorch#154680.
1 parent 17364f3 commit 2337da4

File tree

6 files changed

+55
-10
lines changed

6 files changed

+55
-10
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,15 +1524,19 @@ void scaled_gemm(
15241524
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
15251525
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
15261526
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
1527-
#if defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
1527+
#if defined(USE_ROCM)
1528+
#if defined(HIPBLASLT_OUTER_VEC)
1529+
// this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
1530+
#elif defined(HIPBLASLT_VEC_EXT)
15281531
if (use_rowwise) {
15291532
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
15301533
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
15311534
}
15321535
#else
1533-
// rowwise isn't supported using cublaslt or older hipblaslt
1534-
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
1536+
// rowwise isn't supported using older hipblaslt
1537+
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt");
15351538
#endif
1539+
#endif // defined(USE_ROCM)
15361540
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
15371541
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
15381542
if (result_scale_ptr != nullptr) {
@@ -1572,7 +1576,15 @@ void scaled_gemm(
15721576
#else
15731577
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
15741578
#endif // CUDA_VERSION >= 12080
1575-
}
1579+
} else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
1580+
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
1581+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
1582+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
1583+
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
1584+
// no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below
1585+
#else
1586+
TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above");
1587+
#endif // if CUDA_VERSION >= 12090
15761588

15771589
size_t workspaceSize = _getWorkspaceSize();
15781590
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,12 @@ class HipblasltGemmOp : public Callable<ParamsT> {
522522
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
523523
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
524524
}
525+
#ifdef HIPBLASLT_OUTER_VEC
526+
if (GetUseRowwiseFromParams<CT>(params)) {
527+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
528+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
529+
}
530+
#endif
525531
}
526532
if (result_scale_ptr) {
527533
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ ScalingType get_scaling_type(
10561056
if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
10571057
scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
10581058
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \
1059-
(defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT))
1059+
(defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC)))
10601060
TORCH_CHECK(
10611061
scale_a.is_contiguous() && scale_b.is_contiguous(),
10621062
"Both scale_a and scale_b must be contiguous for RowWise scaling.");

cmake/Dependencies.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,9 @@ if(USE_ROCM)
10591059
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
10601060
list(APPEND HIP_CXX_FLAGS -std=c++17)
10611061
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
1062+
if(HIPBLASLT_OUTER_VEC)
1063+
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_OUTER_VEC)
1064+
endif()
10621065
if(HIPBLASLT_VEC_EXT)
10631066
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
10641067
endif()

cmake/public/LoadHIP.cmake

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,21 @@ if(HIP_FOUND)
178178
set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
179179

180180
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
181+
# check whether hipblaslt provides HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
182+
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_outer_vec.cc")
183+
file(WRITE ${file} ""
184+
"#define LEGACY_HIPBLAS_DIRECT\n"
185+
"#include <hipblaslt/hipblaslt.h>\n"
186+
"int main() {\n"
187+
" hipblasLtMatmulMatrixScale_t attr = HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;\n"
188+
" return 0;\n"
189+
"}\n"
190+
)
191+
try_compile(hipblaslt_compile_result_outer_vec ${PROJECT_RANDOM_BINARY_DIR} ${file}
192+
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
193+
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
194+
OUTPUT_VARIABLE hipblaslt_compile_output_outer_vec)
195+
181196
# check whether hipblaslt provides HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT
182197
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_vec_ext.cc")
183198
file(WRITE ${file} ""
@@ -191,15 +206,21 @@ if(HIP_FOUND)
191206
try_compile(hipblaslt_compile_result_vec_ext ${PROJECT_RANDOM_BINARY_DIR} ${file}
192207
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
193208
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
194-
OUTPUT_VARIABLE hipblaslt_compile_output)
195-
if(hipblaslt_compile_result_vec_ext)
209+
OUTPUT_VARIABLE hipblaslt_compile_output_vec_ext)
210+
211+
if(hipblaslt_compile_result_outer_vec)
212+
set(HIPBLASLT_OUTER_VEC ON)
213+
set(HIPBLASLT_VEC_EXT OFF)
214+
message("hipblaslt is using scale pointer outer vec")
215+
elseif(hipblaslt_compile_result_vec_ext)
216+
set(HIPBLASLT_OUTER_VEC OFF)
196217
set(HIPBLASLT_VEC_EXT ON)
197-
#message("hipblaslt is using scale pointer vec ext: ${hipblaslt_compile_output}")
198218
message("hipblaslt is using scale pointer vec ext")
199219
else()
220+
set(HIPBLASLT_OUTER_VEC OFF)
200221
set(HIPBLASLT_VEC_EXT OFF)
201-
message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output}")
202-
#message("hipblaslt is NOT using scale pointer vec ext")
222+
message("hipblaslt is NOT using scale pointer outer vec: ${hipblaslt_compile_output_outer_vec}")
223+
message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output_vec_ext}")
203224
endif()
204225
endif()
205226
endif()

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7320,6 +7320,9 @@
73207320
("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
73217321
("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
73227322
("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
7323+
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7324+
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7325+
("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)),
73237326
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
73247327
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
73257328
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),

0 commit comments

Comments
 (0)