Skip to content

Commit 00b9e9a

Browse files
author
chengduo
authored
Refine cublas to support CUBLAS_TENSOR_OP_MATH (#13929)
* refine cublase test=develop * code refine * refine cublas * add GEMME_EX * add enable_cublas_tensor_op_math doc and add cublasCall test=develop * fix CublasCall for cuda version test=develop * fix error test=develop * fix GEMM_EX to be compatible with gcc 4.8 test=develop * add GEMM_EX test=develop * to compatiable with gcc4.8 test=develop
1 parent dd6fd4c commit 00b9e9a

File tree

6 files changed

+256
-49
lines changed

6 files changed

+256
-49
lines changed

paddle/fluid/operators/math/blas_impl.cu.h

Lines changed: 172 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
#include "paddle/fluid/operators/math/math_function.h"
1818
#include "paddle/fluid/platform/dynload/cublas.h"
19+
#include "paddle/fluid/platform/gpu_info.h"
20+
21+
DECLARE_bool(enable_cublas_tensor_op_math);
1922

2023
namespace paddle {
2124
namespace operators {
@@ -42,11 +45,44 @@ struct CUBlas<float> {
4245
}
4346

4447
template <typename... ARGS>
45-
static void GEMM_BATCH(ARGS... args) {
48+
static void GEMM_STRIDED_BATCH(ARGS... args) {
4649
#if CUDA_VERSION >= 8000
4750
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...));
4851
#else
4952
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
53+
#endif
54+
}
55+
56+
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
57+
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
58+
template <typename... ARGS>
59+
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
60+
cublasOperation_t transa, cublasOperation_t transb, int m,
61+
int n, int k, const float *alpha, const void *A,
62+
cudaDataType_t Atype, int lda, const void *B,
63+
cudaDataType_t Btype, int ldb, const float *beta, void *C,
64+
cudaDataType_t Ctype, int ldc) {
65+
// Because the gcc 4.8 doesn't expand template parameter pack that
66+
// appears in a lambda-expression, I can not use template parameter pack
67+
// here.
68+
auto cublas_call = [&]() {
69+
#if CUDA_VERSION >= 8000
70+
VLOG(5) << "use_tensor_op_math: "
71+
<< (platform::TensorCoreAvailable() ? "True" : "False");
72+
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
73+
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
74+
lda, B, Btype, ldb, beta, C, Ctype, ldc));
75+
#else
76+
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
77+
#endif
78+
};
79+
80+
#if CUDA_VERSION >= 9000
81+
// NOTES: To use Tensor Core, we should change the cublas config,
82+
// but the cublas may be hold by multi-thread.
83+
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
84+
#else
85+
cublas_call();
5086
#endif
5187
}
5288
};
@@ -69,13 +105,18 @@ struct CUBlas<double> {
69105
}
70106

71107
template <typename... ARGS>
72-
static void GEMM_BATCH(ARGS... args) {
108+
static void GEMM_STRIDED_BATCH(ARGS... args) {
73109
#if CUDA_VERSION >= 8000
74110
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...));
75111
#else
76112
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
77113
#endif
78114
}
115+
116+
template <typename... ARGS>
117+
static void GEMM_EX(ARGS... args) {
118+
PADDLE_THROW("Currently there are not cublasDgemmEx.");
119+
}
79120
};
80121

81122
template <>
@@ -96,14 +137,16 @@ struct CUBlas<platform::float16> {
96137
reinterpret_cast<__half *>(C), ldc));
97138
}
98139

99-
static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa,
100-
cublasOperation_t transb, int m, int n, int k,
101-
const float16 *alpha, const float16 *A, int lda,
102-
long long int strideA, const float16 *B, // NOLINT
103-
int ldb, long long int strideB, // NOLINT
104-
const float16 *beta, float16 *C, int ldc,
105-
long long int strideC, // NOLINT
106-
int batchCount) {
140+
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
141+
cublasOperation_t transa,
142+
cublasOperation_t transb, int m, int n, int k,
143+
const float16 *alpha, const float16 *A,
144+
int lda, long long int strideA, // NOLINT
145+
const float16 *B, // NOLINT
146+
int ldb, long long int strideB, // NOLINT
147+
const float16 *beta, float16 *C, int ldc,
148+
long long int strideC, // NOLINT
149+
int batchCount) {
107150
#if CUDA_VERSION >= 8000
108151
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
109152
handle, transa, transb, m, n, k,
@@ -114,6 +157,45 @@ struct CUBlas<platform::float16> {
114157
ldc, strideC, batchCount));
115158
#else
116159
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
160+
#endif
161+
}
162+
163+
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
164+
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
165+
template <typename... ARGS>
166+
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
167+
cublasOperation_t transa, cublasOperation_t transb, int m,
168+
int n, int k, const void *alpha, const void *A,
169+
cudaDataType_t Atype, int lda, const void *B,
170+
cudaDataType_t Btype, int ldb, const void *beta, void *C,
171+
cudaDataType_t Ctype, int ldc,
172+
cudaDataType_t computeType) {
173+
auto cublas_call = [&]() {
174+
#if CUDA_VERSION >= 8000
175+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
176+
#if CUDA_VERSION >= 9000
177+
bool use_tensor_op_math = platform::TensorCoreAvailable();
178+
if (use_tensor_op_math) {
179+
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
180+
}
181+
VLOG(5) << "use_tensor_op_math: "
182+
<< (use_tensor_op_math ? "True" : "False");
183+
#endif // CUDA_VERSION >= 9000
184+
185+
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
186+
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
187+
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
188+
#else
189+
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
190+
#endif
191+
};
192+
193+
#if CUDA_VERSION >= 9000
194+
// NOTES: To use Tensor Core, we should change the cublas config,
195+
// but the cublas may be hold by multi-thread.
196+
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
197+
#else
198+
cublas_call();
117199
#endif
118200
}
119201
};
@@ -133,8 +215,21 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
133215
cublasOperation_t cuTransB =
134216
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
135217

136-
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
137-
B, ldb, A, lda, &beta, C, N);
218+
#if CUDA_VERSION >= 8000
219+
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
220+
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
221+
CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
222+
CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
223+
CUDA_R_32F, N);
224+
} else {
225+
#endif // CUDA_VERSION >= 8000
226+
227+
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
228+
&alpha, B, ldb, A, lda, &beta, C, N);
229+
230+
#if CUDA_VERSION >= 8000
231+
}
232+
#endif // CUDA_VERSION >= 8000
138233
}
139234

140235
template <>
@@ -157,30 +252,18 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
157252
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
158253
"cublas fp16 gemm requires GPU compute capability >= 53");
159254

160-
#if CUDA_VERSION >= 8000
161255
float h_alpha = static_cast<float>(alpha);
162256
float h_beta = static_cast<float>(beta);
163257

164-
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
165-
#if CUDA_VERSION >= 9000
166-
if (context_.GetComputeCapability() >= 70) {
167-
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
168-
context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH));
169-
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
170-
} else {
171-
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
172-
context_.cublas_handle(), CUBLAS_DEFAULT_MATH));
173-
}
174-
#endif // CUDA_VERSION >= 9000
175-
258+
#if CUDA_VERSION >= 8000
176259
// cublasHgemm does true FP16 computation which is slow for non-Volta
177260
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
178261
// input/output in fp16, computation in fp32, which can also be accelerated
179262
// using tensor cores in volta GPUs.
180-
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
181-
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
182-
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
183-
CUDA_R_32F, algo));
263+
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
264+
CUBlas<platform::float16>::GEMM_EX(
265+
&cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A,
266+
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
184267
#else
185268
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
186269
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
@@ -199,8 +282,38 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
199282
// the cblas convention.
200283
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
201284
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
202-
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
203-
B, ldb, A, lda, &beta, C, ldc);
285+
286+
#if CUDA_VERSION >= 8000
287+
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
288+
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
289+
CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
290+
CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
291+
CUDA_R_32F, ldc);
292+
} else {
293+
#endif // CUDA_VERSION >= 8000
294+
295+
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
296+
&alpha, B, ldb, A, lda, &beta, C, ldc);
297+
298+
#if CUDA_VERSION >= 8000
299+
}
300+
#endif // CUDA_VERSION >= 8000
301+
}
302+
303+
template <>
304+
template <>
305+
inline void Blas<platform::CUDADeviceContext>::GEMM(
306+
bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
307+
const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
308+
platform::float16 beta, platform::float16 *C, int ldc) const {
309+
// Note that cublas follows fortran order, so the order is different from
310+
// the cblas convention.
311+
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
312+
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
313+
314+
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
315+
N, M, K, &alpha, B, ldb, A, lda, &beta, C,
316+
ldc);
204317
}
205318

206319
template <>
@@ -238,9 +351,34 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
238351
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
239352
const int64_t strideC = M * N;
240353

241-
CUBlas<T>::GEMM_BATCH(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
242-
&alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc,
243-
strideC, batchCount);
354+
#if CUDA_VERSION >= 9010
355+
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
356+
auto cublas_call = [&]() {
357+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
358+
bool use_tensor_op_math = platform::TensorCoreAvailable();
359+
if (use_tensor_op_math) {
360+
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
361+
}
362+
VLOG(5) << "use_tensor_op_math: "
363+
<< (use_tensor_op_math ? "True" : "False");
364+
365+
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
366+
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
367+
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
368+
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
369+
};
370+
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
371+
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
372+
} else {
373+
#endif // CUDA_VERSION >= 9010
374+
375+
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
376+
N, M, K, &alpha, B, ldb, strideB, A, lda,
377+
strideA, &beta, C, ldc, strideC, batchCount);
378+
379+
#if CUDA_VERSION >= 9010
380+
}
381+
#endif // CUDA_VERSION >= 9010
244382
}
245383

246384
} // namespace math

paddle/fluid/platform/device_context.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,39 @@ class CudnnWorkspaceHandle {
143143
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
144144
};
145145

146+
#if CUDA_VERSION >= 9000
147+
class ScopedCublasMathMode {
148+
public:
149+
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
150+
: handle_(handle) {
151+
need_reset = false;
152+
PADDLE_ENFORCE(
153+
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
154+
"Failed to get old cublas math mode");
155+
if (old_math_mode_ != new_math_mode) {
156+
PADDLE_ENFORCE(
157+
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
158+
"Failed to set old cublas math mode");
159+
need_reset = true;
160+
}
161+
}
162+
163+
~ScopedCublasMathMode() {
164+
if (need_reset) {
165+
PADDLE_ENFORCE(
166+
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
167+
"Failed to set old cublas math mode");
168+
}
169+
}
170+
171+
private:
172+
cublasHandle_t handle_;
173+
cublasMath_t old_math_mode_;
174+
bool need_reset;
175+
};
176+
177+
#endif
178+
146179
class CUDADeviceContext : public DeviceContext {
147180
public:
148181
explicit CUDADeviceContext(CUDAPlace place);
@@ -199,6 +232,18 @@ class CUDADeviceContext : public DeviceContext {
199232
callback_manager_->Wait();
200233
}
201234

235+
#if CUDA_VERSION >= 9000
236+
/*! \brief CublasCall may need to change cublas's config,
237+
* but the cublas may be hold by multi-thread, so we should
238+
* add lock here. */
239+
template <typename Callback>
240+
void CublasCall(Callback callback, cublasMath_t new_math) {
241+
std::lock_guard<std::mutex> guard(cublas_mtx_);
242+
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
243+
callback();
244+
}
245+
#endif
246+
202247
private:
203248
CUDAPlace place_;
204249

@@ -220,6 +265,8 @@ class CUDADeviceContext : public DeviceContext {
220265
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
221266
mutable std::mutex callback_mtx_;
222267
std::unique_ptr<StreamCallbackManager> callback_manager_;
268+
269+
mutable std::mutex cublas_mtx_;
223270
};
224271

225272
template <>

paddle/fluid/platform/dynload/cublas.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ extern void *cublas_dso_handle;
6161
extern DynLoad__##__name __name
6262
#endif
6363

64-
#define DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \
65-
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
66-
6764
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
6865
__macro(cublasSaxpy_v2); \
6966
__macro(cublasDaxpy_v2); \
@@ -93,22 +90,23 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
9390

9491
// APIs available after CUDA 8.0
9592
#if CUDA_VERSION >= 8000
96-
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
97-
__macro(cublasGemmEx); \
98-
__macro(cublasSgemmStridedBatched); \
99-
__macro(cublasDgemmStridedBatched); \
100-
__macro(cublasCgemmStridedBatched); \
101-
__macro(cublasZgemmStridedBatched); \
102-
__macro(cublasHgemmStridedBatched);
103-
104-
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
93+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx);
94+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched);
95+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched);
96+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched);
97+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched);
98+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched);
10599
#endif
106100

107101
// APIs available after CUDA 9.0
108102
#if CUDA_VERSION >= 9000
109-
#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) __macro(cublasSetMathMode);
103+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode);
104+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode);
105+
#endif
110106

111-
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
107+
#if CUDA_VERSION >= 9010
108+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx);
109+
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx);
112110
#endif
113111

114112
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP

0 commit comments

Comments
 (0)