Skip to content

Commit f2e400d

Browse files
reyoungdzhwinter
authored andcommitted
Revert "accelerate dropout (#9902)" (#10082)
* Revert "accelerate dropout (#9902)" This reverts commit 2e331c6. * Correct discard
1 parent ad91bfe commit f2e400d

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

paddle/fluid/operators/dropout_op.cu

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,34 @@ namespace paddle {
2424
namespace operators {
2525

2626
template <typename T>
27-
__global__ void RandomGenerator(const size_t n, const T* src,
28-
const T* cpu_mask_data, T* mask_data, T* dst) {
27+
__global__ void RandomGenerator(const size_t n, const int seed,
28+
const float dropout_prob, const T* src,
29+
T* mask_data, T* dst) {
30+
thrust::minstd_rand rng;
31+
rng.seed(seed);
32+
thrust::uniform_real_distribution<float> dist(0, 1);
33+
2934
int idx = blockDim.x * blockIdx.x + threadIdx.x;
35+
int step_size = 0;
36+
37+
T mask;
38+
T dest;
3039
for (; idx < n; idx += blockDim.x * gridDim.x) {
31-
mask_data[idx] = cpu_mask_data[idx];
32-
dst[idx] = mask_data[idx] * src[idx];
40+
T s = src[idx];
41+
if (step_size == 0) {
42+
rng.discard(idx);
43+
step_size = blockDim.x * gridDim.x;
44+
} else {
45+
rng.discard(step_size);
46+
}
47+
if (dist(rng) < dropout_prob) {
48+
mask = static_cast<T>(0);
49+
} else {
50+
mask = static_cast<T>(1);
51+
}
52+
dest = s * mask;
53+
mask_data[idx] = mask;
54+
dst[idx] = dest;
3355
}
3456
}
3557

@@ -56,27 +78,15 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
5678
std::random_device rnd;
5779
int seed =
5880
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
59-
std::minstd_rand engine;
60-
engine.seed(seed);
61-
std::uniform_real_distribution<float> dist(0, 1);
62-
framework::Vector<T> cpu_mask(size);
63-
for (size_t i = 0; i < size; ++i) {
64-
if (dist(engine) < dropout_prob) {
65-
cpu_mask[i] = static_cast<T>(0);
66-
} else {
67-
cpu_mask[i] = static_cast<T>(1);
68-
}
69-
}
7081

7182
int threads = 512;
7283
int grid = (x->numel() + threads - 1) / threads;
7384
RandomGenerator<
7485
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
75-
size, x_data, cpu_mask.CUDAData(context.GetPlace()), mask_data,
76-
y_data);
86+
size, seed, dropout_prob, x_data, mask_data, y_data);
7787
} else {
78-
auto X = EigenVector<T>::Flatten(*x);
79-
auto Y = EigenVector<T>::Flatten(*y);
88+
auto X = EigenMatrix<T>::Reshape(*x, 1);
89+
auto Y = EigenMatrix<T>::Reshape(*y, 1);
8090
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
8191
}
8292
}
@@ -89,8 +99,6 @@ namespace ops = paddle::operators;
8999
namespace plat = paddle::platform;
90100
REGISTER_OP_CUDA_KERNEL(
91101
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
92-
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>,
93102
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
94103
REGISTER_OP_CUDA_KERNEL(dropout_grad,
95-
ops::DropoutGradKernel<plat::CUDADeviceContext, double>,
96104
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);

paddle/fluid/operators/dropout_op.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace operators {
2424
using Tensor = framework::Tensor;
2525
template <typename T, int MajorType = Eigen::RowMajor,
2626
typename IndexType = Eigen::DenseIndex>
27-
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
27+
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2828

2929
template <typename DeviceContext, typename T>
3030
class CPUDropoutKernel : public framework::OpKernel<T> {
@@ -60,8 +60,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
6060
}
6161
}
6262
} else {
63-
auto X = EigenVector<T>::Flatten(*x);
64-
auto Y = EigenVector<T>::Flatten(*y);
63+
auto X = EigenMatrix<T>::Reshape(*x, 1);
64+
auto Y = EigenMatrix<T>::Reshape(*y, 1);
6565
auto& place =
6666
*context.template device_context<DeviceContext>().eigen_device();
6767
Y.device(place) = X * (1.0f - dropout_prob);
@@ -81,9 +81,9 @@ class DropoutGradKernel : public framework::OpKernel<T> {
8181
auto* mask = context.Input<Tensor>("Mask");
8282
grad_x->mutable_data<T>(context.GetPlace());
8383

84-
auto M = EigenVector<T>::Flatten(*mask);
85-
auto dX = EigenVector<T>::Flatten(*grad_x);
86-
auto dY = EigenVector<T>::Flatten(*grad_y);
84+
auto M = EigenMatrix<T>::Reshape(*mask, 1);
85+
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
86+
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
8787

8888
auto& place =
8989
*context.template device_context<DeviceContext>().eigen_device();

paddle/fluid/operators/dropout_op_test.cc

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <unistd.h>
16-
#include <iostream>
1716

1817
#include <string>
1918
#include <thread> // NOLINT
@@ -33,16 +32,14 @@ namespace m = paddle::operators::math;
3332

3433
USE_OP(dropout);
3534

36-
static paddle::framework::DDim dims = {10, 10};
37-
3835
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
3936
// init
4037
auto var = scope->Var("X");
4138
auto tensor = var->GetMutable<f::LoDTensor>();
42-
tensor->Resize(dims);
39+
tensor->Resize({10, 10});
4340

4441
std::vector<float> init;
45-
for (int64_t i = 0; i < f::product(dims); ++i) {
42+
for (int64_t i = 0; i < 10 * 10; ++i) {
4643
init.push_back(1.0);
4744
}
4845

@@ -51,19 +48,18 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
5148
auto place = ctx.GetPlace();
5249
auto out_var = scope->Var("Out");
5350
auto out_tensor = out_var->GetMutable<f::LoDTensor>();
54-
out_tensor->Resize(dims);
51+
out_tensor->Resize({10, 10});
5552
out_tensor->mutable_data<float>(place); // allocate
5653

5754
auto mask_var = scope->Var("Mask");
5855
auto mask_tensor = mask_var->GetMutable<f::LoDTensor>();
59-
mask_tensor->Resize(dims);
56+
mask_tensor->Resize({10, 10});
6057
mask_tensor->mutable_data<float>(place); // allocate
6158

6259
// run
6360
f::AttributeMap attrs;
6461
float dropout_prob = 0.5;
65-
attrs.insert({"is_test", false});
66-
attrs.insert({"fix_seed", true});
62+
attrs.insert({"fix_seed", 1});
6763
attrs.insert({"seed", 3});
6864
attrs.insert({"dropout_prob", dropout_prob});
6965
auto dropout_op = f::OpRegistry::CreateOp(
@@ -73,7 +69,6 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
7369

7470
std::vector<float> out_vec;
7571
TensorToVector(*out_tensor, ctx, &out_vec);
76-
ctx.Wait();
7772

7873
std::vector<float> std_out = {
7974
0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1,
@@ -88,22 +83,22 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
8883
}
8984
}
9085

86+
// TODO(wyi): Due to
87+
// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily
88+
// disable this test to remove the prevention of the merge of
89+
// unrelated PRs.
90+
/*
9191
TEST(Dropout, CPUDense) {
9292
f::Scope scope;
9393
p::CPUPlace place;
9494
p::CPUDeviceContext ctx(place);
95-
Compare(&scope, ctx);
95+
Compare(scope, ctx);
9696
}
9797
98-
// TODO(wyi, dzhwinter): Due to
99-
// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily
100-
// disable this test to remove the prevention of the merge of
101-
// unrelated PRs.
102-
/*
10398
TEST(Dropout, GPUDense) {
10499
f::Scope scope;
105100
p::CUDAPlace place;
106101
p::CUDADeviceContext ctx(place);
107-
Compare(&scope, ctx);
102+
Compare(scope, ctx);
108103
}
109104
*/

0 commit comments

Comments
 (0)