Skip to content

Commit 1b22dd2

Browse files
Cudnn convolution reconstruction (#18284) (#18776)
* rewrite the conv_op using cudnn_conv_helper * add workspace limit for v7 test=develop * fix test=develop * add half float test=develop * fix test=develop * fix test=develop * revise code style test=develop * fix test=develop
1 parent 7af67f9 commit 1b22dd2

File tree

2 files changed

+253
-447
lines changed

2 files changed

+253
-447
lines changed

paddle/fluid/operators/conv_cudnn_helper.h

Lines changed: 152 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <memory>
1718
#include <vector>
1819
#include "paddle/fluid/framework/operator_kernel_configs.h"
1920
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
2021
#include "paddle/fluid/platform/cudnn_desc.h"
21-
2222
namespace paddle {
2323
namespace operators {
2424

@@ -57,16 +57,57 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
5757
bool deterministic, int algo_cache_id,
5858
const framework::ExecutionContext& ctx) {
5959
auto dtype = platform::CudnnDataType<T>::type;
60+
bool has_got_workspace_size = true;
6061
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
61-
6262
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
63-
63+
size_t workspace_size = 0;
6464
algo_t algo;
65+
66+
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
67+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
68+
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
69+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
70+
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
71+
VLOG(5) << "use cudnn_tensor_op_math";
72+
} else {
73+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
74+
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
75+
VLOG(5) << "NOT use cudnn_tensor_op_math";
76+
}
77+
#endif
78+
6579
if (!exhaustive) {
80+
#if CUDNN_VERSION >= 7001
81+
int perf_count;
82+
int best_algo_idx = 0;
83+
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
84+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
85+
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
86+
args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, &perf_count,
87+
perf_results.get()));
88+
algo = (perf_results.get())[best_algo_idx].algo;
89+
workspace_size = GetWorkspaceSize(args, algo);
90+
91+
if (workspace_size > workspace_size_limit) {
92+
has_got_workspace_size = false;
93+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
94+
"the workspace size request("
95+
<< workspace_size << ") exceeds the limit("
96+
<< workspace_size_limit << ")";
97+
}
98+
if (!has_got_workspace_size) {
99+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
100+
args.handle, args.idesc.desc(), args.wdesc.desc(),
101+
args.cdesc.desc(), args.odesc.desc(),
102+
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit,
103+
&algo));
104+
}
105+
#else
66106
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
67107
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
68108
args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
69109
workspace_size_limit, &algo));
110+
#endif
70111
VLOG(3) << "choose algo " << algo;
71112
} else {
72113
AlgorithmsCache<algo_t>& algo_cache =
@@ -128,15 +169,72 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
128169
const framework::ExecutionContext& ctx) {
129170
auto dtype = platform::CudnnDataType<T>::type;
130171
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
131-
132172
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
133-
173+
size_t workspace_size = 0;
174+
bool has_got_workspace_size = true;
134175
algo_t algo;
176+
177+
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
178+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
179+
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
180+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
181+
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
182+
VLOG(5) << "use cudnn_tensor_op_math";
183+
} else {
184+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
185+
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
186+
VLOG(5) << "NOT use cudnn_tensor_op_math";
187+
}
188+
#endif
189+
135190
if (!exhaustive && !deterministic) {
191+
#if CUDNN_VERSION >= 7001
192+
int perf_count;
193+
int best_algo_idx = 0;
194+
std::unique_ptr<perf_t[]> perf_results(
195+
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
196+
CUDNN_ENFORCE(
197+
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
198+
args.handle, args.wdesc.desc(), args.odesc.desc(),
199+
args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
200+
&perf_count, perf_results.get()));
201+
algo = (perf_results.get())[best_algo_idx].algo;
202+
203+
#if CUDNN_VERSION < 7500
204+
int stride_dim = args.x->dims().size() - 2;
205+
bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim,
206+
[=](int n) { return n != 1; });
207+
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
208+
perf_results[best_algo_idx].algo) ==
209+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
210+
static_cast<cudnnConvolutionBwdDataAlgo_t>(
211+
perf_results[best_algo_idx].algo) ==
212+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
213+
algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
214+
}
215+
#endif
216+
workspace_size = GetWorkspaceSize(args, algo);
217+
if (workspace_size > workspace_size_limit) {
218+
has_got_workspace_size = false;
219+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
220+
"the workspace size request("
221+
<< workspace_size << ") exceeds the limit("
222+
<< workspace_size_limit << ")";
223+
}
224+
if (!has_got_workspace_size) {
225+
CUDNN_ENFORCE(
226+
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
227+
args.handle, args.wdesc.desc(), args.odesc.desc(),
228+
args.cdesc.desc(), args.idesc.desc(),
229+
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
230+
workspace_size_limit, &algo));
231+
}
232+
#else
136233
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
137-
args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(),
138-
args.odesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
234+
args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(),
235+
args.idesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
139236
workspace_size_limit, &algo));
237+
#endif
140238
} else if (deterministic) {
141239
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
142240
} else {
@@ -186,8 +284,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
186284
size_t workspace_size = 0;
187285
CUDNN_ENFORCE(
188286
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
189-
args.handle, args.wdesc.desc(), args.idesc.desc(),
190-
args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
287+
args.handle, args.wdesc.desc(), args.odesc.desc(),
288+
args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
191289
return workspace_size;
192290
}
193291
};
@@ -203,17 +301,61 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
203301
const framework::ExecutionContext& ctx) {
204302
auto dtype = platform::CudnnDataType<T>::type;
205303
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
206-
207304
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
305+
size_t workspace_size = 0;
306+
bool has_got_workspace_size = true;
307+
308+
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
309+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
310+
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
311+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
312+
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
313+
VLOG(5) << "use cudnn_tensor_op_math";
314+
} else {
315+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
316+
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
317+
VLOG(5) << "NOT use cudnn_tensor_op_math";
318+
}
319+
#endif
208320

209321
algo_t algo;
210322
if (!exhaustive && !deterministic) {
323+
#if CUDNN_VERSION >= 7001
324+
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
325+
int perf_count;
326+
int best_algo_idx = 0;
327+
std::unique_ptr<perf_t[]> perf_results(
328+
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
329+
CUDNN_ENFORCE(
330+
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
331+
args.handle, args.idesc.desc(), args.odesc.desc(),
332+
args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
333+
&perf_count, perf_results.get()));
334+
algo = (perf_results.get())[best_algo_idx].algo;
335+
workspace_size = GetWorkspaceSize(args, algo);
336+
if (workspace_size > workspace_size_limit) {
337+
has_got_workspace_size = false;
338+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
339+
"the workspace size request("
340+
<< workspace_size << ") exceeds the limit("
341+
<< workspace_size_limit << ")";
342+
}
343+
if (!has_got_workspace_size) {
344+
CUDNN_ENFORCE(
345+
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
346+
args.handle, args.idesc.desc(), args.odesc.desc(),
347+
args.cdesc.desc(), args.wdesc.desc(),
348+
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
349+
workspace_size_limit, &algo));
350+
}
351+
#else
211352
CUDNN_ENFORCE(
212353
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
213354
args.handle, args.idesc.desc(), args.odesc.desc(),
214355
args.cdesc.desc(), args.wdesc.desc(),
215356
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
216357
workspace_size_limit, &algo));
358+
#endif
217359
} else if (deterministic) {
218360
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
219361
} else {

0 commit comments

Comments
 (0)