Skip to content

Commit aaf818f

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/add_fwd_block_id
2 parents 65058cf + bf9ed4a commit aaf818f

File tree

20 files changed

+759
-308
lines changed

20 files changed

+759
-308
lines changed

doc/design/parallel_do.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ A vanilla implementation of parallel_do can be shown as the following (`|` means
2424
```
2525
In the forward pass
2626
| Split input onto different devices
27-
| Copy parameter to onto different devices
27+
| Copy parameter onto different devices
2828
|||| Compute forward pass in parallel
2929
| Merge output from different devices
3030

@@ -87,7 +87,7 @@ block2 {
8787
}
8888
```
8989

90-
## Proformance Imporvement
90+
## Performance Imporvement
9191

9292
There are serial places we can make this parallel_do faster.
9393

paddle/fluid/framework/block_desc.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,25 @@ bool BlockDesc::HasVar(const std::string &name) const {
4444
return vars_.find(name) != vars_.end();
4545
}
4646

47+
VarDesc *BlockDesc::RenameVar(const std::string &old_name,
48+
const std::string &new_name) {
49+
if (!this->HasVar(old_name)) {
50+
return nullptr;
51+
}
52+
need_update_ = true;
53+
auto *var = this->Var(old_name);
54+
VarDesc *new_var = new VarDesc(*(var->Proto()));
55+
new_var->SetName(new_name);
56+
vars_[new_name].reset(new_var);
57+
// rename inputs and outputs
58+
for (const auto &op : ops_) {
59+
auto *it = op.get();
60+
it->Rename(old_name, new_name);
61+
}
62+
vars_.erase(old_name);
63+
return new_var;
64+
}
65+
4766
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
4867
if (name == kEmptyVarName) return nullptr;
4968

paddle/fluid/framework/block_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class BlockDesc {
5757

5858
bool HasVar(const std::string &var_name) const;
5959

60+
VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);
61+
6062
VarDesc *FindVarRecursive(const std::string &name_bytes) const;
6163

6264
VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);

paddle/fluid/framework/channel.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <stddef.h> // for size_t
18+
#include <typeindex>
19+
#include "paddle/fluid/platform/enforce.h"
1820

1921
namespace paddle {
2022
namespace framework {
@@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
5153
ch->Close();
5254
}
5355

56+
/*
57+
* The ChannelHolder class serves two main purposes:
58+
* 1. It acts as a unified wrapper for the different kinds of
59+
* channels, i.e. Buffered and Unbuffered channels. This is
60+
* similar to the ReaderHolder class.
61+
* 2. It also helps us in TypeHiding. This is similar to the
62+
* PlaceHolder implementations in variable.h and tensor.h.
63+
*/
64+
class ChannelHolder {
65+
public:
66+
template <typename T>
67+
void Reset(size_t buffer_size) {
68+
holder_.reset(new PlaceholderImpl<T>(buffer_size));
69+
}
70+
71+
template <typename T>
72+
bool Send(T* data) {
73+
if (!IsInitialized()) return false;
74+
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
75+
// Static cast should be safe because we have ensured that types are same
76+
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
77+
return channel != nullptr ? channel->Send(data) : false;
78+
}
79+
80+
template <typename T>
81+
bool Receive(T* data) {
82+
if (!IsInitialized()) return false;
83+
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
84+
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
85+
return channel != nullptr ? channel->Receive(data) : false;
86+
}
87+
88+
void close() {
89+
if (IsInitialized()) holder_->Close();
90+
}
91+
92+
inline bool IsInitialized() const { return holder_ != nullptr; }
93+
94+
private:
95+
/**
96+
* @note Placeholder hides type T, so it doesn't appear as a template
97+
* parameter of ChannelHolder.
98+
*/
99+
struct Placeholder {
100+
virtual ~Placeholder() {}
101+
virtual const std::type_index Type() const = 0;
102+
virtual void* Ptr() const = 0;
103+
virtual void Close() const = 0;
104+
std::type_info type_;
105+
};
106+
107+
template <typename T>
108+
struct PlaceholderImpl : public Placeholder {
109+
PlaceholderImpl(size_t buffer_size) : type_(std::type_index(typeid(T))) {
110+
channel_.reset(MakeChannel<T>(buffer_size));
111+
}
112+
113+
virtual const std::type_index Type() const { return type_; }
114+
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
115+
virtual void Close() {
116+
if (channel_) channel_->Close();
117+
}
118+
119+
std::unique_ptr<Channel<T>*> channel_;
120+
const std::type_index type_;
121+
};
122+
123+
// Pointer to a PlaceholderImpl object
124+
std::unique_ptr<Placeholder> holder_;
125+
};
126+
54127
} // namespace framework
55128
} // namespace paddle
56129

paddle/fluid/framework/executor.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <set>
1818

1919
#include "gflags/gflags.h"
20+
#include "paddle/fluid/framework/channel.h"
2021
#include "paddle/fluid/framework/feed_fetch_method.h"
2122
#include "paddle/fluid/framework/feed_fetch_type.h"
2223
#include "paddle/fluid/framework/lod_rank_table.h"
@@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
5556
var->GetMutable<platform::PlaceList>();
5657
} else if (var_type == proto::VarType::READER) {
5758
var->GetMutable<ReaderHolder>();
59+
} else if (var_type == proto::VarType::CHANNEL) {
60+
var->GetMutable<ChannelHolder>();
5861
} else if (var_type == proto::VarType::NCCL_COM) {
5962
// GetMutable will be called in ncclInit
6063
} else {
6164
PADDLE_THROW(
6265
"Variable type %d is not in "
6366
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
64-
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
67+
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
6568
var_type);
6669
}
6770
}

paddle/fluid/framework/var_desc.cc

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
8888
}
8989

9090
void VarDesc::SetDataType(proto::VarType::Type data_type) {
91-
mutable_tensor_desc()->set_data_type(data_type);
91+
switch (desc_.type().type()) {
92+
case proto::VarType::CHANNEL:
93+
mutable_channel_desc()->set_data_type(data_type);
94+
break;
95+
default:
96+
mutable_tensor_desc()->set_data_type(data_type);
97+
}
9298
}
9399

94100
void VarDesc::SetDataTypes(
@@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
109115
}
110116

111117
proto::VarType::Type VarDesc::GetDataType() const {
112-
return tensor_desc().data_type();
118+
switch (desc_.type().type()) {
119+
case proto::VarType::CHANNEL:
120+
return channel_desc().data_type();
121+
break;
122+
default:
123+
return tensor_desc().data_type();
124+
}
113125
}
114126

115127
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
@@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
122134
return res;
123135
}
124136

137+
void VarDesc::SetCapacity(int64_t capacity) {
138+
switch (desc_.type().type()) {
139+
case proto::VarType::CHANNEL:
140+
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
141+
break;
142+
default:
143+
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
144+
this->Name());
145+
}
146+
}
147+
125148
void VarDesc::SetLoDLevel(int32_t lod_level) {
126149
switch (desc_.type().type()) {
127150
case proto::VarType::LOD_TENSOR:
@@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
191214
}
192215
}
193216

217+
const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
218+
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
219+
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
220+
switch (desc_.type().type()) {
221+
case proto::VarType::CHANNEL:
222+
return desc_.type().channel();
223+
default:
224+
PADDLE_THROW(
225+
"Getting 'channel_desc' is not supported by the type of var %s.",
226+
this->Name());
227+
}
228+
}
229+
194230
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
195231
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
196232
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
@@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
226262
}
227263
}
228264

265+
proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
266+
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
267+
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
268+
switch (desc_.type().type()) {
269+
case proto::VarType::CHANNEL:
270+
return desc_.mutable_type()->mutable_channel();
271+
default:
272+
PADDLE_THROW(
273+
"Getting 'mutable_channel_desc' is not supported by the type of var "
274+
"%s.",
275+
this->Name());
276+
}
277+
}
278+
229279
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
230280
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
231281
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");

paddle/fluid/framework/var_desc.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class VarDesc {
8585
void SetDataTypes(
8686
const std::vector<proto::VarType::Type> &multiple_data_type);
8787

88+
void SetCapacity(int64_t capacity);
89+
8890
proto::VarType::Type GetDataType() const;
8991

9092
std::vector<proto::VarType::Type> GetDataTypes() const;
@@ -106,8 +108,10 @@ class VarDesc {
106108
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
107109

108110
private:
111+
const proto::VarType::ChannelDesc &channel_desc() const;
109112
const proto::VarType::TensorDesc &tensor_desc() const;
110113
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
114+
proto::VarType::ChannelDesc *mutable_channel_desc();
111115
proto::VarType::TensorDesc *mutable_tensor_desc();
112116
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
113117

paddle/fluid/framework/var_type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include "paddle/fluid/framework/channel.h"
1617
#include "paddle/fluid/framework/framework.pb.h"
1718
#include "paddle/fluid/framework/lod_rank_table.h"
1819
#include "paddle/fluid/framework/lod_tensor.h"
@@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
3435
return proto::VarType_Type_SELECTED_ROWS;
3536
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
3637
return proto::VarType_Type_READER;
38+
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
39+
return proto::VarType_Type_CHANNEL;
3740
} else {
3841
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
3942
}
@@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
5760
case proto::VarType_Type_READER:
5861
visitor(var.Get<ReaderHolder>());
5962
return;
63+
case proto::VarType_Type_CHANNEL:
64+
visitor(var.Get<ChannelHolder>());
65+
return;
6066
default:
6167
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
6268
}

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,59 +41,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
4141
};
4242

4343
template <typename T>
44-
struct ElementwiseAddGradFunctor {
45-
template <typename Device, typename X, typename Y, typename Z, typename dX,
46-
typename dY, typename dZ>
47-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
48-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
49-
if (dx) {
50-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
51-
dx_e.device(d) = dz_e;
52-
}
53-
if (dy) {
54-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
55-
dy_e.device(d) = dz_e;
56-
}
57-
}
58-
};
59-
60-
template <typename T>
61-
struct ElementwiseAddBroadCastGradFunctor {
62-
template <typename Device, typename X, typename Y, typename Z, typename dX,
63-
typename dY, typename dZ, typename Pre, typename N>
64-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
65-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
66-
if (dx) {
67-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
68-
dx_e.device(d) = dz_e;
69-
}
70-
71-
if (dy) {
72-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
73-
dy_e.device(d) = dz_e.reshape(Eigen::DSizes<int, 2>(pre, n))
74-
.sum(Eigen::array<int, 1>{{0}});
75-
}
76-
}
77-
};
78-
79-
template <typename T>
80-
struct ElementwiseAddBroadCast2GradFunctor {
81-
template <typename Device, typename X, typename Y, typename Z, typename dX,
82-
typename dY, typename dZ, typename Pre, typename N, typename Post>
83-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
84-
Post post) {
85-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
86-
if (dx) {
87-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
88-
dx_e.device(d) = dz_e;
89-
}
90-
91-
if (dy) {
92-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
93-
dy_e.device(d) = dz_e.reshape(Eigen::DSizes<int, 3>(pre, n, post))
94-
.sum(Eigen::array<int, 2>{{0, 2}});
95-
}
96-
}
44+
struct IdentityGrad {
45+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
9746
};
9847

9948
template <typename DeviceContext, typename T>
@@ -109,10 +58,9 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
10958
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
11059
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
11160
int axis = ctx.Attr<int>("axis");
112-
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
113-
ElementwiseAddBroadCastGradFunctor<T>,
114-
ElementwiseAddBroadCast2GradFunctor<T>>(
115-
ctx, x, y, out, dout, axis, dx, dy);
61+
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
62+
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
63+
IdentityGrad<T>());
11664
}
11765
};
11866

0 commit comments

Comments
 (0)