Skip to content

Commit a6e6bc4

Browse files
committed
modify dropout att; test=develop
1 parent 049c9c7 commit a6e6bc4

File tree

5 files changed

+55
-29
lines changed

5 files changed

+55
-29
lines changed

paddle/fluid/operators/dropout_op.cc

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/dropout_op.h"
16+
#include <string>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -57,15 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
5758
"will be dropped.")
5859
.SetDefault(false);
5960
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
60-
AddAttr<bool>("dropout_implementation",
61-
"When it's True, In the training, after set some value"
62-
"to 0 (probability is dropout_prob),"
63-
"all the value will divide (1-dropout_prob)"
64-
"By using this way, will do nothing in the inference program"
65-
"The dropout op can be removed in the inference program."
66-
"The inference program will be more efficient"
67-
"When it's False, same as original")
68-
.SetDefault(false);
61+
AddAttr<std::string>(
62+
"dropout_implementation",
63+
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
64+
"There are two kinds of ways to implement dropout"
65+
"(the mask below is a tensor have the same shape with input"
66+
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
67+
"1. downgrade_in_infer(default), downgrade the outcome at inference "
68+
"time"
69+
" train: out = input * mask"
70+
" inference: out = input * dropout_prob"
71+
"2. upscale_in_train, upscale the outcome at training time, do nothing "
72+
"in inference"
73+
" train: out = input * mask / ( 1.0 - dropout_prob )"
74+
" inference: out = input"
75+
" dropout op can be removed from the program. the program will be "
76+
"efficient")
77+
.SetDefault("downgrade_in_infer")
78+
.AddCustomChecker([](const std::string& type) {
79+
PADDLE_ENFORCE(
80+
type == "downgrade_in_infer" || type == "upscale_in_train",
81+
"dropout_implementation can only be downgrade_in_infer or "
82+
"upscale_in_train");
83+
});
6984

7085
AddComment(R"DOC(
7186
Dropout Operator.

paddle/fluid/operators/dropout_op.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <thrust/iterator/counting_iterator.h>
1818
#include <thrust/random.h>
1919
#include <thrust/transform.h>
20+
#include <string>
2021
#include "paddle/fluid/operators/dropout_op.h"
2122
#include "paddle/fluid/platform/float16.h"
2223

@@ -27,7 +28,7 @@ template <typename T>
2728
__global__ void RandomGenerator(const size_t n, const int seed,
2829
const float dropout_prob, const T* src,
2930
T* mask_data, T* dst,
30-
bool dropout_implementation) {
31+
bool is_upscale_in_train) {
3132
thrust::minstd_rand rng;
3233
rng.seed(seed);
3334
thrust::uniform_real_distribution<float> dist(0, 1);
@@ -48,7 +49,7 @@ __global__ void RandomGenerator(const size_t n, const int seed,
4849
if (dist(rng) < dropout_prob) {
4950
mask = static_cast<T>(0);
5051
} else {
51-
if (dropout_implementation) {
52+
if (is_upscale_in_train) {
5253
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
5354
} else {
5455
mask = static_cast<T>(1);
@@ -72,7 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
7273
y->mutable_data<T>(context.GetPlace());
7374
float dropout_prob = context.Attr<float>("dropout_prob");
7475

75-
auto dropout_implementation = context.Attr<bool>("dropout_implementation");
76+
auto dropout_implementation =
77+
context.Attr<std::string>("dropout_implementation");
7678
auto& place = *context.template device_context<Place>().eigen_device();
7779
if (!context.Attr<bool>("is_test")) {
7880
auto* mask = context.Output<Tensor>("Mask");
@@ -90,11 +92,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
9092
RandomGenerator<
9193
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
9294
size, seed, dropout_prob, x_data, mask_data, y_data,
93-
dropout_implementation);
95+
(dropout_implementation == "upscale_in_train"));
9496
} else {
9597
auto X = EigenMatrix<T>::Reshape(*x, 1);
9698
auto Y = EigenMatrix<T>::Reshape(*y, 1);
97-
if (dropout_implementation) {
99+
if (dropout_implementation == "upscale_in_train") {
98100
Y.device(place) = X;
99101
} else {
100102
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);

paddle/fluid/operators/dropout_op.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414
#pragma once
1515

1616
#include <random>
17+
#include <string>
1718

1819
#include "paddle/fluid/framework/eigen.h"
1920
#include "paddle/fluid/framework/op_registry.h"
@@ -36,7 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
3637
auto* y_data = y->mutable_data<T>(context.GetPlace());
3738
float dropout_prob = context.Attr<float>("dropout_prob");
3839

39-
auto dropout_implementation = context.Attr<bool>("dropout_implementation");
40+
auto dropout_implementation =
41+
context.Attr<std::string>("dropout_implementation");
4042
if (!context.Attr<bool>("is_test")) {
4143
auto* mask = context.Output<Tensor>("Mask");
4244
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
@@ -57,7 +59,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
5759
mask_data[i] = 0;
5860
y_data[i] = 0;
5961
} else {
60-
if (dropout_implementation) {
62+
if (dropout_implementation == "upscale_in_train") {
6163
mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
6264
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
6365
} else {
@@ -71,7 +73,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
7173
auto Y = EigenMatrix<T>::Reshape(*y, 1);
7274
auto& place =
7375
*context.template device_context<DeviceContext>().eigen_device();
74-
if (dropout_implementation) {
76+
if (dropout_implementation == "upscale_in_train") {
7577
Y.device(place) = X;
7678
} else {
7779
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);

python/paddle/fluid/layers/nn.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ def dropout(x,
985985
is_test=False,
986986
seed=None,
987987
name=None,
988-
dropout_implementation=False):
988+
dropout_implementation="downgrade_in_infer"):
989989
"""
990990
Computes dropout.
991991
@@ -1005,13 +1005,20 @@ def dropout(x,
10051005
units will be dropped. DO NOT use a fixed seed in training.
10061006
name (str|None): A name for this layer(optional). If set None, the layer
10071007
will be named automatically.
1008-
dropout_implementation(bool): A Flag indicating whether divide (1-dropout_prob).
1009-
When it's True, all the units will divide (1-dropout_prob)
1010-
after set some units to zero in the train program.
1011-
And do nothing in the inference program.
1012-
The dropout op can be removed in the inference program.
1013-
The inference program will be more efficient
1014-
When it's False, same as original
1008+
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1009+
1. downgrade_in_infer(default), downgrade the outcome at inference
1010+
train: out = input * mask
1011+
inference: out = input * dropout_prob
1012+
(make is a tensor same shape with input, value is 0 or 1
1013+
ratio of 0 is dropout_prob)
1014+
2. upscale_in_train, upscale the outcome at training time
1015+
train: out = input * mask / ( 1.0 - dropout_prob )
1016+
inference: out = input
1017+
(make is a tensor same shape with input, value is 0 or 1
1018+
ratio of 0 is dropout_prob)
1019+
dropout op can be removed from the program.
1020+
the program will be efficient
1021+
10151022
10161023
10171024
Returns:

python/paddle/fluid/tests/unittests/test_dropout_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def setUp(self):
9393
'dropout_prob': 1.0,
9494
'fix_seed': True,
9595
'is_test': False,
96-
'div_prob_in_train': True
96+
'dropout_implementation': 'upscale_in_train'
9797
}
9898
self.outputs = {
9999
'Out': np.zeros((32, 64)).astype('float32'),
@@ -109,7 +109,7 @@ def setUp(self):
109109
'dropout_prob': 0.0,
110110
'fix_seed': True,
111111
'is_test': False,
112-
'div_prob_in_train': True
112+
'dropout_implementation': 'upscale_in_train'
113113
}
114114
self.outputs = {
115115
'Out': self.inputs['X'],
@@ -125,7 +125,7 @@ def setUp(self):
125125
'dropout_prob': 0.35,
126126
'fix_seed': True,
127127
'is_test': True,
128-
'div_prob_in_train': True
128+
'dropout_implementation': 'upscale_in_train'
129129
}
130130
self.outputs = {'Out': self.inputs['X']}
131131

@@ -140,7 +140,7 @@ def setUp(self):
140140
self.attrs = {
141141
'dropout_prob': 0.75,
142142
'is_test': True,
143-
'div_prob_in_train': True
143+
'dropout_implementation': 'upscale_in_train'
144144
}
145145
self.outputs = {'Out': self.inputs['X']}
146146

0 commit comments

Comments
 (0)