@@ -14,11 +14,11 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include < memory>
17
18
#include < vector>
18
19
#include " paddle/fluid/framework/operator_kernel_configs.h"
19
20
#include " paddle/fluid/operators/conv_cudnn_op_cache.h"
20
21
#include " paddle/fluid/platform/cudnn_desc.h"
21
-
22
22
namespace paddle {
23
23
namespace operators {
24
24
@@ -57,16 +57,57 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
57
57
bool deterministic, int algo_cache_id,
58
58
const framework::ExecutionContext& ctx) {
59
59
auto dtype = platform::CudnnDataType<T>::type;
60
+ bool has_got_workspace_size = true ;
60
61
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
61
-
62
62
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024 ;
63
-
63
+ size_t workspace_size = 0 ;
64
64
algo_t algo;
65
+
66
+ #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
67
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
68
+ if (dev_ctx.GetComputeCapability () >= 70 && dtype == CUDNN_DATA_HALF) {
69
+ CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
70
+ args.cdesc .desc (), CUDNN_TENSOR_OP_MATH));
71
+ VLOG (5 ) << " use cudnn_tensor_op_math" ;
72
+ } else {
73
+ CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
74
+ args.cdesc .desc (), CUDNN_DEFAULT_MATH));
75
+ VLOG (5 ) << " NOT use cudnn_tensor_op_math" ;
76
+ }
77
+ #endif
78
+
65
79
if (!exhaustive) {
80
+ #if CUDNN_VERSION >= 7001
81
+ int perf_count;
82
+ int best_algo_idx = 0 ;
83
+ std::unique_ptr<perf_t []> perf_results (new perf_t [kNUM_CUDNN_FWD_ALGS ]);
84
+ CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7 (
85
+ args.handle , args.idesc .desc (), args.wdesc .desc (), args.cdesc .desc (),
86
+ args.odesc .desc (), kNUM_CUDNN_FWD_ALGS , &perf_count,
87
+ perf_results.get ()));
88
+ algo = (perf_results.get ())[best_algo_idx].algo ;
89
+ workspace_size = GetWorkspaceSize (args, algo);
90
+
91
+ if (workspace_size > workspace_size_limit) {
92
+ has_got_workspace_size = false ;
93
+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
94
+ " the workspace size request("
95
+ << workspace_size << " ) exceeds the limit("
96
+ << workspace_size_limit << " )" ;
97
+ }
98
+ if (!has_got_workspace_size) {
99
+ CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardAlgorithm (
100
+ args.handle , args.idesc .desc (), args.wdesc .desc (),
101
+ args.cdesc .desc (), args.odesc .desc (),
102
+ CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit,
103
+ &algo));
104
+ }
105
+ #else
66
106
CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardAlgorithm (
67
107
args.handle , args.idesc .desc (), args.wdesc .desc (), args.cdesc .desc (),
68
108
args.odesc .desc (), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
69
109
workspace_size_limit, &algo));
110
+ #endif
70
111
VLOG (3 ) << " choose algo " << algo;
71
112
} else {
72
113
AlgorithmsCache<algo_t >& algo_cache =
@@ -128,15 +169,72 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
128
169
const framework::ExecutionContext& ctx) {
129
170
auto dtype = platform::CudnnDataType<T>::type;
130
171
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
131
-
132
172
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024 ;
133
-
173
+ size_t workspace_size = 0 ;
174
+ bool has_got_workspace_size = true ;
134
175
algo_t algo;
176
+
177
+ #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
178
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
179
+ if (dev_ctx.GetComputeCapability () >= 70 && dtype == CUDNN_DATA_HALF) {
180
+ CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
181
+ args.cdesc .desc (), CUDNN_TENSOR_OP_MATH));
182
+ VLOG (5 ) << " use cudnn_tensor_op_math" ;
183
+ } else {
184
+ CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
185
+ args.cdesc .desc (), CUDNN_DEFAULT_MATH));
186
+ VLOG (5 ) << " NOT use cudnn_tensor_op_math" ;
187
+ }
188
+ #endif
189
+
135
190
if (!exhaustive && !deterministic) {
191
+ #if CUDNN_VERSION >= 7001
192
+ int perf_count;
193
+ int best_algo_idx = 0 ;
194
+ std::unique_ptr<perf_t []> perf_results (
195
+ new perf_t [kNUM_CUDNN_BWD_DATA_ALGS ]);
196
+ CUDNN_ENFORCE (
197
+ platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7 (
198
+ args.handle , args.wdesc .desc (), args.odesc .desc (),
199
+ args.cdesc .desc (), args.idesc .desc (), kNUM_CUDNN_BWD_DATA_ALGS ,
200
+ &perf_count, perf_results.get ()));
201
+ algo = (perf_results.get ())[best_algo_idx].algo ;
202
+
203
+ #if CUDNN_VERSION < 7500
204
+ int stride_dim = args.x ->dims ().size () - 2 ;
205
+ bool blacklist = std::any_of (args.s .begin (), args.s .begin () + stride_dim,
206
+ [=](int n) { return n != 1 ; });
207
+ if (blacklist && (static_cast <cudnnConvolutionBwdDataAlgo_t>(
208
+ perf_results[best_algo_idx].algo ) ==
209
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
210
+ static_cast <cudnnConvolutionBwdDataAlgo_t>(
211
+ perf_results[best_algo_idx].algo ) ==
212
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
213
+ algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
214
+ }
215
+ #endif
216
+ workspace_size = GetWorkspaceSize (args, algo);
217
+ if (workspace_size > workspace_size_limit) {
218
+ has_got_workspace_size = false ;
219
+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
220
+ " the workspace size request("
221
+ << workspace_size << " ) exceeds the limit("
222
+ << workspace_size_limit << " )" ;
223
+ }
224
+ if (!has_got_workspace_size) {
225
+ CUDNN_ENFORCE (
226
+ platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
227
+ args.handle , args.wdesc .desc (), args.odesc .desc (),
228
+ args.cdesc .desc (), args.idesc .desc (),
229
+ CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
230
+ workspace_size_limit, &algo));
231
+ }
232
+ #else
136
233
CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
137
- args.handle , args.wdesc .desc (), args.idesc .desc (), args.cdesc .desc (),
138
- args.odesc .desc (), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
234
+ args.handle , args.wdesc .desc (), args.odesc .desc (), args.cdesc .desc (),
235
+ args.idesc .desc (), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
139
236
workspace_size_limit, &algo));
237
+ #endif
140
238
} else if (deterministic) {
141
239
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
142
240
} else {
@@ -186,8 +284,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
186
284
size_t workspace_size = 0 ;
187
285
CUDNN_ENFORCE (
188
286
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize (
189
- args.handle , args.wdesc .desc (), args.idesc .desc (),
190
- args.cdesc .desc (), args.odesc .desc (), algo, &workspace_size));
287
+ args.handle , args.wdesc .desc (), args.odesc .desc (),
288
+ args.cdesc .desc (), args.idesc .desc (), algo, &workspace_size));
191
289
return workspace_size;
192
290
}
193
291
};
@@ -203,17 +301,61 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
203
301
const framework::ExecutionContext& ctx) {
204
302
auto dtype = platform::CudnnDataType<T>::type;
205
303
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
206
-
207
304
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024 ;
305
+ size_t workspace_size = 0 ;
306
+ bool has_got_workspace_size = true ;
307
+
308
+ #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
309
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
310
+ if (dev_ctx.GetComputeCapability () >= 70 && dtype == CUDNN_DATA_HALF) {
311
+ CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
312
+ args.cdesc .desc (), CUDNN_TENSOR_OP_MATH));
313
+ VLOG (5 ) << " use cudnn_tensor_op_math" ;
314
+ } else {
315
+ CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
316
+ args.cdesc .desc (), CUDNN_DEFAULT_MATH));
317
+ VLOG (5 ) << " NOT use cudnn_tensor_op_math" ;
318
+ }
319
+ #endif
208
320
209
321
algo_t algo;
210
322
if (!exhaustive && !deterministic) {
323
+ #if CUDNN_VERSION >= 7001
324
+ using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
325
+ int perf_count;
326
+ int best_algo_idx = 0 ;
327
+ std::unique_ptr<perf_t []> perf_results (
328
+ new perf_t [kNUM_CUDNN_BWD_FILTER_ALGS ]);
329
+ CUDNN_ENFORCE (
330
+ platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7 (
331
+ args.handle , args.idesc .desc (), args.odesc .desc (),
332
+ args.cdesc .desc (), args.wdesc .desc (), kNUM_CUDNN_BWD_FILTER_ALGS ,
333
+ &perf_count, perf_results.get ()));
334
+ algo = (perf_results.get ())[best_algo_idx].algo ;
335
+ workspace_size = GetWorkspaceSize (args, algo);
336
+ if (workspace_size > workspace_size_limit) {
337
+ has_got_workspace_size = false ;
338
+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
339
+ " the workspace size request("
340
+ << workspace_size << " ) exceeds the limit("
341
+ << workspace_size_limit << " )" ;
342
+ }
343
+ if (!has_got_workspace_size) {
344
+ CUDNN_ENFORCE (
345
+ platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
346
+ args.handle , args.idesc .desc (), args.odesc .desc (),
347
+ args.cdesc .desc (), args.wdesc .desc (),
348
+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
349
+ workspace_size_limit, &algo));
350
+ }
351
+ #else
211
352
CUDNN_ENFORCE (
212
353
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
213
354
args.handle , args.idesc .desc (), args.odesc .desc (),
214
355
args.cdesc .desc (), args.wdesc .desc (),
215
356
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
216
357
workspace_size_limit, &algo));
358
+ #endif
217
359
} else if (deterministic) {
218
360
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
219
361
} else {
0 commit comments