Skip to content

Commit 91b6d60

Browse files
committed
Merge branch 'fix_bug_in_recordio' into dev_MultiEpochReader
2 parents bce08d1 + c346a34 commit 91b6d60

31 files changed

+929
-263
lines changed

doc/design/distributed_lookup_table_design.md renamed to doc/fluid/design/dist_train/distributed_lookup_table_design.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ lookup of rows.
2626
The following figure illustrates the multiplication of x with two
2727
non-zero elements, or say, two symbols, and a lookup table W:
2828

29-
![lookup table](./lookup_table.png)
29+
![lookup table](./src/lookup_table.png)
3030

3131
### The Backward Algorithm
3232

@@ -42,7 +42,7 @@ or some more sophisticated algorithms that rely on both W' and W:
4242
$$W = f(W, W')$$
4343

4444
The following figure illustrates the backward pass of the lookup
45-
operator: ![lookup table training](./lookup_table_training.png)
45+
operator: ![lookup table training](./src/lookup_table_training.png)
4646

4747
## Distributed Storage Service
4848

doc/fluid/design/motivation/fluid.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ In computability theory, a system of data-manipulation rules, such as a programm
103103

104104
There are two ways to execute a Fluid program. When a program is executed, it creates a protobuf message [`ProgramDesc`](https://github.com/PaddlePaddle/Paddle/blob/a91efdde6910ce92a78e3aa7157412c4c88d9ee8/paddle/framework/framework.proto#L145) that describes the process and is conceptually like an [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree).
105105

106-
There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program.
106+
There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program.
107107

108108
Fluid is moving towards the direction of a compiler, which is explain in [fluid_compiler.md](fluid_compiler.md).
109109

doc/v2/api/fluid/optimizer.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,10 @@ DecayedAdagrad
4747
:members:
4848
:noindex:
4949

50+
Adadelta
51+
--------------
52+
53+
.. autoclass:: paddle.fluid.optimizer.AdadeltaOptimizer
54+
:members:
55+
:noindex:
56+

paddle/fluid/framework/channel_test.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,3 +871,67 @@ TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) {
871871
ch->Reset<int>(0);
872872
ChannelHolderDestroyUnblockSenders(ch, false);
873873
}
874+
875+
// This tests that closing a channelholder many times.
876+
void ChannelHolderManyTimesClose(ChannelHolder *ch) {
877+
const int num_threads = 15;
878+
std::thread t[num_threads];
879+
bool thread_ended[num_threads];
880+
881+
// Launches threads that try to send data to channel.
882+
for (size_t i = 0; i < num_threads / 3; i++) {
883+
thread_ended[i] = false;
884+
t[i] = std::thread(
885+
[&](bool *ended) {
886+
int data = 10;
887+
ch->Send(&data);
888+
*ended = true;
889+
},
890+
&thread_ended[i]);
891+
}
892+
893+
// Launches threads that try to receive data to channel.
894+
for (size_t i = num_threads / 3; i < 2 * num_threads / 3; i++) {
895+
thread_ended[i] = false;
896+
t[i] = std::thread(
897+
[&](bool *p) {
898+
int data;
899+
if (ch->Receive(&data)) {
900+
EXPECT_EQ(data, 10);
901+
}
902+
*p = true;
903+
},
904+
&thread_ended[i]);
905+
}
906+
907+
// Launches threads that try to close the channel.
908+
for (size_t i = 2 * num_threads / 3; i < num_threads; i++) {
909+
thread_ended[i] = false;
910+
t[i] = std::thread(
911+
[&](bool *p) {
912+
if (!ch->IsClosed()) {
913+
ch->close();
914+
}
915+
*p = true;
916+
},
917+
&thread_ended[i]);
918+
}
919+
920+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
921+
922+
// Verify that all threads are unblocked
923+
for (size_t i = 0; i < num_threads; i++) {
924+
EXPECT_EQ(thread_ended[i], true);
925+
}
926+
EXPECT_TRUE(ch->IsClosed());
927+
// delete the channel
928+
delete ch;
929+
for (size_t i = 0; i < num_threads; i++) t[i].join();
930+
}
931+
932+
TEST(ChannelHolder, ChannelHolderManyTimesCloseTest) {
933+
// Check for Buffered Channel
934+
ChannelHolder *ch = new ChannelHolder();
935+
ch->Reset<int>(10);
936+
ChannelHolderManyTimesClose(ch);
937+
}

paddle/fluid/operators/dropout_op.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
3535
}
3636
};
3737

38-
template <typename AttrType>
3938
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
4039
public:
4140
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
7372
}
7473
};
7574

76-
template <typename AttrType>
7775
class DropoutOpGrad : public framework::OperatorWithKernel {
7876
public:
7977
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
103101
} // namespace paddle
104102

105103
namespace ops = paddle::operators;
106-
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
107-
ops::DropoutOpGrad<float>);
104+
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
105+
ops::DropoutOpGrad);
108106
REGISTER_OP_CPU_KERNEL(
109-
dropout,
110-
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
107+
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
111108
REGISTER_OP_CPU_KERNEL(
112109
dropout_grad,
113110
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);

paddle/fluid/operators/dropout_op.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@ limitations under the License. */
1818
#include <thrust/random.h>
1919
#include <thrust/transform.h>
2020
#include "paddle/fluid/operators/dropout_op.h"
21+
#include "paddle/fluid/platform/float16.h"
2122

2223
namespace paddle {
2324
namespace operators {
2425

25-
template <typename T, typename AttrType>
26+
template <typename T>
2627
__global__ void RandomGenerator(const size_t n, const int seed,
27-
const AttrType dropout_prob, const T* src,
28+
const float dropout_prob, const T* src,
2829
T* mask_data, T* dst) {
2930
thrust::minstd_rand rng;
3031
rng.seed(seed);
31-
thrust::uniform_real_distribution<AttrType> dist(0, 1);
32+
thrust::uniform_real_distribution<float> dist(0, 1);
3233

3334
int idx = blockDim.x * blockIdx.x + threadIdx.x;
3435
for (; idx < n; idx += blockDim.x * gridDim.x) {
@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
4445
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
4546
// Use std::random and thrust::random(thrust is a std library in CUDA) to
4647
// implement uniform random.
47-
template <typename Place, typename T, typename AttrType>
48+
template <typename Place, typename T>
4849
class GPUDropoutKernel : public framework::OpKernel<T> {
4950
public:
5051
void Compute(const framework::ExecutionContext& context) const override {
5152
auto* x = context.Input<Tensor>("X");
5253
auto* y = context.Output<Tensor>("Out");
5354
y->mutable_data<T>(context.GetPlace());
54-
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
55+
float dropout_prob = context.Attr<float>("dropout_prob");
5556

5657
auto X = EigenMatrix<T>::Reshape(*x, 1);
5758
auto Y = EigenMatrix<T>::Reshape(*y, 1);
@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
7071

7172
int threads = 512;
7273
int grid = (x->numel() + threads - 1) / threads;
73-
RandomGenerator<T, AttrType><<<grid, threads, 0,
74-
context.cuda_device_context().stream()>>>(
74+
RandomGenerator<
75+
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
7576
size, seed, dropout_prob, x_data, mask_data, y_data);
7677
} else {
77-
Y.device(place) = X * (1.0f - dropout_prob);
78+
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
7879
}
7980
}
8081
};
@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
8384
} // namespace paddle
8485

8586
namespace ops = paddle::operators;
87+
namespace plat = paddle::platform;
8688
REGISTER_OP_CUDA_KERNEL(
87-
dropout,
88-
ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
89-
REGISTER_OP_CUDA_KERNEL(
90-
dropout_grad,
91-
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
89+
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
90+
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
91+
REGISTER_OP_CUDA_KERNEL(dropout_grad,
92+
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);

paddle/fluid/operators/dropout_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
2525
typename IndexType = Eigen::DenseIndex>
2626
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2727

28-
template <typename DeviceContext, typename T, typename AttrType>
28+
template <typename DeviceContext, typename T>
2929
class CPUDropoutKernel : public framework::OpKernel<T> {
3030
public:
3131
void Compute(const framework::ExecutionContext& context) const override {

paddle/fluid/operators/lod_reset_op.cc

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
2222
using framework::OperatorWithKernel::OperatorWithKernel;
2323

2424
void InferShape(framework::InferShapeContext *ctx) const override {
25-
// input check
2625
PADDLE_ENFORCE(ctx->HasInput("X"),
2726
"Input(X) of LoDResetOp should not be null.");
2827
PADDLE_ENFORCE(ctx->HasOutput("Out"),
2928
"Output(Out) of LoDResetOp should not be null.");
30-
// If target LoD is not set form Input(), then it must be set from Attr().
31-
if (!ctx->HasInput("TargetLoD")) {
29+
30+
if (!ctx->HasInput("Y")) {
3231
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
33-
PADDLE_ENFORCE(level0.size() > 1,
34-
"Target LoD is not found, should be set to be a valid one "
35-
"through Input() or Attr().");
32+
PADDLE_ENFORCE_GT(level0.size(), 1,
33+
"If Input(Y) not provided, the target lod should be "
34+
"specified by attribute `target_lod`.");
3635
}
3736
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
3837
}
@@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
5049
public:
5150
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
5251
: OpProtoAndCheckerMaker(proto, op_checker) {
53-
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
54-
AddInput("TargetLoD",
55-
"(Tensor, optional) The target level 0 LoD from Input().")
52+
AddInput("X",
53+
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
54+
"could be a Tensor or LoDTensor, where the data of output "
55+
"variable inherits from.");
56+
AddInput("Y",
57+
"(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
58+
"lod of Input(Y) would be considered as the target lod first, "
59+
"otherwise data of Input(Y) would be considered as the "
60+
"target lod.")
5661
.AsDispensable();
57-
AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator.");
62+
AddOutput("Out",
63+
"(LoDTensor) Output variable of LoDResetOp which should be a "
64+
"LoDTensor.");
5865
AddAttr<std::vector<int>>("target_lod",
5966
"The target level 0 LoD from Attr().")
6067
.SetDefault(std::vector<int>{});
6168
AddComment(R"DOC(LoDReset operator
6269
63-
Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
64-
Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
65-
Currently the lod_reset operator only supports the reset of level 0 LoD.
66-
At least one of Input(TargetLoD) and Attr(target_lod) must be set,
67-
and if both of them are set, Input(TargetLoD) will be chosen as the
68-
target LoD.
70+
Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
71+
provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
72+
first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
73+
provided, target LoD should be specified by attribute `target_lod`.
74+
If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
75+
is supported.
76+
77+
Example 1:
78+
79+
Given a 1-level LoDTensor input(X):
80+
X.lod = [[ 0, 2, 5 6 ]]
81+
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
82+
X.dims = [6, 1]
83+
84+
attr(target_lod): [0, 4, 6]
85+
86+
then we get a 1-level LoDTensor:
87+
Out.lod = [[ 0, 4, 6 ]]
88+
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
89+
Out.dims = [6, 1]
90+
91+
Example 2:
6992
70-
An example:
71-
Given a float LoDTensor X with shape (6, 1), its transpose form represents
93+
Given a 1-level LoDTensor input(X):
94+
X.lod = [[ 0, 2, 5 6 ]]
95+
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
96+
X.dims = [6, 1]
7297
73-
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
98+
input(Y) is a Tensor:
99+
Y.data = [[0, 2, 6]]
100+
Y.dims = [1, 3]
74101
75-
with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
102+
then we get a 1-level LoDTensor:
103+
Out.lod = [[ 0, 2, 6 ]]
104+
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
105+
Out.dims = [6, 1]
76106
77-
[1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
107+
Example 3:
78108
79-
If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
80-
the sequences that the LoDTensor Output(Out) contains becomes:
109+
Given a 1-level LoDTensor input(X):
110+
X.lod = [[ 0, 2, 5 6 ]]
111+
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
112+
X.dims = [6, 1]
81113
82-
[1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
114+
input(Y) is a 2-level LoDTensor:
115+
Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
116+
Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
117+
Y.dims = [6, 1]
118+
119+
then we get a 2-level LoDTensor:
120+
Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
121+
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
122+
Out.dims = [6, 1]
83123
84124
)DOC");
85125
}
@@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
90130
using framework::OperatorWithKernel::OperatorWithKernel;
91131

92132
void InferShape(framework::InferShapeContext *ctx) const override {
93-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
133+
PADDLE_ENFORCE(ctx->HasInput("X"),
134+
"Input(X) of LoDResetGradOp should not be null.");
94135
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
95-
"Input(Out@GRAD) shouldn't be null.");
96-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
136+
"Input(Out@Grad) of LoDResetGradOp should not be null.");
137+
138+
auto x_grad_name = framework::GradVarName("X");
139+
if (ctx->HasOutput(x_grad_name)) {
140+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
141+
ctx->ShareLoD("X", /*->*/ x_grad_name);
142+
}
97143
}
98144

99145
protected:
@@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
111157
namespace ops = paddle::operators;
112158
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
113159
ops::LoDResetGradOp);
114-
REGISTER_OP_CPU_KERNEL(lod_reset,
115-
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
116-
ops::LoDResetKernel<paddle::platform::CPUPlace, double>);
160+
REGISTER_OP_CPU_KERNEL(
161+
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
162+
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
163+
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
164+
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
117165
REGISTER_OP_CPU_KERNEL(
118166
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
119-
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>);
167+
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
168+
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
169+
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>);

0 commit comments

Comments
 (0)