Skip to content

Commit 26af9cf

Browse files
authored
Merge pull request #14565 from chengduoZH/fix_cublas_warp_error
Fix cublas warp error
2 parents 923c8e3 + f7847ca commit 26af9cf

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

paddle/fluid/platform/dynload/cublas.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP);
3232
CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP);
3333
#endif
3434

35+
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R4
36+
CUBLAS_BLAS_ROUTINE_EACH_R4(DEFINE_WRAP);
37+
#endif
3538
} // namespace dynload
3639
} // namespace platform
3740
} // namespace paddle

paddle/fluid/platform/dynload/cublas.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,33 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
9090

9191
// APIs available after CUDA 8.0
9292
#if CUDA_VERSION >= 8000
93-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx);
94-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched);
95-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched);
96-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched);
97-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched);
98-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched);
93+
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
94+
__macro(cublasGemmEx); \
95+
__macro(cublasSgemmStridedBatched); \
96+
__macro(cublasDgemmStridedBatched); \
97+
__macro(cublasCgemmStridedBatched); \
98+
__macro(cublasZgemmStridedBatched); \
99+
__macro(cublasHgemmStridedBatched);
100+
101+
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
99102
#endif
100103

101104
// APIs available after CUDA 9.0
102105
#if CUDA_VERSION >= 9000
103-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode);
104-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode);
106+
#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \
107+
__macro(cublasSetMathMode); \
108+
__macro(cublasGetMathMode);
109+
110+
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
105111
#endif
106112

113+
// APIs available after CUDA 9.1
107114
#if CUDA_VERSION >= 9010
108-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx);
109-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx);
115+
#define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \
116+
__macro(cublasGemmBatchedEx); \
117+
__macro(cublasGemmStridedBatchedEx);
118+
119+
CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
110120
#endif
111121

112122
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP

0 commit comments

Comments
 (0)