Skip to content

Commit 0441c2c

Browse files
committed
fix ci
1 parent f9c680c commit 0441c2c

File tree

3 files changed

+27
-22
lines changed

3 files changed

+27
-22
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
107107

108108
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
109109
const ProgramDesc &program) const {
110+
std::unordered_map<std::string, proto::VarType::Type> var_types;
111+
for (auto *var : program.Block(0).AllVars()) {
112+
var_types[var->Name()] = var->GetType();
113+
}
110114
auto graph = new SSAGraph();
111115
SSAGraph &result = *graph;
112116
std::unordered_set<std::string> og_has_been_broadcast;
@@ -116,7 +120,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
116120
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
117121
places_.size());
118122

119-
size_t cur_update_sparse_gp_dev_id = 0;
123+
size_t cur_dev_id = 0;
120124
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
121125
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
122126

@@ -156,14 +160,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
156160
// broadcast, and each gradient is only broadcast once.
157161
for (auto &og : op->OutputArgumentNames()) {
158162
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
159-
if (IsSparseGradient(og)) {
160-
CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og);
161-
sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace(
162-
og);
163-
bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace(
163+
if (IsSparseGradient(var_types, og)) {
164+
CreateReduceOp(&result, cur_dev_id, og);
165+
sparse_var_name_on_devices[cur_dev_id].emplace(og);
166+
bcast_sparse_var_name_set[cur_dev_id].emplace(
164167
og.substr(0, og.size() - strlen(kGradVarSuffix)));
165-
cur_update_sparse_gp_dev_id =
166-
(cur_update_sparse_gp_dev_id + 1) % places_.size();
168+
cur_dev_id = (cur_dev_id + 1) % places_.size();
167169
} else {
168170
InsertNCCLAllReduceOp(&result, og);
169171
}
@@ -201,10 +203,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
201203
return std::unique_ptr<SSAGraph>(graph);
202204
}
203205

204-
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
205-
auto og_var = local_scopes_[0]->FindVar(og);
206-
PADDLE_ENFORCE_NOT_NULL(og_var);
207-
return og_var->IsType<SelectedRows>();
206+
bool MultiDevSSAGraphBuilder::IsSparseGradient(
207+
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
208+
const std::string &og) const {
209+
PADDLE_ENFORCE(var_types.count(og) != 0);
210+
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
211+
return true;
212+
}
213+
return false;
208214
}
209215

210216
int MultiDevSSAGraphBuilder::GetOpDeviceID(

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
9999
*/
100100
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
101101

102-
bool IsSparseGradient(const std::string &og) const;
102+
bool IsSparseGradient(
103+
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
104+
const std::string &og) const;
103105
};
104106
} // namespace details
105107
} // namespace framework

paddle/fluid/framework/details/var_handle.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,12 @@ struct VarHandle : public VarHandleBase {
6363
platform::Place place_;
6464

6565
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
66-
// member
67-
// variables(version_, scope_id_, name_, place_) must be equal. But sometimes
68-
// judging whether the two var_handle is equal is actually to determine
69-
// whether
70-
// the two Variables that represented by var_handle is the same. And the same
71-
// Variable may have many different var_handles, the version_ of these
72-
// var_handles
73-
// is different. So I don't take care of version_ temporarily when overloading
74-
// equal.
66+
// member variables(version_, scope_id_, name_, place_) must be equal. But
67+
// sometimes judging whether the two var_handle is equal is actually to
68+
// determine whether the two Variables that represented by var_handle is the
69+
// same. And the same Variable may have many different var_handles, the
70+
// version_ of these var_handles is different. So I don't take care of
71+
// version_ temporarily when overloading equal.
7572
bool operator==(const VarHandle& o) const {
7673
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
7774
o.scope_idx_ == scope_idx_;

0 commit comments

Comments
 (0)