Skip to content

Commit 7971d4a

Browse files
authored
Feature/deterministic (#11205)
* "fix deterministic" * "fix ci" * "fix init"
1 parent 9a8b3bc commit 7971d4a

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/platform/cudnn_helper.h"
2121
#include "paddle/fluid/platform/float16.h"
2222

23-
DEFINE_bool(cudnn_algo_use_autotune, true,
23+
DEFINE_bool(cudnn_deterministic, true,
2424
"Whether allow using an autotuning algorithm for convolution "
2525
"operator. The autotuning algorithm may be non-deterministic. If "
2626
"false, the algorithm is deterministic.");
@@ -272,7 +272,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
272272
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
273273
auto handle = dev_ctx.cudnn_handle();
274274
if (input_grad) {
275-
if (FLAGS_cudnn_algo_use_autotune) {
275+
if (FLAGS_cudnn_deterministic) {
276276
PADDLE_ENFORCE(
277277
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
278278
handle, cudnn_filter_desc,
@@ -297,7 +297,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
297297
}
298298

299299
if (filter_grad) {
300-
if (FLAGS_cudnn_algo_use_autotune) {
300+
if (FLAGS_cudnn_deterministic) {
301301
PADDLE_ENFORCE(
302302
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
303303
handle, cudnn_input_desc, cudnn_output_grad_desc,

paddle/fluid/operators/pool_cudnn_op.cu.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,11 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
135135

136136
PoolingMode pooling_mode;
137137
if (pooling_type == "max") {
138-
pooling_mode = PoolingMode::kMaximum;
138+
if (FLAGS_cudnn_deterministic) {
139+
pooling_mode = PoolingMode::kMaximumDeterministic;
140+
} else {
141+
pooling_mode = PoolingMode::kMaximum;
142+
}
139143
} else {
140144
pooling_mode = PoolingMode::kAverage;
141145
}

paddle/fluid/platform/cudnn_helper.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ limitations under the License. */
2222
#include "paddle/fluid/platform/float16.h"
2323
#include "paddle/fluid/platform/macros.h"
2424

25+
DECLARE_bool(cudnn_deterministic);
26+
2527
namespace paddle {
2628
namespace platform {
2729

@@ -76,8 +78,22 @@ enum class DataLayout { // Not use
7678
enum class PoolingMode {
7779
kMaximum,
7880
kAverage,
81+
kMaximumDeterministic,
7982
};
8083

84+
inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
85+
switch (mode) {
86+
case PoolingMode::kMaximumDeterministic:
87+
return CUDNN_POOLING_MAX_DETERMINISTIC;
88+
case PoolingMode::kAverage:
89+
return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
90+
case PoolingMode::kMaximum:
91+
return CUDNN_POOLING_MAX;
92+
default:
93+
PADDLE_THROW("Unexpected pooling mode.");
94+
}
95+
}
96+
8197
template <typename T>
8298
class CudnnDataType;
8399

@@ -293,9 +309,7 @@ class ScopedPoolingDescriptor {
293309
PADDLE_ENFORCE_EQ(kernel.size(), pads.size());
294310
PADDLE_ENFORCE_EQ(kernel.size(), strides.size());
295311
PADDLE_ENFORCE(dynload::cudnnSetPoolingNdDescriptor(
296-
desc_, (mode == PoolingMode::kMaximum
297-
? CUDNN_POOLING_MAX
298-
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
312+
desc_, (GetPoolingMode(mode)),
299313
CUDNN_PROPAGATE_NAN, // Always propagate nans.
300314
kernel.size(), kernel.data(), pads.data(), strides.data()));
301315
return desc_;

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __bootstrap__():
120120
]
121121
if core.is_compiled_with_cuda():
122122
read_env_flags += [
123-
'fraction_of_gpu_memory_to_use', 'cudnn_algo_use_autotune'
123+
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic'
124124
]
125125
core.init_gflags([sys.argv[0]] +
126126
["--tryfromenv=" + ",".join(read_env_flags)])

0 commit comments

Comments
 (0)