16
16
17
17
#include " paddle/fluid/operators/math/math_function.h"
18
18
#include " paddle/fluid/platform/dynload/cublas.h"
19
+ #include " paddle/fluid/platform/gpu_info.h"
20
+
21
+ DECLARE_bool (enable_cublas_tensor_op_math);
19
22
20
23
namespace paddle {
21
24
namespace operators {
@@ -42,11 +45,44 @@ struct CUBlas<float> {
42
45
}
43
46
44
47
template <typename ... ARGS>
45
- static void GEMM_BATCH (ARGS... args) {
48
+ static void GEMM_STRIDED_BATCH (ARGS... args) {
46
49
#if CUDA_VERSION >= 8000
47
50
PADDLE_ENFORCE (platform::dynload::cublasSgemmStridedBatched (args...));
48
51
#else
49
52
PADDLE_THROW (" SgemmStridedBatched is not supported on cuda <= 7.5" );
53
+ #endif
54
+ }
55
+
56
+ // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
57
+ // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
58
+ template <typename ... ARGS>
59
+ static void GEMM_EX (platform::CUDADeviceContext *dev_ctx,
60
+ cublasOperation_t transa, cublasOperation_t transb, int m,
61
+ int n, int k, const float *alpha, const void *A,
62
+ cudaDataType_t Atype, int lda, const void *B,
63
+ cudaDataType_t Btype, int ldb, const float *beta, void *C,
64
+ cudaDataType_t Ctype, int ldc) {
65
+ // Because the gcc 4.8 doesn't expand template parameter pack that
66
+ // appears in a lambda-expression, I can not use template parameter pack
67
+ // here.
68
+ auto cublas_call = [&]() {
69
+ #if CUDA_VERSION >= 8000
70
+ VLOG (5 ) << " use_tensor_op_math: "
71
+ << (platform::TensorCoreAvailable () ? " True" : " False" );
72
+ PADDLE_ENFORCE (platform::dynload::cublasSgemmEx (
73
+ dev_ctx->cublas_handle (), transa, transb, m, n, k, alpha, A, Atype,
74
+ lda, B, Btype, ldb, beta, C, Ctype, ldc));
75
+ #else
76
+ PADDLE_THROW (" cublasSgemmEx is supported on cuda >= 8.0" );
77
+ #endif
78
+ };
79
+
80
+ #if CUDA_VERSION >= 9000
81
+ // NOTES: To use Tensor Core, we should change the cublas config,
82
+ // but the cublas may be hold by multi-thread.
83
+ dev_ctx->CublasCall (cublas_call, CUBLAS_TENSOR_OP_MATH);
84
+ #else
85
+ cublas_call ();
50
86
#endif
51
87
}
52
88
};
@@ -69,13 +105,18 @@ struct CUBlas<double> {
69
105
}
70
106
71
107
template <typename ... ARGS>
72
- static void GEMM_BATCH (ARGS... args) {
108
+ static void GEMM_STRIDED_BATCH (ARGS... args) {
73
109
#if CUDA_VERSION >= 8000
74
110
PADDLE_ENFORCE (platform::dynload::cublasDgemmStridedBatched (args...));
75
111
#else
76
112
PADDLE_THROW (" DgemmStridedBatched is not supported on cuda <= 7.5" );
77
113
#endif
78
114
}
115
+
116
+ template <typename ... ARGS>
117
+ static void GEMM_EX (ARGS... args) {
118
+ PADDLE_THROW (" Currently there are not cublasDgemmEx." );
119
+ }
79
120
};
80
121
81
122
template <>
@@ -96,14 +137,16 @@ struct CUBlas<platform::float16> {
96
137
reinterpret_cast <__half *>(C), ldc));
97
138
}
98
139
99
- static void GEMM_BATCH (cublasHandle_t handle, cublasOperation_t transa,
100
- cublasOperation_t transb, int m, int n, int k,
101
- const float16 *alpha, const float16 *A, int lda,
102
- long long int strideA, const float16 *B, // NOLINT
103
- int ldb, long long int strideB, // NOLINT
104
- const float16 *beta, float16 *C, int ldc,
105
- long long int strideC, // NOLINT
106
- int batchCount) {
140
+ static void GEMM_STRIDED_BATCH (cublasHandle_t handle,
141
+ cublasOperation_t transa,
142
+ cublasOperation_t transb, int m, int n, int k,
143
+ const float16 *alpha, const float16 *A,
144
+ int lda, long long int strideA, // NOLINT
145
+ const float16 *B, // NOLINT
146
+ int ldb, long long int strideB, // NOLINT
147
+ const float16 *beta, float16 *C, int ldc,
148
+ long long int strideC, // NOLINT
149
+ int batchCount) {
107
150
#if CUDA_VERSION >= 8000
108
151
PADDLE_ENFORCE (platform::dynload::cublasHgemmStridedBatched (
109
152
handle, transa, transb, m, n, k,
@@ -114,6 +157,45 @@ struct CUBlas<platform::float16> {
114
157
ldc, strideC, batchCount));
115
158
#else
116
159
PADDLE_THROW (" HgemmStridedBatched is not supported on cuda <= 7.5" );
160
+ #endif
161
+ }
162
+
163
+ // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
164
+ // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
165
+ template <typename ... ARGS>
166
+ static void GEMM_EX (platform::CUDADeviceContext *dev_ctx,
167
+ cublasOperation_t transa, cublasOperation_t transb, int m,
168
+ int n, int k, const void *alpha, const void *A,
169
+ cudaDataType_t Atype, int lda, const void *B,
170
+ cudaDataType_t Btype, int ldb, const void *beta, void *C,
171
+ cudaDataType_t Ctype, int ldc,
172
+ cudaDataType_t computeType) {
173
+ auto cublas_call = [&]() {
174
+ #if CUDA_VERSION >= 8000
175
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
176
+ #if CUDA_VERSION >= 9000
177
+ bool use_tensor_op_math = platform::TensorCoreAvailable ();
178
+ if (use_tensor_op_math) {
179
+ algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
180
+ }
181
+ VLOG (5 ) << " use_tensor_op_math: "
182
+ << (use_tensor_op_math ? " True" : " False" );
183
+ #endif // CUDA_VERSION >= 9000
184
+
185
+ PADDLE_ENFORCE (platform::dynload::cublasGemmEx (
186
+ dev_ctx->cublas_handle (), transa, transb, m, n, k, alpha, A, Atype,
187
+ lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
188
+ #else
189
+ PADDLE_THROW (" cublasGemmEx is supported on cuda >= 8.0" );
190
+ #endif
191
+ };
192
+
193
+ #if CUDA_VERSION >= 9000
194
+ // NOTES: To use Tensor Core, we should change the cublas config,
195
+ // but the cublas may be hold by multi-thread.
196
+ dev_ctx->CublasCall (cublas_call, CUBLAS_TENSOR_OP_MATH);
197
+ #else
198
+ cublas_call ();
117
199
#endif
118
200
}
119
201
};
@@ -133,8 +215,21 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
133
215
cublasOperation_t cuTransB =
134
216
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
135
217
136
- CUBlas<T>::GEMM (context_.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha,
137
- B, ldb, A, lda, &beta, C, N);
218
+ #if CUDA_VERSION >= 8000
219
+ if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float >::value) {
220
+ auto &cuda_ctx = const_cast <platform::CUDADeviceContext &>(context_);
221
+ CUBlas<T>::GEMM_EX (&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
222
+ CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
223
+ CUDA_R_32F, N);
224
+ } else {
225
+ #endif // CUDA_VERSION >= 8000
226
+
227
+ CUBlas<T>::GEMM (context_.cublas_handle (), cuTransB, cuTransA, N, M, K,
228
+ &alpha, B, ldb, A, lda, &beta, C, N);
229
+
230
+ #if CUDA_VERSION >= 8000
231
+ }
232
+ #endif // CUDA_VERSION >= 8000
138
233
}
139
234
140
235
template <>
@@ -157,30 +252,18 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
157
252
PADDLE_ENFORCE_GE (context_.GetComputeCapability (), 53 ,
158
253
" cublas fp16 gemm requires GPU compute capability >= 53" );
159
254
160
- #if CUDA_VERSION >= 8000
161
255
float h_alpha = static_cast <float >(alpha);
162
256
float h_beta = static_cast <float >(beta);
163
257
164
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
165
- #if CUDA_VERSION >= 9000
166
- if (context_.GetComputeCapability () >= 70 ) {
167
- PADDLE_ENFORCE (platform::dynload::cublasSetMathMode (
168
- context_.cublas_handle (), CUBLAS_TENSOR_OP_MATH));
169
- algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
170
- } else {
171
- PADDLE_ENFORCE (platform::dynload::cublasSetMathMode (
172
- context_.cublas_handle (), CUBLAS_DEFAULT_MATH));
173
- }
174
- #endif // CUDA_VERSION >= 9000
175
-
258
+ #if CUDA_VERSION >= 8000
176
259
// cublasHgemm does true FP16 computation which is slow for non-Volta
177
260
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
178
261
// input/output in fp16, computation in fp32, which can also be accelerated
179
262
// using tensor cores in volta GPUs.
180
- PADDLE_ENFORCE ( platform::dynload::cublasGemmEx (
181
- context_. cublas_handle (), cuTransB, cuTransA, N, M, K, &h_alpha, B,
182
- CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C , CUDA_R_16F, N ,
183
- CUDA_R_32F, algo) );
263
+ auto &cuda_ctx = const_cast < platform::CUDADeviceContext &>(context_);
264
+ CUBlas<platform::float16>:: GEMM_EX (
265
+ &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B , CUDA_R_16F, ldb, A ,
266
+ CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F );
184
267
#else
185
268
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
186
269
CUBlas<platform::float16>::GEMM (context_.cublas_handle (), cuTransB, cuTransA,
@@ -199,8 +282,38 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
199
282
// the cblas convention.
200
283
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
201
284
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
202
- CUBlas<T>::GEMM (context_.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha,
203
- B, ldb, A, lda, &beta, C, ldc);
285
+
286
+ #if CUDA_VERSION >= 8000
287
+ if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float >::value) {
288
+ auto &cuda_ctx = const_cast <platform::CUDADeviceContext &>(context_);
289
+ CUBlas<T>::GEMM_EX (&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
290
+ CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
291
+ CUDA_R_32F, ldc);
292
+ } else {
293
+ #endif // CUDA_VERSION >= 8000
294
+
295
+ CUBlas<T>::GEMM (context_.cublas_handle (), cuTransB, cuTransA, N, M, K,
296
+ &alpha, B, ldb, A, lda, &beta, C, ldc);
297
+
298
+ #if CUDA_VERSION >= 8000
299
+ }
300
+ #endif // CUDA_VERSION >= 8000
301
+ }
302
+
303
+ template <>
304
+ template <>
305
+ inline void Blas<platform::CUDADeviceContext>::GEMM (
306
+ bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
307
+ const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
308
+ platform::float16 beta, platform::float16 *C, int ldc) const {
309
+ // Note that cublas follows fortran order, so the order is different from
310
+ // the cblas convention.
311
+ cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
312
+ cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
313
+
314
+ CUBlas<platform::float16>::GEMM (context_.cublas_handle (), cuTransB, cuTransA,
315
+ N, M, K, &alpha, B, ldb, A, lda, &beta, C,
316
+ ldc);
204
317
}
205
318
206
319
template <>
@@ -238,9 +351,34 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
238
351
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
239
352
const int64_t strideC = M * N;
240
353
241
- CUBlas<T>::GEMM_BATCH (context_.cublas_handle (), cuTransB, cuTransA, N, M, K,
242
- &alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc,
243
- strideC, batchCount);
354
+ #if CUDA_VERSION >= 9010
355
+ if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float >::value) {
356
+ auto cublas_call = [&]() {
357
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
358
+ bool use_tensor_op_math = platform::TensorCoreAvailable ();
359
+ if (use_tensor_op_math) {
360
+ algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
361
+ }
362
+ VLOG (5 ) << " use_tensor_op_math: "
363
+ << (use_tensor_op_math ? " True" : " False" );
364
+
365
+ PADDLE_ENFORCE (platform::dynload::cublasGemmStridedBatchedEx (
366
+ context_.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B,
367
+ CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
368
+ CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
369
+ };
370
+ auto &dev_ctx = const_cast <platform::CUDADeviceContext &>(context_);
371
+ dev_ctx.CublasCall (cublas_call, CUBLAS_TENSOR_OP_MATH);
372
+ } else {
373
+ #endif // CUDA_VERSION >= 9010
374
+
375
+ CUBlas<T>::GEMM_STRIDED_BATCH (context_.cublas_handle (), cuTransB, cuTransA,
376
+ N, M, K, &alpha, B, ldb, strideB, A, lda,
377
+ strideA, &beta, C, ldc, strideC, batchCount);
378
+
379
+ #if CUDA_VERSION >= 9010
380
+ }
381
+ #endif // CUDA_VERSION >= 9010
244
382
}
245
383
246
384
} // namespace math
0 commit comments