Skip to content

Commit 409b43a

Browse files
Enabled fp8 gemm gelu_aux_bias (#315)
* Enabled fp8 gemm gelu_aux_bias * Addded new test * Updated GemmAlgoCache * Addressed reviews
1 parent a9c4026 commit 409b43a

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

tests/cpp/operator/test_cublaslt_gemm.cu

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,21 @@ void performTest(const TestParams& params) {
209209
(void)cudaGetDeviceProperties(&prop, 0);
210210

211211
#ifdef __HIP_PLATFORM_AMD__
212+
213+
// Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0.
214+
// hipBLASLt currently supports this config only
215+
bool fp8_gelu_fusion_config = false;
216+
#if HIP_VERSION >= 70000000
217+
if (prop.major == 9 && prop.minor == 4)
218+
{
219+
fp8_gelu_fusion_config = atype == DType::kFloat8E4M3 &&
220+
btype == DType::kFloat8E4M3 &&
221+
dtype == DType::kFloat8E4M3 &&
222+
(params.use_gelu && gelu_type == DType::kFloat16) &&
223+
(!params.use_bias || bias_type == DType::kFloat16);
224+
}
225+
#endif
226+
212227
if (has_fp8)
213228
{
214229
bool fp8_supported = (prop.major == 9 && prop.minor >= 4);
@@ -227,8 +242,8 @@ void performTest(const TestParams& params) {
227242
}
228243
}
229244

230-
if (params.use_gelu) {
231-
GTEST_SKIP() << "FP8 GEMM with GELU is not supported";
245+
if (params.use_gelu && !fp8_gelu_fusion_config) {
246+
GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config";
232247
}
233248
if (params.use_bias && dtype == DType::kFloat16) {
234249
GTEST_SKIP() << "FP8 GEMM with bias and FP16 output is not supported";
@@ -252,7 +267,7 @@ void performTest(const TestParams& params) {
252267
if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) {
253268
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
254269
}
255-
if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3) {
270+
if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3 && !fp8_gelu_fusion_config) {
256271
GTEST_SKIP() << "FP8 GEMM with bias and FP8 output is not supported in current config";
257272
}
258273
}
@@ -506,6 +521,7 @@ MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xfp8, bf8, fp8, bf16, bf16, fp8);
506521

507522
MAKE_GEMM_TEST(Testbf8xfp8xbf16xbf16xbf8, bf8, fp8, bf16, bf16, bf8);
508523

524+
MAKE_GEMM_TEST(Testfp8xfp8xfp16xfp16xfp8, fp8, fp8, fp16, fp16, fp8);
509525

510526
INSTANTIATE_TEST_SUITE_P(
511527
OperatorTest,

transformer_engine/common/gemm/rocm_gemm.cu

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ static class GemmAlgoCache {
522522
public:
523523
struct Key {
524524
int deviceCap;
525-
hipDataType a_type, b_type, d_type, bias_type;
525+
hipDataType a_type, b_type, d_type, bias_type, aux_type;
526526
int m, n, k;
527527
int lda, ldb, ldd;
528528
hipblasOperation_t transa, transb;
@@ -532,13 +532,13 @@ public:
532532

533533
Key(int deviceCap_,
534534
hipDataType a_type_, hipDataType b_type_,
535-
hipDataType d_type_, hipDataType bias_type_,
535+
hipDataType d_type_, hipDataType bias_type_, hipDataType aux_type_,
536536
int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
537537
hipblasOperation_t transa_, hipblasOperation_t transb_,
538538
int scaling_mode_, hipblasLtEpilogue_t epilogue_):
539539
deviceCap(deviceCap_),
540540
a_type(a_type_), b_type(b_type_),
541-
d_type(d_type_), bias_type(bias_type_),
541+
d_type(d_type_), bias_type(bias_type_), aux_type(aux_type_),
542542
m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_),
543543
transa(transa_), transb(transb_),
544544
scaling_mode(scaling_mode_), epilogue(epilogue_) {}
@@ -550,6 +550,7 @@ public:
550550
return ((deviceCap == val.deviceCap)
551551
&& (a_type == val.a_type) && (b_type == val.b_type)
552552
&& (d_type == val.d_type) && (bias_type == val.bias_type)
553+
&& (aux_type == val.aux_type)
553554
&& (m == val.m) && (n == val.n) && (k == val.k)
554555
&& (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd)
555556
&& (transa == val.transa) && (transb == val.transb)
@@ -681,7 +682,7 @@ protected:
681682
{
682683
csv_helper fs(ofs, csv_sep);
683684
fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
684-
<< "type_a" << "type_b" << "type_d" << "bias_type"
685+
<< "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type"
685686
<< "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type"
686687
<< "ws_min" << "ws_max" << "algo_id" << "aidx";
687688
}
@@ -723,7 +724,7 @@ protected:
723724
if (line.empty() || line[0] == '#') continue;
724725
std::istringstream is(line);
725726
char c;
726-
std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
727+
std::string type_a, type_b, type_d, bias_type, aux_type, trans_a, trans_b, epi, comp, scale;
727728
int64_t algo_id;
728729
int algo_idx;
729730
size_t ws_min, ws_max;
@@ -750,6 +751,7 @@ protected:
750751
std::getline(is, type_d, csv_sep);
751752
std::getline(is, bias_type, csv_sep);
752753
is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c;
754+
std::getline(is, aux_type, csv_sep);
753755
std::getline(is, epi, csv_sep);
754756
std::getline(is, comp, csv_sep);
755757
std::getline(is, scale, csv_sep);
@@ -801,6 +803,9 @@ protected:
801803
cfg.bias_type = (bias_type == "-")
802804
? (hipDataType)-1
803805
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
806+
cfg.aux_type = (aux_type == "-")
807+
? (hipDataType)-1
808+
: typeNameMapper.getValue(aux_type, "aux_type", fp8_filter);
804809

805810
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
806811
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
@@ -886,6 +891,7 @@ protected:
886891
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
887892
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
888893
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
894+
<< ((cfg.aux_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.aux_type))
889895
<< cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName(cfg.epilogue)
890896
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
891897
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
@@ -1003,19 +1009,35 @@ void hipblaslt_gemm(const Tensor *inputA,
10031009
const hipDataType B_type = get_hipblaslt_dtype(param.Btype);
10041010
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
10051011
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
1006-
// const hipblasltDatatype_t aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype);
1012+
const hipDataType aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype);
10071013

10081014
NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
10091015
"FP8 input to GEMM requires inverse of scale!");
10101016
NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
10111017
"FP8 input to GEMM requires inverse of scale!");
10121018

1013-
// check consistency of arguments:
1014-
// if fp8 is desired, context cannot be null
1019+
#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
1020+
if (use_fp8 && gelu) {
1021+
hipDeviceProp_t prop;
1022+
NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, 0));
1023+
// Currently hipblasLT only supports fp8 gemm + gelu fusion only on MI300
1024+
if (prop.major == 9 && prop.minor == 4) {
1025+
bool allow_fp8_gemm = (param.Atype == DType::kFloat8E4M3) &&
1026+
(param.Btype == DType::kFloat8E4M3) &&
1027+
(outputD->data.dtype == DType::kFloat8E4M3) &&
1028+
(!bias || inputBias->data.dtype == DType::kFloat16) &&
1029+
(outputPreGelu->data.dtype == DType::kFloat16 || outputPreGelu->data.dtype == outputD->data.dtype);
1030+
NVTE_CHECK(allow_fp8_gemm, "fp8 gemm + gelu fusion is unavailable with current config!");
1031+
} else {
1032+
NVTE_CHECK(false, "fp8 gemm + gelu fusion is unavailable right now!");
1033+
}
1034+
}
1035+
#else
10151036
// fp8 + gelu fusion + fp8 aux is unavailable right now.
10161037
if (use_fp8) {
10171038
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
10181039
}
1040+
#endif
10191041
if (is_fp8_dtype(outputD->data.dtype)) {
10201042
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
10211043
}
@@ -1064,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA,
10641086
&param.transB, sizeof(param.transB)));
10651087

10661088
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
1067-
// Note: gelu fusion isn't available right now, and we don't need
1089+
// Note: gelu fusion is available for certain config from rocm 7.0
10681090
// amax(D) either (next op is high precision).
10691091
#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
10701092
hipblasLtMatmulMatrixScale_t scaling_mode;
@@ -1116,6 +1138,14 @@ void hipblaslt_gemm(const Tensor *inputA,
11161138
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
11171139
&bias_type, sizeof(bias_type)));
11181140
}
1141+
#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
1142+
if (gelu){
1143+
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
1144+
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE,
1145+
&aux_type,
1146+
sizeof(aux_type)));
1147+
}
1148+
#endif
11191149
}
11201150

11211151
if (bias && gelu) {
@@ -1167,6 +1197,7 @@ void hipblaslt_gemm(const Tensor *inputA,
11671197

11681198
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
11691199
use_fp8 ? bias_type : (hipDataType)-1,
1200+
(use_fp8 && gelu) ? aux_type : (hipDataType)-1,
11701201
m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue );
11711202
GemmAlgoCache::Algo cached_algo;
11721203
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())

0 commit comments

Comments
 (0)