@@ -23,6 +23,16 @@ limitations under the License. */
23
23
#include " paddle/fluid/platform/cudnn_helper.h"
24
24
#include " paddle/fluid/platform/float16.h"
25
25
26
+ // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
27
+ // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
28
+ // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
29
+ // reason we set it to false by default is that this mode may use scaled
30
+ // atomic integer reduction that may cause a numerical overflow for certain
31
+ // input data range.
32
+ DEFINE_bool (cudnn_batchnorm_spatial_persistent, false ,
33
+ " Whether enable CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode for cudnn "
34
+ " batch_norm, defalut is False." );
35
+
26
36
namespace paddle {
27
37
namespace operators {
28
38
@@ -76,7 +86,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
76
86
}
77
87
epsilon = std::max (epsilon, CUDNN_BN_MIN_EPSILON);
78
88
#if CUDNN_VERSION_MIN(7, 0, 0)
79
- mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
89
+ if (FLAGS_cudnn_batchnorm_spatial_persistent) {
90
+ mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
91
+ } else {
92
+ mode_ = CUDNN_BATCHNORM_SPATIAL;
93
+ }
80
94
#else
81
95
mode_ = CUDNN_BATCHNORM_SPATIAL;
82
96
#endif
@@ -302,7 +316,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
302
316
}
303
317
epsilon = std::max (epsilon, CUDNN_BN_MIN_EPSILON);
304
318
#if CUDNN_VERSION_MIN(7, 0, 0)
305
- mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
319
+ if (FLAGS_cudnn_batchnorm_spatial_persistent) {
320
+ mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
321
+ } else {
322
+ mode_ = CUDNN_BATCHNORM_SPATIAL;
323
+ }
306
324
#else
307
325
mode_ = CUDNN_BATCHNORM_SPATIAL;
308
326
#endif
0 commit comments