Skip to content

Commit 379d933

Browse files
authored
Merge pull request #14036 from phlrain/add_dropout_att_new
Add dropout att new 1.1 merge
2 parents 18be725 + a4ad286 commit 379d933

File tree

11 files changed

+174
-28
lines changed

11 files changed

+174
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ third_party/
2828
build_*
2929
# clion workspace.
3030
cmake-build-*
31+
model_test

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
8686
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
8787
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
8888
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
89-
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None))
89+
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'))
9090
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
9191
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
9292
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))

paddle/fluid/operators/dropout_op.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/dropout_op.h"
16+
#include <string>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
5758
"will be dropped.")
5859
.SetDefault(false);
5960
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
61+
AddAttr<std::string>(
62+
"dropout_implementation",
63+
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
64+
"There are two kinds of ways to implement dropout"
65+
"(the mask below is a tensor have the same shape with input"
66+
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
67+
"1. downgrade_in_infer(default), downgrade the outcome at inference "
68+
"time"
69+
" train: out = input * mask"
70+
" inference: out = input * dropout_prob"
71+
"2. upscale_in_train, upscale the outcome at training time, do nothing "
72+
"in inference"
73+
" train: out = input * mask / ( 1.0 - dropout_prob )"
74+
" inference: out = input"
75+
" dropout op can be removed from the program. the program will be "
76+
"efficient")
77+
.SetDefault("downgrade_in_infer")
78+
.AddCustomChecker([](const std::string& type) {
79+
PADDLE_ENFORCE(
80+
type == "downgrade_in_infer" || type == "upscale_in_train",
81+
"dropout_implementation can only be downgrade_in_infer or "
82+
"upscale_in_train");
83+
});
6084

6185
AddComment(R"DOC(
6286
Dropout Operator.
@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
104128
paddle::framework::DefaultGradOpDescMaker<true>);
105129
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
106130
REGISTER_OP_CPU_KERNEL(
107-
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
131+
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
132+
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>);
108133
REGISTER_OP_CPU_KERNEL(
109134
dropout_grad,
110-
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
135+
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
136+
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/dropout_op.cu

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <thrust/iterator/counting_iterator.h>
1818
#include <thrust/random.h>
1919
#include <thrust/transform.h>
20+
#include <string>
2021
#include "paddle/fluid/operators/dropout_op.h"
2122
#include "paddle/fluid/platform/float16.h"
2223

@@ -26,7 +27,8 @@ namespace operators {
2627
template <typename T>
2728
__global__ void RandomGenerator(const size_t n, const int seed,
2829
const float dropout_prob, const T* src,
29-
T* mask_data, T* dst) {
30+
T* mask_data, T* dst,
31+
bool is_upscale_in_train) {
3032
thrust::minstd_rand rng;
3133
rng.seed(seed);
3234
thrust::uniform_real_distribution<float> dist(0, 1);
@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
4749
if (dist(rng) < dropout_prob) {
4850
mask = static_cast<T>(0);
4951
} else {
50-
mask = static_cast<T>(1);
52+
if (is_upscale_in_train) {
53+
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
54+
} else {
55+
mask = static_cast<T>(1);
56+
}
5157
}
5258
dest = s * mask;
5359
mask_data[idx] = mask;
@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
6773
y->mutable_data<T>(context.GetPlace());
6874
float dropout_prob = context.Attr<float>("dropout_prob");
6975

76+
auto dropout_implementation =
77+
context.Attr<std::string>("dropout_implementation");
7078
auto& place = *context.template device_context<Place>().eigen_device();
7179
if (!context.Attr<bool>("is_test")) {
7280
auto* mask = context.Output<Tensor>("Mask");
@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
8391
int grid = (x->numel() + threads - 1) / threads;
8492
RandomGenerator<
8593
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
86-
size, seed, dropout_prob, x_data, mask_data, y_data);
94+
size, seed, dropout_prob, x_data, mask_data, y_data,
95+
(dropout_implementation == "upscale_in_train"));
8796
} else {
8897
auto X = EigenMatrix<T>::Reshape(*x, 1);
8998
auto Y = EigenMatrix<T>::Reshape(*y, 1);
90-
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
99+
if (dropout_implementation == "upscale_in_train") {
100+
Y.device(place) = X;
101+
} else {
102+
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
103+
}
91104
}
92105
}
93106
};
@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
99112
namespace plat = paddle::platform;
100113
REGISTER_OP_CUDA_KERNEL(
101114
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
102-
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
103-
REGISTER_OP_CUDA_KERNEL(dropout_grad,
104-
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
115+
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
116+
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
117+
REGISTER_OP_CUDA_KERNEL(
118+
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
119+
ops::DropoutGradKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/dropout_op.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414
#pragma once
1515

1616
#include <random>
17+
#include <string>
1718

1819
#include "paddle/fluid/framework/eigen.h"
1920
#include "paddle/fluid/framework/op_registry.h"
@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
3637
auto* y_data = y->mutable_data<T>(context.GetPlace());
3738
float dropout_prob = context.Attr<float>("dropout_prob");
3839

40+
auto dropout_implementation =
41+
context.Attr<std::string>("dropout_implementation");
3942
if (!context.Attr<bool>("is_test")) {
4043
auto* mask = context.Output<Tensor>("Mask");
4144
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
@@ -49,22 +52,32 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
4952
engine.seed(seed);
5053

5154
std::uniform_real_distribution<float> dist(0, 1);
55+
5256
size_t size = framework::product(mask->dims());
5357
for (size_t i = 0; i < size; ++i) {
5458
if (dist(engine) < dropout_prob) {
5559
mask_data[i] = 0;
5660
y_data[i] = 0;
5761
} else {
58-
mask_data[i] = 1;
59-
y_data[i] = x_data[i];
62+
if (dropout_implementation == "upscale_in_train") {
63+
mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
64+
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
65+
} else {
66+
mask_data[i] = 1;
67+
y_data[i] = x_data[i];
68+
}
6069
}
6170
}
6271
} else {
6372
auto X = EigenMatrix<T>::Reshape(*x, 1);
6473
auto Y = EigenMatrix<T>::Reshape(*y, 1);
6574
auto& place =
6675
*context.template device_context<DeviceContext>().eigen_device();
67-
Y.device(place) = X * (1.0f - dropout_prob);
76+
if (dropout_implementation == "upscale_in_train") {
77+
Y.device(place) = X;
78+
} else {
79+
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
80+
}
6881
}
6982
}
7083
};

paddle/fluid/operators/softmax_cudnn_op.cu.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
7676
namespace plat = paddle::platform;
7777
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
7878
ops::SoftmaxCUDNNKernel<float>,
79+
ops::SoftmaxCUDNNKernel<double>,
7980
ops::SoftmaxCUDNNKernel<plat::float16>);
8081
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
81-
ops::SoftmaxGradCUDNNKernel<float>);
82+
ops::SoftmaxGradCUDNNKernel<float>,
83+
ops::SoftmaxGradCUDNNKernel<double>);

paddle/fluid/operators/transpose_op.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
210210
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
211211

212212
REGISTER_OP_CPU_KERNEL(
213-
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
213+
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
214+
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
214215
REGISTER_OP_CPU_KERNEL(
215216
transpose_grad,
216-
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
217+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
218+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
217219

218220
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
219221
ops::Transpose2GradMaker);
220222
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
221223

222224
REGISTER_OP_CPU_KERNEL(
223-
transpose2,
224-
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
225+
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
226+
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
225227
REGISTER_OP_CPU_KERNEL(
226228
transpose2_grad,
227-
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
229+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
230+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/transpose_op.cu.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ limitations under the License. */
1616

1717
namespace ops = paddle::operators;
1818
REGISTER_OP_CUDA_KERNEL(
19-
transpose,
20-
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
19+
transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
2121
REGISTER_OP_CUDA_KERNEL(
2222
transpose_grad,
23-
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
23+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
24+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
2425

2526
REGISTER_OP_CUDA_KERNEL(
2627
transpose2,
27-
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
28+
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
29+
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
2830
REGISTER_OP_CUDA_KERNEL(
2931
transpose2_grad,
30-
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
32+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
33+
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);

python/paddle/fluid/clip.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _process_context(self, context, param, grad):
272272
)
273273

274274
square = grad * grad
275-
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64')
275+
local_norm_var = layers.reduce_sum(input=square)
276276
context[self.group_name].append(local_norm_var)
277277

278278
self.context = context
@@ -282,7 +282,6 @@ def _create_operators(self, param, grad):
282282
if group_scale_name not in self.context:
283283
group_norm_var = layers.sums(input=self.context[self.group_name])
284284
group_norm_var = layers.sqrt(x=group_norm_var)
285-
group_norm_var = layers.cast(group_norm_var, 'float32')
286285
clip_var = self.context[self.group_name + "_clip"]
287286
group_scale_var = layers.elementwise_div(
288287
x=clip_var,

python/paddle/fluid/layers/nn.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,12 @@ def cos_sim(X, Y):
980980
return out
981981

982982

983-
def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
983+
def dropout(x,
984+
dropout_prob,
985+
is_test=False,
986+
seed=None,
987+
name=None,
988+
dropout_implementation="downgrade_in_infer"):
984989
"""
985990
Computes dropout.
986991
@@ -1000,6 +1005,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
10001005
units will be dropped. DO NOT use a fixed seed in training.
10011006
name (str|None): A name for this layer(optional). If set None, the layer
10021007
will be named automatically.
1008+
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1009+
1. downgrade_in_infer(default), downgrade the outcome at inference
1010+
train: out = input * mask
1011+
inference: out = input * dropout_prob
1012+
(make is a tensor same shape with input, value is 0 or 1
1013+
ratio of 0 is dropout_prob)
1014+
2. upscale_in_train, upscale the outcome at training time
1015+
train: out = input * mask / ( 1.0 - dropout_prob )
1016+
inference: out = input
1017+
(make is a tensor same shape with input, value is 0 or 1
1018+
ratio of 0 is dropout_prob)
1019+
dropout op can be removed from the program.
1020+
the program will be efficient
1021+
1022+
10031023
10041024
Returns:
10051025
Variable: A tensor variable is the shape with `x`.
@@ -1029,7 +1049,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
10291049
'dropout_prob': dropout_prob,
10301050
'is_test': is_test,
10311051
'fix_seed': seed is not None,
1032-
'seed': seed if seed is not None else 0
1052+
'seed': seed if seed is not None else 0,
1053+
'dropout_implementation': dropout_implementation,
10331054
})
10341055
return out
10351056

0 commit comments

Comments
 (0)