Skip to content

Commit 11b5c44

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-pserver-crash-when-no-parameter
2 parents a11d4f3 + 90d9e5a commit 11b5c44

File tree

10 files changed

+460
-28
lines changed

10 files changed

+460
-28
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, k
177177
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
178178
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
179179
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
180+
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
181+
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
180182
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
181183
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
182184
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/parallel_executor.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
303303
}
304304

305305
ParallelExecutor::~ParallelExecutor() {
306-
const auto dev_ctxs =
307-
platform::DeviceContextPool::Instance().GetAllDeviceContexts();
308-
for (auto &dev_ctx : dev_ctxs) {
309-
dev_ctx->Wait();
306+
for (auto &p : member_->places_) {
307+
platform::DeviceContextPool::Instance().Get(p)->Wait();
310308
}
311309

312310
if (member_->own_local_scope_) {
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/add_position_encoding_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AddPositionEncodingOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"X(Input) of add_position_encoding_op should not be null.");
27+
PADDLE_ENFORCE(
28+
ctx->HasOutput("Out"),
29+
"Out(Output) of add_position_encoding_op should not be null.");
30+
31+
auto x_dims = ctx->GetInputDim("X");
32+
ctx->SetOutputDim("Out", x_dims);
33+
ctx->ShareLoD("X", /*->*/ "Out");
34+
}
35+
};
36+
37+
class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
38+
public:
39+
using framework::OperatorWithKernel::OperatorWithKernel;
40+
41+
void InferShape(framework::InferShapeContext* ctx) const override {
42+
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) must not be null.");
43+
PADDLE_ENFORCE(ctx->HasInput("Out"), "Out must not be null.");
44+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
45+
"Out@GRAD must not be null.");
46+
47+
auto out_dims = ctx->GetInputDim("Out");
48+
if (ctx->HasOutput(framework::GradVarName("X"))) {
49+
ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
50+
}
51+
}
52+
};
53+
54+
class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker {
55+
public:
56+
void Make() override {
57+
AddInput("X", "Input of AddPositionEncoding operator");
58+
AddOutput("Out", "Output of AddPositionEncoding operator");
59+
AddAttr<float>("alpha", "The scale of Original Embedding.")
60+
.SetDefault(1.0f)
61+
.AddCustomChecker([](const float& alpha) {
62+
PADDLE_ENFORCE(alpha >= 0.0f, "'alpha' must be above 0.0.");
63+
});
64+
AddAttr<float>("beta", "The scale of Position Embedding.")
65+
.SetDefault(1.0f)
66+
.AddCustomChecker([](const float& beta) {
67+
PADDLE_ENFORCE(beta >= 0.0f, "'beta' must be between 0.0.");
68+
});
69+
AddComment(R"DOC(
70+
Add Position Encoding Operator.
71+
72+
The add position encoding calculates the output based on the input, alpha, beta.
73+
The size of each dimension of the parameters checked in the infer-shape.
74+
)DOC");
75+
}
76+
};
77+
78+
} // namespace operators
79+
} // namespace paddle
80+
81+
namespace ops = paddle::operators;
82+
namespace plt = paddle::platform;
83+
84+
REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp,
85+
ops::AddPositionEncodingOpMaker,
86+
paddle::framework::DefaultGradOpDescMaker<true>);
87+
REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad);
88+
89+
REGISTER_OP_CPU_KERNEL(
90+
add_position_encoding,
91+
ops::AddPositionEncodingKernel<plt::CPUDeviceContext, float>,
92+
ops::AddPositionEncodingKernel<plt::CPUDeviceContext, double>);
93+
94+
REGISTER_OP_CPU_KERNEL(
95+
add_position_encoding_grad,
96+
ops::AddPositionEncodingGradKernel<plt::CPUDeviceContext, float>,
97+
ops::AddPositionEncodingGradKernel<plt::CPUDeviceContext, double>);
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/detail/safe_ref.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class AddPositionEncodingKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& context) const override {
27+
auto* X = context.Input<framework::LoDTensor>("X");
28+
auto& x_lod = X->lod();
29+
auto* src_ptr = X->data<T>();
30+
31+
auto* Out = context.Output<framework::LoDTensor>("Out");
32+
auto* dst_ptr = Out->mutable_data<T>(context.GetPlace());
33+
34+
float alpha = context.Attr<float>("alpha");
35+
float beta = context.Attr<float>("beta");
36+
37+
auto x_dim = X->dims();
38+
int batch_size = 0;
39+
int max_seq_len = 0;
40+
int enc_size = 0;
41+
42+
if (x_lod.empty()) {
43+
PADDLE_ENFORCE(
44+
x_dim.size() == 3UL,
45+
"The input X of Add Position Encoding should be 3-D Tensor!");
46+
batch_size = x_dim[0];
47+
max_seq_len = x_dim[1];
48+
enc_size = x_dim[2];
49+
} else {
50+
PADDLE_ENFORCE(
51+
x_dim.size() == 2UL,
52+
"The input X of Add Position Encoding should be 2-D LoDTensor!");
53+
PADDLE_ENFORCE(
54+
x_lod.size() == 1UL,
55+
"The Add Position Encoding Op only supports lod_level == 1!");
56+
batch_size = x_lod[0].size() - 1;
57+
max_seq_len = -1;
58+
enc_size = x_dim[1];
59+
}
60+
61+
PADDLE_ENFORCE(enc_size % 2 == 0, "Only support even encode size!");
62+
63+
const int half_size = enc_size / 2;
64+
for (int i = 0; i < batch_size; ++i) {
65+
const int max_length =
66+
x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i];
67+
for (int j = 0; j < max_length; ++j) {
68+
for (int k = 0; k < half_size; ++k) {
69+
const double val = (half_size > 1)
70+
? j / pow(10000.0, double(k) / (half_size - 1))
71+
: j / 10000.0;
72+
dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta;
73+
dst_ptr[half_size + k] =
74+
src_ptr[half_size + k] * alpha + cos(val) * beta;
75+
}
76+
src_ptr += enc_size;
77+
dst_ptr += enc_size;
78+
}
79+
}
80+
}
81+
};
82+
83+
template <typename DeviceContext, typename T>
84+
class AddPositionEncodingGradKernel : public framework::OpKernel<T> {
85+
public:
86+
void Compute(const framework::ExecutionContext& context) const override {
87+
auto* dOut =
88+
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
89+
auto dout = framework::EigenVector<T>::Flatten(*dOut);
90+
91+
auto* dX =
92+
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
93+
dX->mutable_data<T>(context.GetPlace());
94+
auto dx = framework::EigenVector<T>::Flatten(*dX);
95+
96+
float alpha = context.Attr<float>("alpha");
97+
98+
auto* place =
99+
context.template device_context<DeviceContext>().eigen_device();
100+
dx.device(*place) = dout * static_cast<T>(alpha);
101+
}
102+
};
103+
104+
} // namespace operators
105+
} // namespace paddle

paddle/fluid/platform/device_context.cc

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
3232
"'Place' is not supported, Please re-compile with WITH_GPU "
3333
"option");
3434
}
35-
return it->second.get();
35+
return it->second.get().get();
3636
}
3737

38-
const std::vector<const DeviceContext*>
39-
DeviceContextPool::GetAllDeviceContexts() const {
40-
std::vector<const DeviceContext*> all_device_ctx;
41-
all_device_ctx.reserve(device_contexts_.size());
42-
for (auto& dev_ctx : device_contexts_) {
43-
all_device_ctx.emplace_back(dev_ctx.second.get());
44-
}
45-
return all_device_ctx;
38+
template <typename DevCtx, typename PlaceType>
39+
inline void EmplaceDeviceContext(
40+
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
41+
map_ptr,
42+
platform::Place p) {
43+
using PtrType = std::unique_ptr<DeviceContext>;
44+
map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
45+
// lazy evaluation. i.e., only create device context at
46+
// first `Get`
47+
return PtrType(new DevCtx(boost::get<PlaceType>(p)));
48+
}));
4649
}
4750

4851
DeviceContextPool::DeviceContextPool(
4952
const std::vector<platform::Place>& places) {
5053
PADDLE_ENFORCE_GT(places.size(), 0);
51-
using PtrType = std::unique_ptr<DeviceContext>;
5254
std::set<Place> set;
5355
for (auto& p : places) {
5456
set.insert(p);
@@ -57,26 +59,22 @@ DeviceContextPool::DeviceContextPool(
5759
for (auto& p : set) {
5860
if (platform::is_cpu_place(p)) {
5961
#ifdef PADDLE_WITH_MKLDNN
60-
device_contexts_.emplace(
61-
p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
62+
EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
6263
#else
63-
device_contexts_.emplace(
64-
p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
64+
EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
6565
#endif
6666
} else if (platform::is_gpu_place(p)) {
6767
#ifdef PADDLE_WITH_CUDA
68-
device_contexts_.emplace(
69-
p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
68+
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
7069
#else
7170
PADDLE_THROW(
7271
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
7372
"option");
7473
#endif
7574
} else if (platform::is_cuda_pinned_place(p)) {
7675
#ifdef PADDLE_WITH_CUDA
77-
device_contexts_.emplace(
78-
p,
79-
PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
76+
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
77+
&device_contexts_, p);
8078
#else
8179
PADDLE_THROW(
8280
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "

paddle/fluid/platform/device_context.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
1010
limitations under the License. */
1111
#pragma once
1212

13+
#include <future> // NOLINT
1314
#include <memory>
1415
#include <mutex> // NOLINT
1516
#include <string>
@@ -223,9 +224,6 @@ class DeviceContextPool {
223224
/*! \brief Return handle of single device context. */
224225
platform::DeviceContext* Get(const platform::Place& place);
225226

226-
/*! \brief Return all the device contexts. */
227-
const std::vector<const DeviceContext*> GetAllDeviceContexts() const;
228-
229227
template <typename Place>
230228
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
231229
const Place& place) {
@@ -237,7 +235,8 @@ class DeviceContextPool {
237235

238236
private:
239237
static DeviceContextPool* pool;
240-
std::map<Place, std::unique_ptr<DeviceContext>> device_contexts_;
238+
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
239+
device_contexts_;
241240
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
242241
};
243242

0 commit comments

Comments
 (0)