Skip to content

Commit 9f001c6

Browse files
committed
skip dist. test=develop
1 parent 2561a6f commit 9f001c6

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

paddle/fluid/framework/details/inplace_op_pass.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
301301

302302
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
303303
// this candidate can not be inplaced. Will be deprecated in the future.
304-
if (view_.ReusedInPythonMemOpt(out_node->Name())) {
304+
if (view_.InSkipSet(out_node->Name())) {
305305
VLOG(4) << string::Sprintf(
306306
"Skiped %s => %s reused previous memory block in python memory "
307307
"optmize,"
@@ -385,7 +385,7 @@ void GraphView::Build(ir::Graph* g) {
385385
// resolve data harzards depends on the var nodes in right order.
386386
ops_ = SortOpLikeDescOrder(*g);
387387

388-
// track the nodes which reused previous node in Python memory optimize.
388+
// 1. track the nodes which reused previous node in Python memory optimize.
389389
// these node can not be inplaced, otherwise may generate a circle in graph.
390390
std::unordered_set<std::string> all_vars;
391391
for (auto& node : g->Nodes()) {
@@ -399,11 +399,28 @@ void GraphView::Build(ir::Graph* g) {
399399
}
400400
}
401401
}
402+
403+
// 2. track the nodes which used by parameter server.
404+
// these node can not be inplaced, otherwise trainer
405+
// pserver can not find each other name.
406+
for (auto& node : g->Nodes()) {
407+
if (!node->IsOp()) continue;
408+
if (node->Name() == "send") {
409+
for (auto& in : node->inputs) {
410+
dup_nodes_.emplace(in->Name());
411+
}
412+
}
413+
if (node->Name() == "recv") {
414+
for (auto& out : node->outputs) {
415+
dup_nodes_.emplace(out->Name());
416+
}
417+
}
418+
}
402419
}
403420

404421
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
405422

406-
bool GraphView::ReusedInPythonMemOpt(const std::string& var) const {
423+
bool GraphView::InSkipSet(const std::string& var) const {
407424
return dup_nodes_.count(var);
408425
}
409426

paddle/fluid/framework/details/inplace_op_pass.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ class GraphView {
4141
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
4242

4343
// Will Deperated in the future.
44-
// NOTE(dzhwinter) : Python memory optimize will reuse
44+
// NOTE(dzhwinter) :
45+
// 1. Python memory optimize will reuse
4546
// memory based var name, so different op output may
4647
// have the same variable name. enable inplace on such node
4748
// will generate a circle in ssa graph.
48-
bool ReusedInPythonMemOpt(const std::string& var) const;
49+
// 2. DistributeTranspiler will use unique name to
50+
// map the parameter and gradient, must be skipped.
51+
bool InSkipSet(const std::string& var) const;
4952

5053
private:
5154
std::vector<ir::Node*> ops_;

0 commit comments

Comments
 (0)