@@ -522,7 +522,7 @@ static class GemmAlgoCache {
522522public:
523523 struct Key {
524524 int deviceCap;
525- hipDataType a_type, b_type, d_type, bias_type;
525+ hipDataType a_type, b_type, d_type, bias_type, aux_type ;
526526 int m, n, k;
527527 int lda, ldb, ldd;
528528 hipblasOperation_t transa, transb;
@@ -532,13 +532,13 @@ public:
532532
533533 Key (int deviceCap_,
534534 hipDataType a_type_, hipDataType b_type_,
535- hipDataType d_type_, hipDataType bias_type_,
535+ hipDataType d_type_, hipDataType bias_type_, hipDataType aux_type_,
536536 int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
537537 hipblasOperation_t transa_, hipblasOperation_t transb_,
538538 int scaling_mode_, hipblasLtEpilogue_t epilogue_):
539539 deviceCap (deviceCap_),
540540 a_type (a_type_), b_type(b_type_),
541- d_type (d_type_), bias_type(bias_type_),
541+ d_type (d_type_), bias_type(bias_type_), aux_type(aux_type_),
542542 m (m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_),
543543 transa (transa_), transb(transb_),
544544 scaling_mode (scaling_mode_), epilogue(epilogue_) {}
@@ -550,6 +550,7 @@ public:
550550 return ((deviceCap == val.deviceCap )
551551 && (a_type == val.a_type ) && (b_type == val.b_type )
552552 && (d_type == val.d_type ) && (bias_type == val.bias_type )
553+ && (aux_type == val.aux_type )
553554 && (m == val.m ) && (n == val.n ) && (k == val.k )
554555 && (lda == val.lda ) && (ldb == val.ldb ) && (ldd == val.ldd )
555556 && (transa == val.transa ) && (transb == val.transb )
@@ -681,7 +682,7 @@ protected:
681682 {
682683 csv_helper fs (ofs, csv_sep);
683684 fs << " dev_cap" << " m" << " n" << " k" << " trans_a" << " trans_b"
684- << " type_a" << " type_b" << " type_d" << " bias_type"
685+ << " type_a" << " type_b" << " type_d" << " bias_type" << " aux_type "
685686 << " lda" << " ldb" << " ldd" << " scale_mode" << " epi" << " comp" << " scale_type"
686687 << " ws_min" << " ws_max" << " algo_id" << " aidx" ;
687688 }
@@ -723,7 +724,7 @@ protected:
723724 if (line.empty () || line[0 ] == ' #' ) continue ;
724725 std::istringstream is (line);
725726 char c;
726- std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
727+ std::string type_a, type_b, type_d, bias_type, aux_type, trans_a, trans_b, epi, comp, scale;
727728 int64_t algo_id;
728729 int algo_idx;
729730 size_t ws_min, ws_max;
@@ -750,6 +751,7 @@ protected:
750751 std::getline (is, type_d, csv_sep);
751752 std::getline (is, bias_type, csv_sep);
752753 is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c;
754+ std::getline (is, aux_type, csv_sep);
753755 std::getline (is, epi, csv_sep);
754756 std::getline (is, comp, csv_sep);
755757 std::getline (is, scale, csv_sep);
@@ -801,6 +803,9 @@ protected:
801803 cfg.bias_type = (bias_type == " -" )
802804 ? (hipDataType)-1
803805 : typeNameMapper.getValue (bias_type, " bias_type" , fp8_filter);
806+ cfg.aux_type = (aux_type == " -" )
807+ ? (hipDataType)-1
808+ : typeNameMapper.getValue (aux_type, " aux_type" , fp8_filter);
804809
805810 cfg.transa = transposeNameMapper.getValue (trans_a, " trans_a" );
806811 cfg.transb = transposeNameMapper.getValue (trans_b, " trans_b" );
@@ -886,6 +891,7 @@ protected:
886891 << transposeNameMapper.getName (cfg.transa ) << transposeNameMapper.getName (cfg.transb )
887892 << typeNameMapper.getName (cfg.a_type ) << typeNameMapper.getName (cfg.b_type ) << typeNameMapper.getName (cfg.d_type )
888893 << ((cfg.bias_type == (hipDataType)-1 ) ? " -" : typeNameMapper.getName (cfg.bias_type ))
894+ << ((cfg.aux_type == (hipDataType)-1 ) ? " -" : typeNameMapper.getName (cfg.aux_type ))
889895 << cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName (cfg.epilogue )
890896 << computeNameMapper.getName (HIPBLAS_COMPUTE_32F) << typeNameMapper.getName (HIP_R_32F)
891897 << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end () << " \n " ;
@@ -1003,19 +1009,35 @@ void hipblaslt_gemm(const Tensor *inputA,
10031009 const hipDataType B_type = get_hipblaslt_dtype (param.Btype );
10041010 const hipDataType D_type = get_hipblaslt_dtype (outputD->data .dtype );
10051011 const hipDataType bias_type = get_hipblaslt_dtype (inputBias->data .dtype );
1006- // const hipblasltDatatype_t aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype);
1012+ const hipDataType aux_type = get_hipblaslt_dtype (outputPreGelu->data .dtype );
10071013
10081014 NVTE_CHECK (!is_fp8_dtype (param.Atype ) || param.A_scale_inv != nullptr ,
10091015 " FP8 input to GEMM requires inverse of scale!" );
10101016 NVTE_CHECK (!is_fp8_dtype (param.Btype ) || param.B_scale_inv != nullptr ,
10111017 " FP8 input to GEMM requires inverse of scale!" );
10121018
1013- // check consistency of arguments:
1014- // if fp8 is desired, context cannot be null
1019+ #if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
1020+ if (use_fp8 && gelu) {
1021+ hipDeviceProp_t prop;
1022+ NVTE_CHECK_CUDA (hipGetDeviceProperties (&prop, 0 ));
1023+ // Currently hipblasLT only supports fp8 gemm + gelu fusion only on MI300
1024+ if (prop.major == 9 && prop.minor == 4 ) {
1025+ bool allow_fp8_gemm = (param.Atype == DType::kFloat8E4M3 ) &&
1026+ (param.Btype == DType::kFloat8E4M3 ) &&
1027+ (outputD->data .dtype == DType::kFloat8E4M3 ) &&
1028+ (!bias || inputBias->data .dtype == DType::kFloat16 ) &&
1029+ (outputPreGelu->data .dtype == DType::kFloat16 || outputPreGelu->data .dtype == outputD->data .dtype );
1030+ NVTE_CHECK (allow_fp8_gemm, " fp8 gemm + gelu fusion is unavailable with current config!" );
1031+ } else {
1032+ NVTE_CHECK (false , " fp8 gemm + gelu fusion is unavailable right now!" );
1033+ }
1034+ }
1035+ #else
10151036 // fp8 + gelu fusion + fp8 aux is unavailable right now.
10161037 if (use_fp8) {
10171038 NVTE_CHECK (!gelu, " fp8 gemm + gelu fusion is unavailable right now!" );
10181039 }
1040+ #endif
10191041 if (is_fp8_dtype (outputD->data .dtype )) {
10201042 NVTE_CHECK (!accumulate, " Accumulation mode not supported with FP8 GEMM output!" );
10211043 }
@@ -1064,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA,
10641086 ¶m.transB , sizeof (param.transB )));
10651087
10661088 // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
1067- // Note: gelu fusion isn't available right now, and we don't need
1089+ // Note: gelu fusion is available for certain config from rocm 7.0
10681090 // amax(D) either (next op is high precision).
10691091#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
10701092 hipblasLtMatmulMatrixScale_t scaling_mode;
@@ -1116,6 +1138,14 @@ void hipblaslt_gemm(const Tensor *inputA,
11161138 HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
11171139 &bias_type, sizeof (bias_type)));
11181140 }
1141+ #if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
1142+ if (gelu){
1143+ NVTE_CHECK_HIPBLASLT (hipblasLtMatmulDescSetAttribute (operationDesc,
1144+ HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE,
1145+ &aux_type,
1146+ sizeof (aux_type)));
1147+ }
1148+ #endif
11191149 }
11201150
11211151 if (bias && gelu) {
@@ -1167,6 +1197,7 @@ void hipblaslt_gemm(const Tensor *inputA,
11671197
11681198 GemmAlgoCache::Key gemm_cfg (algoCache.device_cap (device_id), A_type, B_type, D_type,
11691199 use_fp8 ? bias_type : (hipDataType)-1 ,
1200+ (use_fp8 && gelu) ? aux_type : (hipDataType)-1 ,
11701201 m, n, k, param.lda , param.ldb , ldd, param.transA , param.transB , scaling_mode, epilogue );
11711202 GemmAlgoCache::Algo cached_algo;
11721203 if (algoCache.find (gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo .has_value ())
0 commit comments