Skip to content

Commit 187ba08

Browse files
committed
enable tensor core for conv cudnn
1 parent 66e0aed commit 187ba08

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,32 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
128128
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
129129
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
130130
workspace_size_limit, &algo));
131+
132+
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
133+
// Tensor core is supported since the volta GPU and
134+
// is only enabled when input and filter data are float16
135+
if (dev_ctx.GetComputeCapability() >= 70 &&
136+
std::type_index(typeid(T)) ==
137+
std::type_index(typeid(platform::float16))) {
138+
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
139+
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
140+
// Currently tensor core is only enabled using this algo
141+
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
142+
} else {
143+
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
144+
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
145+
}
146+
#endif
147+
131148
// get workspace size able to allocate
132149
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
133150
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
134151
cudnn_output_desc, algo, &workspace_size_in_bytes));
152+
// It is possible for float16 on Volta GPU to allocate more memory than
153+
// the limit because the algo is overrided to use tensor core.
154+
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
155+
"workspace_size to be allocated exceeds the limit");
156+
135157
// Allocate on GPU memory
136158
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
137159
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);

paddle/fluid/platform/cudnn_helper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,11 @@ class ScopedConvolutionDescriptor {
257257
}
258258
#endif
259259

260+
cudnnDataType_t compute_type =
261+
(type == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
260262
PADDLE_ENFORCE(dynload::cudnnSetConvolutionNdDescriptor(
261263
desc_, pads.size(), pads.data(), strides.data(), dilations.data(),
262-
CUDNN_CROSS_CORRELATION, type));
264+
CUDNN_CROSS_CORRELATION, compute_type));
263265
return desc_;
264266
}
265267

paddle/fluid/platform/dynload/cudnn.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616

1717
#include <cudnn.h>
1818
#include <dlfcn.h>
19-
#include <mutex>
2019
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
2120

2221
namespace paddle {
@@ -140,7 +139,8 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
140139

141140
#if CUDNN_VERSION >= 7001
142141
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
143-
__macro(cudnnSetConvolutionGroupCount);
142+
__macro(cudnnSetConvolutionGroupCount); \
143+
__macro(cudnnSetConvolutionMathType);
144144
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
145145
#endif
146146

0 commit comments

Comments
 (0)