@@ -362,7 +362,7 @@ ScalingType get_scaling_type(
362
362
// Check for RowWise scaling
363
363
if (scale_a.size (0 ) == dim_m && scale_a.size (1 ) == 1 &&
364
364
scale_b.size (0 ) == 1 && scale_b.size (1 ) == dim_n) {
365
- #if defined(HIPBLASLT_VEC_EXT)
365
+ #if defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC)
366
366
TORCH_CHECK (
367
367
scale_a.is_contiguous () && scale_b.is_contiguous (),
368
368
" Both scale_a and scale_b must be contiguous for RowWise scaling." );
@@ -619,17 +619,25 @@ void _scaled_gemm(
619
619
computeDesc.setAttribute (HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar (transb));
620
620
hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
621
621
hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
622
- #if defined(HIPBLASLT_VEC_EXT)
622
+ #if defined(HIPBLASLT_OUTER_VEC)
623
+ // this case is handled later with HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
624
+ #elif defined(HIPBLASLT_VEC_EXT)
623
625
if (use_rowwise) {
624
626
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
625
627
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
626
628
}
627
629
#else
628
- // rowwise isn't supported using cublaslt or older hipblaslt
630
+ // rowwise isn't supported using older hipblaslt
629
631
TORCH_INTERNAL_ASSERT (use_rowwise == false , " rowwise scaled_gemm not supported with blaslt" );
630
632
#endif
631
633
computeDesc.setAttribute (matmulDescA, mat1_scale_ptr);
632
634
computeDesc.setAttribute (matmulDescB, mat2_scale_ptr);
635
+ #if defined(HIPBLASLT_OUTER_VEC)
636
+ if (use_rowwise) {
637
+ computeDesc.setAttribute (HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
638
+ computeDesc.setAttribute (HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
639
+ }
640
+ #endif
633
641
if (result_scale_ptr != nullptr ) {
634
642
computeDesc.setAttribute (HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
635
643
}
0 commit comments