Skip to content

Commit 0b3465d

Browse files
committed
better
1 parent dcaf183 commit 0b3465d

File tree

9 files changed

+308
-49
lines changed

9 files changed

+308
-49
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
55
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
66
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
77

8-
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
8+
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper)
99
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
1010
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
1111

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "paddle/fluid/framework/details/reduce_op_handle.h"
2626
#include "paddle/fluid/framework/details/rpc_op_handle.h"
2727
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
28+
#include "paddle/fluid/framework/ir/graph_helper.h"
2829
#include "paddle/fluid/framework/ir/node.h"
2930
#include "paddle/fluid/framework/op_info.h"
3031
#include "paddle/fluid/framework/scope.h"
@@ -186,9 +187,55 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
186187
return dev_id;
187188
}
188189

190+
// Topology sort the graph nodes from inputs to outputs.
191+
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
192+
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
193+
// some optimizer ops might not depend on any nodes), we manually move all
194+
// optimizer nodes after last backward nodes.
195+
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
196+
std::vector<ir::Node *> ret = ir::TopologySort(graph);
197+
size_t last_backward = 0;
198+
std::vector<ir::Node *> optimize_ops;
199+
std::vector<ir::Node *> sorted_ret;
200+
for (size_t i = 0; i < ret.size(); ++i) {
201+
if (boost::get<int>(
202+
ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
203+
static_cast<int>(OpRole::kBackward)) {
204+
sorted_ret.push_back(ret[i]);
205+
last_backward = sorted_ret.size();
206+
} else if (boost::get<int>(ret[i]->Op()->GetAttr(
207+
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
208+
static_cast<int>(OpRole::kOptimize)) {
209+
optimize_ops.push_back(ret[i]);
210+
} else {
211+
sorted_ret.push_back(ret[i]);
212+
}
213+
}
214+
215+
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
216+
optimize_ops.end());
217+
218+
for (ir::Node *n : sorted_ret) {
219+
n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(),
220+
[n](ir::Node *t) {
221+
return t->Name() ==
222+
ir::Node::kControlDepVarName;
223+
}),
224+
n->inputs.end());
225+
n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(),
226+
[n](ir::Node *t) {
227+
return t->Name() ==
228+
ir::Node::kControlDepVarName;
229+
}),
230+
n->outputs.end());
231+
}
232+
return sorted_ret;
233+
}
234+
189235
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
190236
std::unique_ptr<Graph> graph) const {
191237
// Rebuild the graph structure.
238+
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
192239
auto nodes = std::move(graph->nodes);
193240
graph->nodes.clear();
194241

@@ -217,12 +264,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
217264
size_t cur_device_id = 0;
218265
bool is_forwarding = true;
219266

220-
// NOTE: Currently, passes before SSAGraphBuilder cannot reorder
221-
// forward, backward nodes. E.g. you can't append an forward node
222-
// at the end of the node list.
223-
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
224-
for (ir::Node *node : TopologySortOperationFromInToOut(nodes)) {
225-
VLOG(3) << "apply node: " << node->Name() << reinterpret_cast<void *>(node);
267+
for (ir::Node *node : sorted_ops) {
226268
if (boost::get<int>(
227269
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
228270
static_cast<int>(OpRole::kRPC)) {
@@ -240,7 +282,6 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
240282
// It also assumes backward op will always follow the forward op in
241283
// the block.
242284
is_forwarding = false;
243-
LOG(ERROR) << "forward flipping!!!!!!!";
244285
} else {
245286
int op_dev_id = GetOpDeviceID(node);
246287
if (op_dev_id != -1) { // This op only runs on one specific device.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cc_library(node SRCS node.cc DEPS proto_desc)
22
cc_library(graph SRCS graph.cc DEPS node)
3+
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
34
cc_library(pass SRCS pass.cc DEPS graph node)
4-
55
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)

paddle/fluid/framework/ir/graph.cc

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222

2323
namespace paddle {
2424
namespace framework {
25+
/*
2526
namespace {
2627
void SortHelper(
2728
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
@@ -39,7 +40,21 @@ void SortHelper(
3940
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
4041
ret->push_back(node);
4142
}
43+
44+
std::vector<ir::Node*> TopologySort(
45+
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
46+
std::unordered_set<ir::Node *> visited;
47+
std::vector<ir::Node *> ret;
48+
49+
for (auto adj : adj_list) {
50+
if (visited.find(adj.first) == visited.end()) {
51+
SortHelper(adj_list, adj.first, &visited, &ret);
52+
}
53+
}
54+
return ret;
55+
}
4256
} // namespace
57+
*/
4358

4459
Graph::Graph(const ProgramDesc &program) : program_(program) {
4560
VLOG(3) << "block in program:" << program_.Size();
@@ -48,20 +63,9 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
4863
all_vars.emplace(var->Name(), var);
4964
}
5065

51-
ir::Node *last_backward = nullptr;
52-
std::vector<ir::Node *> optimize_ops;
5366
std::map<std::string, std::vector<ir::Node *>> var_nodes;
5467
for (auto *op : program.Block(0).AllOps()) {
5568
ir::Node *node = CreateOpNode(op);
56-
if (boost::get<int>(
57-
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
58-
static_cast<int>(OpRole::kBackward)) {
59-
last_backward = node;
60-
} else if (boost::get<int>(
61-
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
62-
static_cast<int>(OpRole::kOptimize)) {
63-
optimize_ops.push_back(node);
64-
}
6569

6670
for (auto &each_var_name : op->InputArgumentNames()) {
6771
ir::Node *var = nullptr;
@@ -106,70 +110,130 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
106110
// Read Write is the same op.
107111
continue;
108112
}
109-
ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable);
113+
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
114+
ir::Node::Type::kVariable);
110115
read_op->outputs.push_back(dep_var);
111116
dep_var->inputs.push_back(read_op);
112117
write_op->inputs.push_back(dep_var);
113118
dep_var->outputs.push_back(write_op);
114119
}
115120
}
116121
}
122+
}
117123

118-
if (last_backward) {
119-
for (ir::Node *opt_node : optimize_ops) {
120-
ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable);
121-
last_backward->outputs.push_back(dep_var);
122-
dep_var->inputs.push_back(last_backward);
123-
opt_node->inputs.push_back(dep_var);
124-
dep_var->outputs.push_back(opt_node);
125-
VLOG(3) << "appending connect: " << last_backward->Name()
126-
<< reinterpret_cast<void *>(last_backward) << "->"
127-
<< opt_node->Name() << reinterpret_cast<void *>(opt_node);
124+
/*
125+
bool HasCircleHelper(ir::Node* node,
126+
const std::map<ir::Node *, std::unordered_set<ir::Node *>>
127+
&adj_list,
128+
std::unordered_set<ir::Node*>* visited,
129+
std::unordered_set<ir::Node*>* in_trace) {
130+
if (visited->find(node) == visited->end()) {
131+
visited->insert(node);
132+
in_trace->insert(node);
133+
134+
for (ir::Node *in : adj_list.at(node)) {
135+
if (visited->find(in) == visited->end() &&
136+
HasCircleHelper(in, adj_list, visited, in_trace)) {
137+
return true;
138+
} else if (in_trace->find(in) != in_trace->end()) {
139+
return true;
140+
}
128141
}
129142
}
143+
in_trace->erase(node);
144+
return false;
130145
}
131146
132-
std::vector<ir::Node *> TopologySortOperationFromInToOut(
133-
const std::vector<std::unique_ptr<ir::Node>> &nodes) {
147+
bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
148+
&adj_list) {
149+
std::unordered_set<ir::Node*> visited;
150+
std::unordered_set<ir::Node*> in_trace;
151+
for (auto& adj : adj_list) {
152+
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
153+
return true;
154+
}
155+
}
156+
return false;
157+
}
158+
159+
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
160+
const std::vector<ir::Node*> &nodes) {
134161
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
135-
std::unordered_set<ir::Node *> visited;
136-
std::vector<ir::Node *> ret;
137162
138163
for (auto &n : nodes) {
139164
if (n->NodeType() != ir::Node::Type::kOperation) continue;
140-
if (adj_list.find(n.get()) == adj_list.end()) {
141-
adj_list[n.get()] = std::unordered_set<ir::Node *>();
165+
if (adj_list.find(n) == adj_list.end()) {
166+
adj_list[n] = std::unordered_set<ir::Node *>();
142167
}
143168
for (auto &var : n->inputs) {
144169
for (auto &adj_n : var->inputs) {
145170
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
146-
adj_list[n.get()].insert(adj_n);
171+
adj_list[n].insert(adj_n);
147172
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
148-
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get())
173+
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
149174
<< " via " << var->Name() << reinterpret_cast<void *>(var);
150175
}
151176
}
152177
}
178+
return adj_list;
179+
}
153180
154-
for (auto adj : adj_list) {
155-
if (visited.find(adj.first) == visited.end()) {
156-
SortHelper(adj_list, adj.first, &visited, &ret);
181+
std::vector<ir::Node *> TopologySortOperationFromInToOut(
182+
const std::vector<std::unique_ptr<ir::Node>> &nodes) {
183+
std::vector<ir::Node*> tmp;
184+
for (auto& n : nodes) {
185+
tmp.push_back(n.get());
186+
}
187+
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
188+
BuildAdjList(tmp);
189+
190+
PADDLE_ENFORCE(!HasCircle(adj_list));
191+
std::vector<ir::Node*> ret = TopologySort(adj_list);
192+
193+
ir::Node *last_backward = nullptr;
194+
std::vector<ir::Node *> optimize_ops;
195+
for (ir::Node* n : ret) {
196+
if (boost::get<int>(
197+
n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
198+
static_cast<int>(OpRole::kBackward)) {
199+
last_backward = n;
200+
} else if (boost::get<int>(
201+
n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
202+
static_cast<int>(OpRole::kOptimize)) {
203+
optimize_ops.push_back(n);
157204
}
158205
}
159206
207+
if (last_backward) {
208+
for (ir::Node *opt_node : optimize_ops) {
209+
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
210+
ir::Node::Type::kVariable);
211+
last_backward->outputs.push_back(dep_var);
212+
dep_var->inputs.push_back(last_backward);
213+
opt_node->inputs.push_back(dep_var);
214+
dep_var->outputs.push_back(opt_node);
215+
VLOG(3) << "appending connect: " << last_backward->Name()
216+
<< reinterpret_cast<void *>(last_backward) << "->"
217+
<< opt_node->Name() << reinterpret_cast<void *>(opt_node);
218+
}
219+
}
220+
221+
PADDLE_ENFORCE(!HasCircle(adj_list));
160222
for (ir::Node *n : ret) {
161223
std::unordered_set<ir::Node *> dummy;
162224
n->inputs.erase(
163225
std::remove_if(n->inputs.begin(), n->inputs.end(),
164-
[n](ir::Node *t) { return t->Name() == "dummy"; }),
226+
[n](ir::Node *t) {
227+
return t->Name() == ir::Node::kControlDepVarName; }),
165228
n->inputs.end());
166229
n->outputs.erase(
167230
std::remove_if(n->outputs.begin(), n->outputs.end(),
168-
[n](ir::Node *t) { return t->Name() == "dummy"; }),
231+
[n](ir::Node *t) {
232+
return t->Name() == ir::Node::kControlDepVarName; }),
169233
n->outputs.end());
170234
}
171235
return ret;
172-
}
236+
}*/
173237

174238
} // namespace framework
175239
} // namespace paddle

paddle/fluid/framework/ir/graph.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,5 @@ class Graph {
7878
std::map<std::string, std::function<void(void)>> attr_dels_;
7979
};
8080

81-
std::vector<ir::Node*> TopologySortOperationFromInToOut(
82-
const std::vector<std::unique_ptr<ir::Node>>& nodes);
83-
8481
} // namespace framework
8582
} // namespace paddle

0 commit comments

Comments
 (0)