Skip to content

Commit 78a3b33

Browse files
authored
Merge pull request #14125 from jacquesqiao/cherry-pick-cpu-dist
Merge pull request #14103 from jacquesqiao/cpu-for-1.1-merge-with-shape
2 parents a035c40 + 415f500 commit 78a3b33

39 files changed

+1012
-545
lines changed

paddle/fluid/framework/attribute.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
6464
case proto::AttrType::LONG: {
6565
return attr_desc.l();
6666
}
67+
case proto::AttrType::LONGS: {
68+
std::vector<int64_t> val(attr_desc.longs_size());
69+
for (int i = 0; i < attr_desc.longs_size(); ++i) {
70+
val[i] = attr_desc.longs(i);
71+
}
72+
return val;
73+
}
6774
default:
6875
PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
6976
}

paddle/fluid/framework/attribute.h

Lines changed: 115 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,113 @@ limitations under the License. */
2626

2727
namespace paddle {
2828
namespace framework {
29+
30+
template <typename T>
31+
struct ExtractAttribute {
32+
explicit ExtractAttribute(const std::string& attr_name)
33+
: attr_name_(attr_name) {}
34+
35+
T* operator()(Attribute& attr) const {
36+
T* attr_value = nullptr;
37+
try {
38+
attr_value = &boost::get<T>(attr);
39+
} catch (boost::bad_get& bad_get) {
40+
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
41+
attr_name_, paddle::platform::demangle(typeid(T).name()),
42+
paddle::platform::demangle(attr.type().name()));
43+
}
44+
return attr_value;
45+
}
46+
47+
const std::string& attr_name_;
48+
};
49+
50+
// special handle bool
51+
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
52+
// hard to change the logic there. In another way, we should correct handle
53+
// if the user set `some_flag=1`.
54+
//
55+
// FIX ME anytime if there is a better solution.
56+
template <>
57+
struct ExtractAttribute<bool> {
58+
explicit ExtractAttribute(const std::string& attr_name)
59+
: attr_name_(attr_name) {}
60+
61+
bool* operator()(Attribute& attr) const {
62+
if (attr.type() == typeid(int)) { // NOLINT
63+
int val = boost::get<int>(attr);
64+
attr = static_cast<bool>(val);
65+
} else if (attr.type() == typeid(float)) { // NOLINT
66+
float val = boost::get<float>(attr);
67+
attr = static_cast<bool>(val);
68+
}
69+
bool* attr_value = nullptr;
70+
try {
71+
attr_value = &boost::get<bool>(attr);
72+
} catch (boost::bad_get& bad_get) {
73+
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
74+
attr_name_, paddle::platform::demangle(attr.type().name()));
75+
}
76+
return attr_value;
77+
}
78+
79+
const std::string& attr_name_;
80+
};
81+
82+
template <>
83+
struct ExtractAttribute<int64_t> {
84+
explicit ExtractAttribute(const std::string& attr_name)
85+
: attr_name_(attr_name) {}
86+
87+
int64_t* operator()(Attribute& attr) const {
88+
if (attr.type() == typeid(int)) { // NOLINT
89+
int val = boost::get<int>(attr);
90+
attr = static_cast<int64_t>(val);
91+
} else if (attr.type() == typeid(float)) { // NOLINT
92+
int val = boost::get<float>(attr);
93+
attr = static_cast<int64_t>(val);
94+
}
95+
int64_t* attr_value = nullptr;
96+
try {
97+
attr_value = &boost::get<int64_t>(attr);
98+
} catch (boost::bad_get& bad_get) {
99+
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
100+
attr_name_, paddle::platform::demangle(attr.type().name()));
101+
}
102+
return attr_value;
103+
}
104+
105+
const std::string& attr_name_;
106+
};
107+
108+
template <>
109+
struct ExtractAttribute<std::vector<int64_t>> {
110+
explicit ExtractAttribute(const std::string& attr_name)
111+
: attr_name_(attr_name) {}
112+
113+
std::vector<int64_t>* operator()(Attribute& attr) const {
114+
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
115+
std::vector<int> val = boost::get<std::vector<int>>(attr);
116+
std::vector<int64_t> vec(val.begin(), val.end());
117+
attr = vec;
118+
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
119+
std::vector<float> val = boost::get<std::vector<float>>(attr);
120+
std::vector<int64_t> vec(val.begin(), val.end());
121+
attr = vec;
122+
}
123+
std::vector<int64_t>* attr_value = nullptr;
124+
try {
125+
attr_value = &boost::get<std::vector<int64_t>>(attr);
126+
} catch (boost::bad_get& bad_get) {
127+
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
128+
attr_name_, paddle::platform::demangle(attr.type().name()));
129+
}
130+
return attr_value;
131+
}
132+
133+
const std::string& attr_name_;
134+
};
135+
29136
template <typename T>
30137
inline proto::AttrType AttrTypeID() {
31138
Attribute tmp = T();
@@ -42,7 +149,11 @@ class AttrReader {
42149
inline const T& Get(const std::string& name) const {
43150
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
44151
name);
45-
return boost::get<T>(attrs_.at(name));
152+
153+
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
154+
ExtractAttribute<T> extract_attr(name);
155+
T* attr_value = extract_attr(attr);
156+
return *attr_value;
46157
}
47158

48159
private:
@@ -82,7 +193,7 @@ class DefaultValueSetter {
82193
public:
83194
explicit DefaultValueSetter(T default_value)
84195
: default_value_(default_value) {}
85-
void operator()(T& value) const { value = default_value_; }
196+
void operator()(T& value) const { value = default_value_; } // NOLINT
86197

87198
private:
88199
T default_value_;
@@ -117,84 +228,6 @@ class EnumInContainer {
117228
std::unordered_set<T> container_;
118229
};
119230

120-
template <typename T>
121-
struct ExtractAttribute {
122-
explicit ExtractAttribute(const std::string& attr_name)
123-
: attr_name_(attr_name) {}
124-
125-
T* operator()(Attribute& attr) const {
126-
T* attr_value = nullptr;
127-
try {
128-
attr_value = &boost::get<T>(attr);
129-
} catch (boost::bad_get& bad_get) {
130-
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
131-
attr_name_, paddle::platform::demangle(typeid(T).name()),
132-
paddle::platform::demangle(attr.type().name()));
133-
}
134-
return attr_value;
135-
}
136-
137-
const std::string& attr_name_;
138-
};
139-
140-
// special handle bool
141-
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
142-
// hard to change the logic there. In another way, we should correct handle
143-
// if the user set `some_flag=1`.
144-
//
145-
// FIX ME anytime if there is a better solution.
146-
template <>
147-
struct ExtractAttribute<bool> {
148-
explicit ExtractAttribute(const std::string& attr_name)
149-
: attr_name_(attr_name) {}
150-
151-
bool* operator()(Attribute& attr) const {
152-
if (attr.type() == typeid(int)) { // NOLINT
153-
int val = boost::get<int>(attr);
154-
attr = static_cast<bool>(val);
155-
} else if (attr.type() == typeid(float)) { // NOLINT
156-
float val = boost::get<float>(attr);
157-
attr = static_cast<bool>(val);
158-
}
159-
bool* attr_value = nullptr;
160-
try {
161-
attr_value = &boost::get<bool>(attr);
162-
} catch (boost::bad_get& bad_get) {
163-
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
164-
attr_name_, paddle::platform::demangle(attr.type().name()));
165-
}
166-
return attr_value;
167-
}
168-
169-
const std::string& attr_name_;
170-
};
171-
172-
template <>
173-
struct ExtractAttribute<int64_t> {
174-
explicit ExtractAttribute(const std::string& attr_name)
175-
: attr_name_(attr_name) {}
176-
177-
int64_t* operator()(Attribute& attr) const {
178-
if (attr.type() == typeid(int)) { // NOLINT
179-
int val = boost::get<int>(attr);
180-
attr = static_cast<int64_t>(val);
181-
} else if (attr.type() == typeid(float)) { // NOLINT
182-
int val = boost::get<float>(attr);
183-
attr = static_cast<int64_t>(val);
184-
}
185-
int64_t* attr_value = nullptr;
186-
try {
187-
attr_value = &boost::get<int64_t>(attr);
188-
} catch (boost::bad_get& bad_get) {
189-
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
190-
attr_name_, paddle::platform::demangle(attr.type().name()));
191-
}
192-
return attr_value;
193-
}
194-
195-
const std::string& attr_name_;
196-
};
197-
198231
// check whether a certain attribute fit its limits
199232
// an attribute can have more than one limits
200233
template <typename T>
@@ -235,7 +268,7 @@ class TypedAttrChecker {
235268
return *this;
236269
}
237270

238-
void operator()(AttributeMap& attr_map) const {
271+
void operator()(AttributeMap& attr_map) const { // NOLINT
239272
if (!attr_map.count(attr_name_)) {
240273
// user do not set this attr
241274
PADDLE_ENFORCE(!default_value_setter_.empty(),
@@ -271,7 +304,7 @@ class OpAttrChecker {
271304
return *(checker.target<TypedAttrChecker<T>>());
272305
}
273306

274-
void Check(AttributeMap& attr_map) const {
307+
void Check(AttributeMap& attr_map) const { // NOLINT
275308
for (const auto& checker : attr_checkers_) {
276309
checker(attr_map);
277310
}

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ void BroadcastOpHandle::RunImpl() {
5252
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
5353
PADDLE_ENFORCE_NOT_NULL(in_var);
5454
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
55+
if (UNLIKELY(!in_tensor.IsInitialized())) {
56+
VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!";
57+
return;
58+
}
5559

5660
InitOutputValue(*in_var_handle, out_var_handles);
5761

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
680680
}
681681

682682
if (node->Op()->Type() == "split_byref" ||
683-
node->Op()->Type() == "split_selected_rows") {
683+
node->Op()->Type() == "split_selected_rows" ||
684+
node->Op()->Type() == "split_ids") {
684685
// TODO(paddle-dev): getting the first var is not safe.
685686
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
686687
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {

paddle/fluid/framework/framework.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum AttrType {
3535
BLOCK = 8;
3636
LONG = 9;
3737
BLOCKS = 10;
38+
LONGS = 11;
3839
}
3940

4041
// OpDesc describes an instance of a C++ framework::OperatorBase
@@ -55,6 +56,7 @@ message OpDesc {
5556
optional int32 block_idx = 12;
5657
optional int64 l = 13;
5758
repeated int32 blocks_idx = 14;
59+
repeated int64 longs = 15;
5860
};
5961

6062
message Var {

paddle/fluid/framework/op_desc.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,15 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
419419
}
420420
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
421421
}
422+
422423
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
424+
423425
void operator()(int64_t v) const { attr_->set_l(v); }
426+
427+
void operator()(const std::vector<int64_t> &v) const {
428+
VectorToRepeated(v, attr_->mutable_longs());
429+
}
430+
424431
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
425432
};
426433

paddle/fluid/framework/parallel_executor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ void ParallelExecutor::BCastParamsToDevices(
185185
}
186186

187187
auto &main_tensor = main_var->Get<LoDTensor>();
188+
if (!main_tensor.IsInitialized()) {
189+
VLOG(3) << "one in var not inited, return!";
190+
continue;
191+
}
188192
auto &dims = main_tensor.dims();
189193
if (paddle::platform::is_gpu_place(main_tensor.place())) {
190194
#ifdef PADDLE_WITH_CUDA

paddle/fluid/framework/type_defs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using Attribute =
3636
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
3737
std::vector<float>, std::vector<std::string>, bool,
3838
std::vector<bool>, BlockDesc*, int64_t,
39-
std::vector<BlockDesc*>>;
39+
std::vector<BlockDesc*>, std::vector<int64_t>>;
4040

4141
using AttributeMap = std::unordered_map<std::string, Attribute>;
4242

0 commit comments

Comments
 (0)