@@ -20,6 +20,11 @@ limitations under the License. */
20
20
#include " paddle/fluid/platform/cudnn_helper.h"
21
21
#include " paddle/fluid/platform/float16.h"
22
22
23
+ DEFINE_bool (cudnn_algo_use_autotune, true ,
24
+ " Whether allow using an autotuning algorithm for convolution "
25
+ " operator. The autotuning algorithm may be non-deterministic. If "
26
+ " false, the algorithm is deterministic." );
27
+
23
28
namespace paddle {
24
29
namespace operators {
25
30
@@ -267,17 +272,23 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
267
272
auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
268
273
auto handle = dev_ctx.cudnn_handle ();
269
274
if (input_grad) {
270
- PADDLE_ENFORCE (
271
- platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
272
- handle, cudnn_filter_desc,
273
- // dyDesc: Handle to the previously initialized input differential
274
- // tensor descriptor.
275
- cudnn_output_grad_desc, cudnn_conv_desc,
276
- // dxDesc: Handle to the previously initialized output tensor
277
- // descriptor.
278
- cudnn_input_desc,
279
- CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
280
- workspace_size_limit, &data_algo));
275
+ if (FLAGS_cudnn_algo_use_autotune) {
276
+ PADDLE_ENFORCE (
277
+ platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
278
+ handle, cudnn_filter_desc,
279
+ // dyDesc: Handle to the previously initialized input
280
+ // differential
281
+ // tensor descriptor.
282
+ cudnn_output_grad_desc, cudnn_conv_desc,
283
+ // dxDesc: Handle to the previously initialized output tensor
284
+ // descriptor.
285
+ cudnn_input_desc,
286
+ CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
287
+ workspace_size_limit, &data_algo));
288
+ } else {
289
+ data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
290
+ }
291
+
281
292
PADDLE_ENFORCE (
282
293
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize (
283
294
handle, cudnn_filter_desc, cudnn_output_grad_desc,
@@ -286,12 +297,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
286
297
}
287
298
288
299
if (filter_grad) {
289
- PADDLE_ENFORCE (
290
- platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
291
- handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
292
- cudnn_filter_desc,
293
- CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
294
- workspace_size_limit, &filter_algo));
300
+ if (FLAGS_cudnn_algo_use_autotune) {
301
+ PADDLE_ENFORCE (
302
+ platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
303
+ handle, cudnn_input_desc, cudnn_output_grad_desc,
304
+ cudnn_conv_desc, cudnn_filter_desc,
305
+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
306
+ workspace_size_limit, &filter_algo));
307
+ } else {
308
+ filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
309
+ }
295
310
296
311
PADDLE_ENFORCE (
297
312
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize (
0 commit comments