Skip to content

Commit 7ed457e

Browse files
kexinzhaoqingqing01
authored andcommitted
Fix cuda 7.5 error with cublas GEMM (#9811)
* fix gemm error for cuda 7.5 * fix version number
1 parent 20f202a commit 7ed457e

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

paddle/fluid/operators/math/math_function.cu

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ void gemm<platform::CUDADeviceContext, float16>(
3939
cublasOperation_t cuTransB =
4040
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
4141

42-
float h_alpha = static_cast<float>(alpha);
43-
float h_beta = static_cast<float>(beta);
44-
4542
// TODO(kexinzhao): add processing code for compute capability < 53 case
4643
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
4744
"cublas fp16 gemm requires GPU compute capability >= 53");
4845

46+
#if CUDA_VERSION >= 8000
47+
float h_alpha = static_cast<float>(alpha);
48+
float h_beta = static_cast<float>(beta);
49+
4950
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
5051
#if CUDA_VERSION >= 9000
5152
if (context.GetComputeCapability() >= 70) {
@@ -56,7 +57,7 @@ void gemm<platform::CUDADeviceContext, float16>(
5657
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
5758
CUBLAS_DEFAULT_MATH));
5859
}
59-
#endif
60+
#endif // CUDA_VERSION >= 9000
6061

6162
// cublasHgemm does true FP16 computation which is slow for non-Volta
6263
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
@@ -66,6 +67,18 @@ void gemm<platform::CUDADeviceContext, float16>(
6667
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
6768
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
6869
CUDA_R_32F, algo));
70+
#else
71+
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
72+
const half h_alpha = static_cast<const half>(alpha);
73+
const half h_beta = static_cast<const half>(beta);
74+
const half* h_A = reinterpret_cast<const half*>(A);
75+
const half* h_B = reinterpret_cast<const half*>(B);
76+
half* h_C = reinterpret_cast<half*>(C);
77+
78+
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
79+
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
80+
h_A, lda, &h_beta, h_C, N));
81+
#endif // CUDA_VERSION >= 8000
6982
}
7083

7184
template <>

paddle/fluid/platform/dynload/cublas.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
2828
CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP);
2929
#endif
3030

31+
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R3
32+
CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP);
33+
#endif
34+
3135
} // namespace dynload
3236
} // namespace platform
3337
} // namespace paddle

paddle/fluid/platform/dynload/cublas.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ extern void *cublas_dso_handle;
7171
__macro(cublasDgemm_v2); \
7272
__macro(cublasHgemm); \
7373
__macro(cublasSgemmEx); \
74-
__macro(cublasGemmEx); \
7574
__macro(cublasSgeam_v2); \
7675
__macro(cublasDgeam_v2); \
7776
__macro(cublasCreate_v2); \
@@ -83,22 +82,31 @@ extern void *cublas_dso_handle;
8382
__macro(cublasDgemmBatched); \
8483
__macro(cublasCgemmBatched); \
8584
__macro(cublasZgemmBatched); \
86-
__macro(cublasSgemmStridedBatched); \
87-
__macro(cublasDgemmStridedBatched); \
88-
__macro(cublasCgemmStridedBatched); \
89-
__macro(cublasZgemmStridedBatched); \
90-
__macro(cublasHgemmStridedBatched); \
9185
__macro(cublasSgetrfBatched); \
9286
__macro(cublasSgetriBatched); \
9387
__macro(cublasDgetrfBatched); \
9488
__macro(cublasDgetriBatched);
9589

9690
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
9791

92+
// APIs available after CUDA 8.0
93+
#if CUDA_VERSION >= 8000
94+
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
95+
__macro(cublasGemmEx); \
96+
__macro(cublasSgemmStridedBatched); \
97+
__macro(cublasDgemmStridedBatched); \
98+
__macro(cublasCgemmStridedBatched); \
99+
__macro(cublasZgemmStridedBatched); \
100+
__macro(cublasHgemmStridedBatched);
101+
102+
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
103+
#endif
104+
98105
// APIs available after CUDA 9.0
99106
#if CUDA_VERSION >= 9000
100-
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) __macro(cublasSetMathMode);
101-
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
107+
#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) __macro(cublasSetMathMode);
108+
109+
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
102110
#endif
103111

104112
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP

0 commit comments

Comments
 (0)