Skip to content

Commit ffb24a7

Browse files
committed
add dropout attr; test=develop
1 parent 909e134 commit ffb24a7

File tree

11 files changed

+148
-28
lines changed

11 files changed

+148
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ third_party/
2828
build_*
2929
# clion workspace.
3030
cmake-build-*
31+
model_test

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
8686
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
8787
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
8888
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
89-
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None))
89+
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, False))
9090
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
9191
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
9292
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))

paddle/fluid/operators/dropout_op.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
5757
"will be dropped.")
5858
.SetDefault(false);
5959
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
60+
AddAttr<bool>("dropout_implementation",
61+
"When it's True, In the training, after set some value"
62+
"to 0 (probability is dropout_prob),"
63+
"all the value will divide (1-dropout_prob)"
64+
"By using this way, will do nothing in the inference program"
65+
"The dropout op can be removed in the inference program."
66+
"The inference program will be more efficient"
67+
"When it's False, same as original")
68+
.SetDefault(false);
6069

6170
AddComment(R"DOC(
6271
Dropout Operator.
@@ -104,7 +113,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
104113
paddle::framework::DefaultGradOpDescMaker<true>);
105114
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
106115
REGISTER_OP_CPU_KERNEL(
107-
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
116+
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
117+
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>);
108118
REGISTER_OP_CPU_KERNEL(
109119
dropout_grad,
110-
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
120+
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
121+
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/dropout_op.cu

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ namespace operators {
2626
template <typename T>
2727
__global__ void RandomGenerator(const size_t n, const int seed,
2828
const float dropout_prob, const T* src,
29-
T* mask_data, T* dst) {
29+
T* mask_data, T* dst,
30+
bool dropout_implementation) {
3031
thrust::minstd_rand rng;
3132
rng.seed(seed);
3233
thrust::uniform_real_distribution<float> dist(0, 1);
@@ -47,7 +48,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
4748
if (dist(rng) < dropout_prob) {
4849
mask = static_cast<T>(0);
4950
} else {
50-
mask = static_cast<T>(1);
51+
if (dropout_implementation) {
52+
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
53+
} else {
54+
mask = static_cast<T>(1);
55+
}
5156
}
5257
dest = s * mask;
5358
mask_data[idx] = mask;
@@ -67,6 +72,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
6772
y->mutable_data<T>(context.GetPlace());
6873
float dropout_prob = context.Attr<float>("dropout_prob");
6974

75+
auto dropout_implementation = context.Attr<bool>("dropout_implementation");
7076
auto& place = *context.template device_context<Place>().eigen_device();
7177
if (!context.Attr<bool>("is_test")) {
7278
auto* mask = context.Output<Tensor>("Mask");
@@ -83,11 +89,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
8389
int grid = (x->numel() + threads - 1) / threads;
8490
RandomGenerator<
8591
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
86-
size, seed, dropout_prob, x_data, mask_data, y_data);
92+
size, seed, dropout_prob, x_data, mask_data, y_data,
93+
dropout_implementation);
8794
} else {
8895
auto X = EigenMatrix<T>::Reshape(*x, 1);
8996
auto Y = EigenMatrix<T>::Reshape(*y, 1);
90-
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
97+
if (dropout_implementation) {
98+
Y.device(place) = X;
99+
} else {
100+
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
101+
}
91102
}
92103
}
93104
};
@@ -99,6 +110,8 @@ namespace ops = paddle::operators;
99110
namespace plat = paddle::platform;
100111
REGISTER_OP_CUDA_KERNEL(
101112
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
102-
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
103-
REGISTER_OP_CUDA_KERNEL(dropout_grad,
104-
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
113+
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
114+
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
115+
REGISTER_OP_CUDA_KERNEL(
116+
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
117+
ops::DropoutGradKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/dropout_op.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
3636
auto* y_data = y->mutable_data<T>(context.GetPlace());
3737
float dropout_prob = context.Attr<float>("dropout_prob");
3838

39+
auto dropout_implementation = context.Attr<bool>("dropout_implementation");
3940
if (!context.Attr<bool>("is_test")) {
4041
auto* mask = context.Output<Tensor>("Mask");
4142
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
@@ -49,22 +50,32 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
4950
engine.seed(seed);
5051

5152
std::uniform_real_distribution<float> dist(0, 1);
53+
5254
size_t size = framework::product(mask->dims());
5355
for (size_t i = 0; i < size; ++i) {
5456
if (dist(engine) < dropout_prob) {
5557
mask_data[i] = 0;
5658
y_data[i] = 0;
5759
} else {
58-
mask_data[i] = 1;
59-
y_data[i] = x_data[i];
60+
if (dropout_implementation) {
61+
mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
62+
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
63+
} else {
64+
mask_data[i] = 1;
65+
y_data[i] = x_data[i];
66+
}
6067
}
6168
}
6269
} else {
6370
auto X = EigenMatrix<T>::Reshape(*x, 1);
6471
auto Y = EigenMatrix<T>::Reshape(*y, 1);
6572
auto& place =
6673
*context.template device_context<DeviceContext>().eigen_device();
67-
Y.device(place) = X * (1.0f - dropout_prob);
74+
if (dropout_implementation) {
75+
Y.device(place) = X;
76+
} else {
77+
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
78+
}
6879
}
6980
}
7081
};

paddle/fluid/operators/softmax_cudnn_op.cu.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
7676
namespace plat = paddle::platform;
7777
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
7878
ops::SoftmaxCUDNNKernel<float>,
79+
ops::SoftmaxCUDNNKernel<double>,
7980
ops::SoftmaxCUDNNKernel<plat::float16>);
8081
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
81-
ops::SoftmaxGradCUDNNKernel<float>);
82+
ops::SoftmaxGradCUDNNKernel<float>,
83+
ops::SoftmaxGradCUDNNKernel<double>);

paddle/fluid/operators/transpose_op.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
210210
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
211211

212212
REGISTER_OP_CPU_KERNEL(
213-
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
213+
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
214+
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
214215
REGISTER_OP_CPU_KERNEL(
215216
transpose_grad,
216-
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
217+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
218+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
217219

218220
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
219221
ops::Transpose2GradMaker);
220222
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
221223

222224
REGISTER_OP_CPU_KERNEL(
223-
transpose2,
224-
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
225+
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
226+
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
225227
REGISTER_OP_CPU_KERNEL(
226228
transpose2_grad,
227-
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
229+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
230+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/transpose_op.cu.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ limitations under the License. */
1616

1717
namespace ops = paddle::operators;
1818
REGISTER_OP_CUDA_KERNEL(
19-
transpose,
20-
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
19+
transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
2121
REGISTER_OP_CUDA_KERNEL(
2222
transpose_grad,
23-
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
23+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
24+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
2425

2526
REGISTER_OP_CUDA_KERNEL(
2627
transpose2,
27-
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
28+
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
29+
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
2830
REGISTER_OP_CUDA_KERNEL(
2931
transpose2_grad,
30-
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
32+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
33+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);

python/paddle/fluid/clip.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _process_context(self, context, param, grad):
272272
)
273273

274274
square = grad * grad
275-
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64')
275+
local_norm_var = layers.reduce_sum(input=square)
276276
context[self.group_name].append(local_norm_var)
277277

278278
self.context = context
@@ -282,7 +282,6 @@ def _create_operators(self, param, grad):
282282
if group_scale_name not in self.context:
283283
group_norm_var = layers.sums(input=self.context[self.group_name])
284284
group_norm_var = layers.sqrt(x=group_norm_var)
285-
group_norm_var = layers.cast(group_norm_var, 'float32')
286285
clip_var = self.context[self.group_name + "_clip"]
287286
group_scale_var = layers.elementwise_div(
288287
x=clip_var,

python/paddle/fluid/layers/nn.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,12 @@ def cos_sim(X, Y):
974974
return out
975975

976976

977-
def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
977+
def dropout(x,
978+
dropout_prob,
979+
is_test=False,
980+
seed=None,
981+
name=None,
982+
dropout_implementation=False):
978983
"""
979984
Computes dropout.
980985
@@ -994,6 +999,14 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
994999
units will be dropped. DO NOT use a fixed seed in training.
9951000
name (str|None): A name for this layer(optional). If set None, the layer
9961001
will be named automatically.
1002+
dropout_implementation(bool): A Flag indicating whether divide (1-dropout_prob).
1003+
When it's True, all the units will divide (1-dropout_prob)
1004+
after set some units to zero in the train program.
1005+
And do nothing in the inference program.
1006+
The dropout op can be removed in the inference program.
1007+
The inference program will be more efficient
1008+
When it's False, same as original
1009+
9971010
9981011
Returns:
9991012
Variable: A tensor variable is the shape with `x`.
@@ -1022,7 +1035,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
10221035
'dropout_prob': dropout_prob,
10231036
'is_test': is_test,
10241037
'fix_seed': seed is not None,
1025-
'seed': seed if seed is not None else 0
1038+
'seed': seed if seed is not None else 0,
1039+
'dropout_implementation': dropout_implementation,
10261040
})
10271041
return out
10281042

0 commit comments

Comments
 (0)