Skip to content

Commit cac315f

Browse files
author
chengduo
authored
update alloc_continuous_space_for_grad_pass (#18288)
test=release/1.5
1 parent 618c2c7 commit cac315f

File tree

7 files changed

+72
-16
lines changed

7 files changed

+72
-16
lines changed

paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <algorithm>
1717
#include <string>
1818
#include <unordered_map>
19+
#include <unordered_set>
1920
#include <utility>
2021
#include <vector>
2122
#include "paddle/fluid/framework/details/build_strategy.h"
@@ -84,24 +85,27 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
8485
}
8586

8687
if (params_grads.size() == 0) {
87-
LOG(WARNING) << "Doesn't find gradients";
88+
LOG(INFO) << "Doesn't find gradients";
8889
return;
8990
}
9091

91-
std::unordered_map<std::string, ir::Node *> vars;
92+
std::unordered_map<std::string, ir::Node *> var_name2node;
93+
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
94+
var_name2node_set;
9295
for (ir::Node *node : result.Nodes()) {
9396
if (node->IsVar() && node->Var()) {
9497
// Note: The graph may have the same name node. For example, parameter
9598
// is the input of operator and it also is the output of optimizer;
96-
vars.emplace(node->Var()->Name(), node);
99+
var_name2node.emplace(node->Var()->Name(), node);
100+
var_name2node_set[node->Var()->Name()].emplace(node);
97101
}
98102
}
99103

100104
auto &group_grads_params =
101105
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
102106

103107
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
104-
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
108+
SetGroupGradsAndParams(var_name2node, params_grads, &group_grads_params);
105109

106110
params_grads.clear();
107111
for (auto &group_p_g : group_grads_params) {
@@ -116,9 +120,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
116120
auto dtype = kDefaultDtype;
117121
for (auto &p_g : params_grads) {
118122
// Get gradient var
119-
auto iter = vars.find(p_g.second);
120-
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
121-
iter->second->Var()->SetPersistable(true);
123+
auto iter = var_name2node.find(p_g.second);
124+
PADDLE_ENFORCE(iter != var_name2node.end(), "%s is not found.",
125+
p_g.second);
126+
// Set persistable
127+
auto same_nodes = var_name2node_set.find(p_g.second);
128+
PADDLE_ENFORCE(same_nodes != var_name2node_set.end(), "%s is not found.",
129+
p_g.second);
130+
for (auto it : same_nodes->second) {
131+
it->Var()->SetPersistable(true);
132+
}
122133

123134
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
124135

@@ -151,7 +162,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
151162
"%s is duplicate in FusedVars.", fused_var_name);
152163
fused_var_set.insert(fused_var_name);
153164

154-
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
165+
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, var_name2node,
155166
fused_var_name, params_grads);
156167
}
157168

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,33 @@ bool HasCircle(const Graph &graph) {
103103
return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
104104
}
105105

106+
bool VarDescIsConsistency(const Graph &graph) {
107+
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
108+
var_name2node_set;
109+
for (ir::Node *node : graph.Nodes()) {
110+
if (node->IsVar() && node->Var()) {
111+
// Note: The graph may have the same name node. For example, parameter
112+
// is the input of operator and it also is the output of optimizer;
113+
var_name2node_set[node->Var()->Name()].emplace(node);
114+
}
115+
}
116+
for (auto &iter : var_name2node_set) {
117+
auto &first_node = *iter.second.begin();
118+
bool is_persistable = std::any_of(iter.second.begin(), iter.second.end(),
119+
[&first_node](const ir::Node *node) {
120+
return node->Var()->Persistable();
121+
});
122+
if (is_persistable) {
123+
bool is_consistency =
124+
std::all_of(iter.second.begin(), iter.second.end(),
125+
[&first_node](const ir::Node *node) {
126+
return *node->Var() == *first_node->Var();
127+
});
128+
if (!is_consistency) return false;
129+
}
130+
}
131+
return true;
132+
}
106133
bool FindCircleSubGraph(const Graph &graph,
107134
std::vector<std::vector<ir::Node *>> *circles) {
108135
return HasCircleInternal(BuildOperationAdjList(graph), circles);

paddle/fluid/framework/ir/graph_helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <map>
1818
#include <memory>
1919
#include <set>
20+
#include <string>
2021
#include <vector>
2122

2223
#include "paddle/fluid/framework/ir/graph.h"
@@ -36,6 +37,9 @@ struct NodeComp {
3637
// Test if the graph contains circle.
3738
bool HasCircle(const Graph &graph);
3839

40+
// Check if the var desc of node is consistency.
41+
bool VarDescIsConsistency(const Graph &graph);
42+
3943
// Find All Circles for debugging,
4044
// store all subgraph in circles.
4145
bool FindCircleSubGraph(const Graph &graph,

paddle/fluid/framework/ir/pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ Graph* Pass::Apply(Graph* graph) const {
3838
// TODO(panyx0718): Add more verifications.
3939
PADDLE_ENFORCE(!HasCircle(*graph),
4040
"Illegal Pass. Generated graph shouldn't has cycle.");
41+
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
42+
"The VarDescs of persistable variable are not consistency.");
4143
PADDLE_ENFORCE(graph == native_graph,
4244
"Pass::Apply() cannot delete the passed graph and shouldn't "
4345
"return a new graph.(For the need of pybind11)");

paddle/fluid/framework/parallel_executor.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
320320
}
321321
#endif
322322
if (!member_->use_all_reduce_) {
323-
PADDLE_ENFORCE(places.size() > 1,
324-
"If you set build_strategy.reduce with 'Reduce',"
325-
"the number of places must be greater than 1.");
323+
if (places.size() == 1) {
324+
LOG(INFO) << "If you set build_strategy.reduce with 'Reduce',"
325+
"the number of places should be greater than 1.";
326+
member_->use_all_reduce_ = true;
327+
}
326328
}
327329

328-
LOG(WARNING) << string::Sprintf(
330+
LOG(INFO) << string::Sprintf(
329331
"The number of %s, which is used in ParallelExecutor, is %lu. And "
330332
"the Program will be copied %lu copies",
331333
(member_->use_cuda_ ? "CUDAPlace" : "CPUPlace"), places.size(),
@@ -364,10 +366,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
364366
// choice the execution strategy.
365367
build_strategy.enable_parallel_graph_ =
366368
EnableParallelGraphExecution(*graph, exec_strategy, build_strategy);
367-
if (build_strategy.enable_parallel_graph_)
368-
VLOG(0) << "The Executor would execute the graph by ParallelGraph "
369-
"Execution which can get better performance,"
370-
<< "you can force it off by env FLAGS_enable_parallel_graph=0";
369+
if (build_strategy.enable_parallel_graph_) {
370+
LOG(INFO) << "The Executor would execute the graph by ParallelGraph "
371+
"Execution which can get better performance,"
372+
<< "you can force it off by env FLAGS_enable_parallel_graph=0";
373+
}
371374

372375
if (member_->use_cuda_ && member_->nranks_ > 1) {
373376
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)

paddle/fluid/framework/var_desc.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,5 +264,10 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
264264
}
265265
}
266266

267+
bool operator==(const VarDesc &left, const VarDesc &right) {
268+
return left.Proto()->SerializeAsString() ==
269+
right.Proto()->SerializeAsString();
270+
}
271+
267272
} // namespace framework
268273
} // namespace paddle

paddle/fluid/framework/var_desc.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class VarDesc {
6767

6868
proto::VarDesc *Proto() { return &desc_; }
6969

70+
const proto::VarDesc *Proto() const { return &desc_; }
71+
7072
std::string Name() const { return desc_.name(); }
7173

7274
void SetName(std::string name) { desc_.set_name(name); }
@@ -116,5 +118,7 @@ class VarDesc {
116118

117119
proto::VarDesc desc_;
118120
};
121+
122+
bool operator==(const VarDesc &left, const VarDesc &right);
119123
} // namespace framework
120124
} // namespace paddle

0 commit comments

Comments
 (0)