Skip to content

Commit a1a9632

Browse files
authored
[ROCm] fix build for newer hipblaslt BC-breaking change (#2510)
[ROCm] fix build for newer hipblaslt BC-breaking changeo hipblaslt adds HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F and HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT is removed.
1 parent 46ba24c commit a1a9632

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def get_extensions():
429429
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
430430
found_col16 = False
431431
found_vec_ext = False
432+
found_outer_vec = False
432433
print("ROCM_HOME", ROCM_HOME)
433434
hipblaslt_headers = list(
434435
glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))
@@ -441,12 +442,17 @@ def get_extensions():
441442
found_col16 = True
442443
if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text:
443444
found_vec_ext = True
445+
if "HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F" in text:
446+
found_outer_vec = True
444447
if found_col16:
445448
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
446449
print("hipblaslt found extended col order enums")
447450
else:
448451
print("hipblaslt does not have extended col order enums")
449-
if found_vec_ext:
452+
if found_outer_vec:
453+
extra_compile_args["cxx"].append("-DHIPBLASLT_OUTER_VEC")
454+
print("hipblaslt found outer vec")
455+
elif found_vec_ext:
450456
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
451457
print("hipblaslt found vec ext")
452458
else:

torchao/csrc/rocm/swizzle/swizzle.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ ScalingType get_scaling_type(
362362
// Check for RowWise scaling
363363
if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
364364
scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
365-
#if defined(HIPBLASLT_VEC_EXT)
365+
#if defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC)
366366
TORCH_CHECK(
367367
scale_a.is_contiguous() && scale_b.is_contiguous(),
368368
"Both scale_a and scale_b must be contiguous for RowWise scaling.");
@@ -619,17 +619,25 @@ void _scaled_gemm(
619619
computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
620620
hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
621621
hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
622-
#if defined(HIPBLASLT_VEC_EXT)
622+
#if defined(HIPBLASLT_OUTER_VEC)
623+
// this case is handled later with HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
624+
#elif defined(HIPBLASLT_VEC_EXT)
623625
if (use_rowwise) {
624626
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
625627
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
626628
}
627629
#else
628-
// rowwise isn't supported using cublaslt or older hipblaslt
630+
// rowwise isn't supported using older hipblaslt
629631
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
630632
#endif
631633
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
632634
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
635+
#if defined(HIPBLASLT_OUTER_VEC)
636+
if (use_rowwise) {
637+
computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
638+
computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
639+
}
640+
#endif
633641
if (result_scale_ptr != nullptr) {
634642
computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
635643
}

0 commit comments

Comments
 (0)