Skip to content

Commit 4ff1bde

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/hide_api_cont
2 parents 4dccb58 + ebe3b5e commit 4ff1bde

38 files changed

+2773
-825
lines changed

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
16+
#include <stdexcept>
1617
#include <string>
1718
#include <vector>
1819
#include "paddle/fluid/framework/executor.h"
@@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
5354
}
5455
}
5556
}
57+
std::vector<framework::LoDTensor> fetch_data;
58+
std::exception_ptr eptr;
59+
try {
60+
fetch_data = underlying_executor_->Run(fetch_tensors);
61+
} catch (...) {
62+
eptr = std::current_exception();
63+
}
5664

57-
auto fetch_data = underlying_executor_->Run(fetch_tensors);
5865
drop_scope_counter_ += 1;
5966
if (!fetch_tensors.empty() ||
6067
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
@@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
6976
scope->DeleteScope(local_scope);
7077
}
7178
}
72-
return fetch_data;
79+
if (eptr) {
80+
std::rethrow_exception(eptr);
81+
} else {
82+
return fetch_data;
83+
}
7384
}
7485
} // namespace details
7586
} // namespace framework

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
7878
set.clear();
7979
};
8080

81+
// Clean run context
82+
run_op_futures_.clear();
83+
exception_.reset();
84+
8185
// Step 3. Execution
8286
while (!pending_vars.empty()) {
8387
// 1. Run All Ready ops
@@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
96100
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
97101

98102
if (timeout) {
99-
std::lock_guard<std::mutex> l(exception_mu_);
103+
std::unique_lock<std::mutex> l(exception_mu_);
100104
if (exception_) {
105+
l.unlock();
106+
for (auto &run_op_future : run_op_futures_) {
107+
run_op_future.wait();
108+
}
109+
l.lock();
101110
std::exception *exp = exception_.get();
102111
if (dynamic_cast<platform::EOFException *>(exp)) {
103112
auto e = *static_cast<platform::EOFException *>(exp);
104-
exception_.reset();
105113
throw e;
106114
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
107115
auto e = *static_cast<platform::EnforceNotMet *>(exp);
108-
exception_.reset();
109116
throw e;
110117
} else {
111118
LOG(FATAL) << "Unknown exception.";
@@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp(
222229
}
223230
};
224231
if (pool_) {
225-
pool_->enqueue(op_run);
232+
run_op_futures_.emplace_back(pool_->enqueue(op_run));
226233
} else {
227234
op_run();
228235
}

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <deque>
18+
#include <list>
1819
#include <string>
1920
#include <unordered_set>
2021
#include <utility>
@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
7778

7879
private:
7980
ExecutionStrategy strategy_;
81+
// use std::list because clear(), push_back, and for_each are O(1)
82+
std::list<std::future<void>> run_op_futures_;
8083
};
8184

8285
} // namespace details

paddle/fluid/framework/parallel_executor.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ ParallelExecutor::ParallelExecutor(
9595
}
9696

9797
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
98-
BCastParamsToGPUs(bcast_vars);
98+
BCastParamsToDevs(bcast_vars);
9999
}
100100
// Startup Program has been run. All local scopes has correct parameters.
101101

@@ -131,7 +131,7 @@ ParallelExecutor::ParallelExecutor(
131131
member_->places_, std::move(member_->executor_)));
132132
}
133133

134-
void ParallelExecutor::BCastParamsToGPUs(
134+
void ParallelExecutor::BCastParamsToDevs(
135135
const std::unordered_set<std::string> &vars) const {
136136
// the the initializing bcast, all vars would be bcast from device(0),
137137
// otherwise
@@ -202,7 +202,11 @@ void ParallelExecutor::BCastParamsToGPUs(
202202
#endif
203203
} else {
204204
platform::CPUPlace cpu;
205-
for (size_t i = 1; i < member_->places_.size(); ++i) {
205+
for (size_t i = 0; i < member_->places_.size(); ++i) {
206+
if ((initializing && i == 0) ||
207+
(!initializing && static_cast<int>(i) == var_dev_id))
208+
continue;
209+
206210
auto local_scope = member_->local_scopes_[i];
207211
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
208212
t->Resize(dims);

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ParallelExecutor {
6666
void Run(const std::vector<std::string> &fetch_tensors,
6767
const std::string &fetched_var_name);
6868

69-
void BCastParamsToGPUs(const std::unordered_set<std::string> &vars) const;
69+
void BCastParamsToDevs(const std::unordered_set<std::string> &vars) const;
7070

7171
private:
7272
ParallelExecutorPrivate *member_;

paddle/fluid/framework/reader.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ enum ReaderStatus { kRunning, kStopped };
2929

3030
class ReaderBase {
3131
public:
32-
void ReadNext(std::vector<LoDTensor>* out);
32+
virtual void ReadNext(std::vector<LoDTensor>* out);
3333

34-
void Shutdown();
34+
virtual void Shutdown();
3535

36-
void Start();
36+
virtual void Start();
3737

3838
// Return the readers which are the end of decorating chain. Basically
3939
// they are readers just before read op.
@@ -42,7 +42,7 @@ class ReaderBase {
4242
virtual ~ReaderBase();
4343

4444
protected:
45-
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
45+
virtual void ReadNextImpl(std::vector<LoDTensor>* out) {}
4646

4747
virtual void ShutdownImpl() {}
4848

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,15 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
259259
op_library(sequence_conv_op DEPS context_project)
260260
op_library(sequence_pool_op DEPS sequence_pooling)
261261
op_library(lstm_op DEPS sequence2batch lstm_compute)
262+
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
262263
op_library(lstmp_op DEPS sequence2batch lstm_compute)
263264
op_library(gru_op DEPS sequence2batch gru_compute)
264265
op_library(recurrent_op DEPS executor)
265266
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
266267
op_library(cos_sim_op DEPS cos_sim_functor)
267268
op_library(parallel_do_op DEPS executor)
269+
op_library(unsqueeze_op DEPS reshape_op)
270+
op_library(squeeze_op DEPS reshape_op)
268271

269272
if (WITH_GPU)
270273
op_library(conv_op DEPS vol2col depthwise_conv im2col)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
/**
22+
* Organize the classes into a binary tree. At each node, a sigmoid function
23+
* is used to calculate the probability of belonging to the right branch.
24+
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
25+
* Hierarchical Probabilistic Neural Network Language Model."
26+
*
27+
* Here we uses a simple way of making the binary tree.
28+
* Assuming the number of classes C = 6,
29+
* The classes are organized as a binary tree in the following way:
30+
*
31+
* @code{.py}
32+
* *-*-*- 2
33+
* | | |- 3
34+
* | |
35+
* | |-*- 4
36+
* | |- 5
37+
* |
38+
* |-*- 0
39+
* |- 1
40+
* @endcode
41+
*
42+
* where * indicates an internal node, and each leaf node represents a class.
43+
* - Node 0 ... C-2 are internal nodes.
44+
* - Node C-1 ... 2C-2 are leaf nodes.
45+
* - Class c is represented by leaf node \f$c+C-1\f$.
46+
*
47+
* We assign an id for each node:
48+
* - the id of root be 0.
49+
* - the left child of a node i is 2*i+1.
50+
* - the right child of a node i is 2*i+2.
51+
*
52+
* It's easy to see that:
53+
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
54+
* - the j-th level ancestor of node i is
55+
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
56+
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
57+
*
58+
*/
59+
60+
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
61+
public:
62+
using framework::OperatorWithKernel::OperatorWithKernel;
63+
void InferShape(framework::InferShapeContext* ctx) const override {
64+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
65+
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
66+
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
67+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
68+
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
69+
"Output(PreOut) should not be null.");
70+
const int64_t batch_size = ctx->GetInputDim("X")[0];
71+
std::vector<int64_t> output_shape({batch_size, 1});
72+
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
73+
}
74+
75+
protected:
76+
framework::OpKernelType GetExpectedKernelType(
77+
const framework::ExecutionContext& ctx) const override {
78+
return framework::OpKernelType(
79+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
80+
ctx.GetPlace());
81+
}
82+
};
83+
84+
template <typename AttrType>
85+
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
86+
public:
87+
void Make() override {
88+
AddInput("X",
89+
"(Tensor, required) The input tensor with shape [N, D], "
90+
"where N is the size of mini-batch, and D is the feature size.");
91+
AddInput("W",
92+
"(Tensor, required), The parameters of hierarchical "
93+
"sigmoid operator, each of them is a 2-D tensor, the shape is"
94+
"[num_classes - 1, D].");
95+
AddInput("Label",
96+
"(Tensor, required), The labels of training data. It's a"
97+
"tensor with shape [N, 1].");
98+
AddInput("Bias",
99+
"(Tensor, optional), The bias is a tensor with shape"
100+
"[1, num_classes - 1].");
101+
AddOutput("Out",
102+
"(Tensor, required) The output of hierarchical sigmoid operator."
103+
"The shape is [N, 1].");
104+
AddOutput("PreOut",
105+
"(Tensor, required) A intermedia 2-D tensor with shape "
106+
"[batch_size, code_length], where code_length represents the "
107+
"maximum path length from root to leaf nodes.")
108+
.AsIntermediate();
109+
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
110+
.SetDefault(2);
111+
AddComment(R"DOC(
112+
The hierarchical sigmoid operator organize the classes into a binary tree.
113+
At each node, a sigmoid function is used to calculate the probability of
114+
belonging to the right branch. This idea is from
115+
"F. Morin, Y. Bengio (AISTATS 05):
116+
Hierarchical Probabilistic Neural Network Language Model."
117+
)DOC");
118+
}
119+
};
120+
121+
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
122+
public:
123+
using framework::OperatorWithKernel::OperatorWithKernel;
124+
void InferShape(framework::InferShapeContext* ctx) const override {
125+
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
126+
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
127+
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
128+
"Input(Preout) should not be null.");
129+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
130+
"Output(W@Grad should not be null.)");
131+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
132+
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
133+
ctx->SetOutputDim(framework::GradVarName("Bias"),
134+
ctx->GetInputDim("Bias"));
135+
}
136+
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
137+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
138+
}
139+
140+
protected:
141+
framework::OpKernelType GetExpectedKernelType(
142+
const framework::ExecutionContext& ctx) const override {
143+
return framework::OpKernelType(
144+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
145+
ctx.GetPlace());
146+
}
147+
};
148+
149+
} // namespace operators
150+
} // namespace paddle
151+
152+
namespace ops = paddle::operators;
153+
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
154+
ops::HierarchicalSigmoidOpMaker<int>,
155+
paddle::framework::DefaultGradOpDescMaker<true>);
156+
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
157+
REGISTER_OP_CPU_KERNEL(
158+
hierarchical_sigmoid,
159+
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
160+
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
161+
double>);
162+
REGISTER_OP_CPU_KERNEL(
163+
hierarchical_sigmoid_grad,
164+
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
165+
float>,
166+
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
167+
double>);

0 commit comments

Comments
 (0)