Skip to content

Commit 138a71b

Browse files
authored
Add tf32 switch for cuDNN (#29192) (#30574)
This PR is cherry-picked from PR: #29192 Function: Added TF32 switch for cuDNN. Turned on as default, turned off when users set the switch as False
1 parent 12c51f5 commit 138a71b

File tree

10 files changed

+124
-20
lines changed

10 files changed

+124
-20
lines changed

paddle/fluid/operators/conv_cudnn_helper.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
210210

211211
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
212212
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
213+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
214+
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
215+
VLOG(5) << "NOT use cudnn_tensor_op_math";
213216
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
214217
PADDLE_ENFORCE_CUDA_SUCCESS(
215218
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
216219
CUDNN_TENSOR_OP_MATH));
217220
VLOG(5) << "use cudnn_tensor_op_math";
218-
} else {
221+
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
222+
#if CUDA_VERSION >= 11000
219223
PADDLE_ENFORCE_CUDA_SUCCESS(
220224
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
221-
CUDNN_DEFAULT_MATH));
222-
VLOG(5) << "NOT use cudnn_tensor_op_math";
225+
CUDNN_FMA_MATH));
226+
#endif // CUDA_VERSION >= 11000
223227
}
224228
#endif
225229

@@ -340,16 +344,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
340344
algo_t algo;
341345
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
342346
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
347+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
348+
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
349+
VLOG(5) << "NOT use cudnn_tensor_op_math";
343350
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
344351
PADDLE_ENFORCE_CUDA_SUCCESS(
345352
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
346353
CUDNN_TENSOR_OP_MATH));
347354
VLOG(5) << "use cudnn_tensor_op_math";
348-
} else {
355+
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
356+
#if CUDA_VERSION >= 11000
349357
PADDLE_ENFORCE_CUDA_SUCCESS(
350358
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
351-
CUDNN_DEFAULT_MATH));
352-
VLOG(5) << "NOT use cudnn_tensor_op_math";
359+
CUDNN_FMA_MATH));
360+
#endif // CUDA_VERSION >= 11000
353361
}
354362
#endif
355363

@@ -485,16 +493,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
485493

486494
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
487495
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
496+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
497+
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
498+
VLOG(5) << "NOT use cudnn_tensor_op_math";
488499
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
489500
PADDLE_ENFORCE_CUDA_SUCCESS(
490501
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
491502
CUDNN_TENSOR_OP_MATH));
492503
VLOG(5) << "use cudnn_tensor_op_math";
493-
} else {
504+
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
505+
#if CUDA_VERSION >= 11000
494506
PADDLE_ENFORCE_CUDA_SUCCESS(
495507
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
496-
CUDNN_DEFAULT_MATH));
497-
VLOG(5) << "NOT use cudnn_tensor_op_math";
508+
CUDNN_FMA_MATH));
509+
#endif // CUDA_VERSION >= 11000
498510
}
499511
#endif
500512

paddle/fluid/operators/conv_cudnn_op.cu

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
240240
auto layout_format = GetCudnnTensorFormat(layout);
241241

242242
args.handle = handle;
243-
args.cdesc.set(dtype, padding_common, strides, dilations);
243+
args.cdesc.set(dtype, padding_common, strides, dilations,
244+
platform::AllowTF32Cudnn());
244245

245246
#if CUDNN_VERSION_MIN(7, 0, 1)
246247
// cudnn 7 can support groups, no need to do it manually
@@ -603,7 +604,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
603604
args1.idesc.set(transformed_input_grad, layout_tensor);
604605
args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
605606
args1.odesc.set(transformed_output_grad_channel, layout_tensor);
606-
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
607+
args1.cdesc.set(dtype, padding_common, strides, dilations,
608+
platform::AllowTF32Cudnn(), c_groups);
607609

608610
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
609611
data_algo =
@@ -620,7 +622,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
620622
args2.wdesc.set(transformed_filter_grad_channel, layout_tensor,
621623
iwo_groups);
622624
args2.odesc.set(transformed_output_grad_channel, layout_tensor);
623-
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
625+
args2.cdesc.set(dtype, padding_common, strides, dilations,
626+
platform::AllowTF32Cudnn(), c_groups);
624627

625628
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
626629
filter_algo =
@@ -980,7 +983,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
980983
args1.idesc.set(transformed_ddX, iwo_group);
981984
args1.wdesc.set(*W, layout, iwo_group);
982985
args1.odesc.set(transformed_ddO_channel, iwo_group);
983-
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
986+
args1.cdesc.set(dtype, padding_common, strides, dilations,
987+
platform::AllowTF32Cudnn(), c_group);
984988

985989
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
986990
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
@@ -995,7 +999,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
995999
args2.wdesc.set(*ddW, layout, iwo_group);
9961000

9971001
args2.odesc.set(transformed_ddO_channel, iwo_group);
998-
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
1002+
args2.cdesc.set(dtype, padding_common, strides, dilations,
1003+
platform::AllowTF32Cudnn(), c_group);
9991004

10001005
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
10011006
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
@@ -1012,7 +1017,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
10121017

10131018
args3.odesc.set(transformed_dO_channel, iwo_group);
10141019

1015-
args3.cdesc.set(dtype, padding_common, strides, dilations, c_group);
1020+
args3.cdesc.set(dtype, padding_common, strides, dilations,
1021+
platform::AllowTF32Cudnn(), c_group);
10161022

10171023
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
10181024
filter_algo =
@@ -1028,7 +1034,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
10281034
args4.idesc.set(transformed_dX, iwo_group);
10291035
args4.wdesc.set(*ddW, layout, iwo_group);
10301036
args4.odesc.set(transformed_dO_channel, iwo_group);
1031-
args4.cdesc.set(dtype, padding_common, strides, dilations, c_group);
1037+
args4.cdesc.set(dtype, padding_common, strides, dilations,
1038+
platform::AllowTF32Cudnn(), c_group);
10321039

10331040
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
10341041
data_algo =

paddle/fluid/operators/conv_transpose_cudnn_op.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
232232
args.idesc.set(transformed_output, iwo_groups);
233233
args.wdesc.set(*filter, layout_tensor, iwo_groups);
234234
args.odesc.set(transformed_input, iwo_groups);
235-
args.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
235+
args.cdesc.set(dtype, padding_common, strides, dilations,
236+
platform::AllowTF32Cudnn(), c_groups);
236237

237238
using search = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
238239
algo = search::Find<T>(args, false, deterministic, ctx);
@@ -468,7 +469,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
468469
args1.idesc.set(transformed_output_grad, iwo_groups);
469470
args1.wdesc.set(*filter, layout_tensor, iwo_groups);
470471
args1.odesc.set(input_transpose, iwo_groups);
471-
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
472+
args1.cdesc.set(dtype, padding_common, strides, dilations,
473+
platform::AllowTF32Cudnn(), c_groups);
472474
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
473475
data_algo = search1::Find<T>(args1, false, deterministic, ctx);
474476
workspace_size =
@@ -481,7 +483,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
481483
args2.idesc.set(transformed_output_grad, iwo_groups);
482484
args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups);
483485
args2.odesc.set(input_transpose, iwo_groups);
484-
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
486+
args2.cdesc.set(dtype, padding_common, strides, dilations,
487+
platform::AllowTF32Cudnn(), c_groups);
485488
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
486489
filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
487490
workspace_size = std::max(workspace_size,

paddle/fluid/operators/fused/conv_fusion_op.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
200200

201201
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
202202
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
203+
#if CUDNN_VERSION >= 11000
204+
if (!platform::allow_tf32_cudnn) {
205+
PADDLE_ENFORCE_CUDA_SUCCESS(
206+
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
207+
CUDNN_FMA_MATH));
208+
}
209+
#endif // CUDA_VERSION >= 11000
203210

204211
auto x_dims = framework::vectorize(transformed_input.dims());
205212
auto f_dims = framework::vectorize(filter->dims());

paddle/fluid/operators/fused/fusion_conv_inception_op.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
153153
PADDLE_ENFORCE_CUDA_SUCCESS(
154154
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
155155
CUDNN_DEFAULT_MATH));
156+
#if CUDNN_VERSION >= 11000
157+
if (!platform::allow_tf32_cudnn) {
158+
PADDLE_ENFORCE_CUDA_SUCCESS(
159+
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
160+
CUDNN_FMA_MATH));
161+
}
162+
#endif // CUDA_VERSION >= 11000
156163
}
157164
in_dims[2][1] *= 2;
158165
in_strides[2][0] = oc * h * w;

paddle/fluid/platform/cudnn_desc.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <vector>
2525

2626
#include "paddle/fluid/platform/cudnn_helper.h"
27+
#include "paddle/fluid/platform/device_context.h"
2728

2829
namespace paddle {
2930
namespace framework {
@@ -229,7 +230,8 @@ class ConvolutionDescriptor {
229230

230231
void set(cudnnDataType_t dtype, const std::vector<int>& pads,
231232
const std::vector<int>& strides, const std::vector<int>& dilations,
232-
const int groups = 1) {
233+
bool allow_tf32, const int groups = 1) {
234+
allow_tf32_ = allow_tf32;
233235
cudnnDataType_t compute_type =
234236
(dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
235237
T* desc = desc_.get();
@@ -246,11 +248,18 @@ class ConvolutionDescriptor {
246248
PADDLE_ENFORCE_CUDA_SUCCESS(
247249
platform::dynload::cudnnSetConvolutionMathType(desc,
248250
CUDNN_TENSOR_OP_MATH));
251+
} else if (dtype == CUDNN_DATA_FLOAT && !allow_tf32) {
252+
#if CUDA_VERSION >= 11000
253+
PADDLE_ENFORCE_CUDA_SUCCESS(
254+
platform::dynload::cudnnSetConvolutionMathType(desc, CUDNN_FMA_MATH));
255+
#endif // CUDA_VERSION >= 11000
249256
}
250257
#endif
251258
#endif
252259
}
253260

261+
bool allow_tf32_;
262+
254263
private:
255264
std::unique_ptr<T, Deleter> desc_;
256265
};

paddle/fluid/platform/device_context.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
7070
namespace paddle {
7171
namespace platform {
7272

73+
#ifdef PADDLE_WITH_CUDA
74+
bool allow_tf32_cublas = true;
75+
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
76+
bool AllowTF32Cublas() { return allow_tf32_cublas; }
77+
78+
bool allow_tf32_cudnn = true;
79+
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
80+
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
81+
#endif // PADDLE_WITH_CUDA
82+
7383
DeviceContextPool* DeviceContextPool::pool = nullptr;
7484

7585
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {

paddle/fluid/platform/device_context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ namespace platform {
6767
void SetAllowTF32Cublas(bool active);
6868
/*Get the global variable allow_tf32_cublas value*/
6969
bool AllowTF32Cublas();
70+
/*Set the value of the global variable allow_tf32_cudnn*/
71+
void SetAllowTF32Cudnn(bool active);
72+
/*Get the global variable allow_tf32_cudnn value*/
73+
bool AllowTF32Cudnn();
7074
#endif // PADDLE_WITH_CUDA
7175

7276
enum DeviceType {

paddle/fluid/pybind/pybind.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,13 @@ All parameter, weight, gradient are variables in Paddle.
19871987

19881988
m.def("size_of_dtype", framework::SizeOfType);
19891989

1990+
#ifdef PADDLE_WITH_CUDA
1991+
m.def("set_cublas_switch", platform::SetAllowTF32Cublas);
1992+
m.def("get_cublas_switch", platform::AllowTF32Cublas);
1993+
m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn);
1994+
m.def("get_cudnn_switch", platform::AllowTF32Cudnn);
1995+
#endif // PADDLE_WITH_CUDA
1996+
19901997
using VarQuantScale =
19911998
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
19921999

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import six
17+
import numpy as np
18+
import paddle
19+
import paddle.fluid as fluid
20+
import paddle.fluid.core as core
21+
22+
23+
class TestTF32Switch(unittest.TestCase):
24+
def test_on_off(self):
25+
if core.is_compiled_with_cuda():
26+
self.assertTrue(core.get_cudnn_switch()) # default
27+
core.set_cudnn_switch(0)
28+
self.assertFalse(core.get_cudnn_switch()) # turn off
29+
core.set_cudnn_switch(1)
30+
self.assertTrue(core.get_cudnn_switch()) # turn on
31+
32+
core.set_cudnn_switch(1) # restore the switch
33+
else:
34+
pass
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

0 commit comments

Comments
 (0)