Skip to content

Commit 885c4e5

Browse files
authored
fea/infer memory optim2 (#14953)
1 parent 6597ccb commit 885c4e5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1450
-154
lines changed

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
1616
#include <string>
1717
#include <vector>
18+
#include "paddle/fluid/framework/ir/graph_helper.h"
1819
#include "paddle/fluid/platform/enforce.h"
1920

2021
namespace paddle {

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ limitations under the License. */
1818
#include <fstream>
1919
#include <iosfwd>
2020
#include <ostream>
21+
#include <stack>
2122
#include <unordered_map>
2223
#include <unordered_set>
24+
#include "paddle/fluid/framework/ir/graph_traits.h"
2325

2426
DEFINE_string(print_sub_graph_dir, "",
2527
"FLAGS_print_sub_graph_dir is used "
@@ -41,7 +43,7 @@ void SortHelper(
4143
}
4244
}
4345

44-
VLOG(3) << "topology sort insert: " << node->Name()
46+
VLOG(5) << "topology sort insert: " << node->Name() << " "
4547
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
4648
ret->push_back(node);
4749
}
@@ -99,12 +101,13 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
99101
return ret;
100102
}
101103

104+
// Build operator inlink edge table.
102105
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
103106
const Graph &graph) {
104107
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
105108

106109
for (auto &n : graph.Nodes()) {
107-
if (n->NodeType() != ir::Node::Type::kOperation) continue;
110+
if (!n->IsOp()) continue;
108111
if (adj_list.find(n) == adj_list.end()) {
109112
adj_list[n] = std::unordered_set<ir::Node *>();
110113
}
@@ -121,6 +124,119 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
121124
return adj_list;
122125
}
123126

127+
// Build operator outlink edge table.
128+
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
129+
const Graph &graph) {
130+
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
131+
132+
for (auto &n : graph.Nodes()) {
133+
if (!n->IsOp()) continue;
134+
if (adj_list.find(n) == adj_list.end()) {
135+
adj_list[n] = std::unordered_set<ir::Node *>();
136+
}
137+
for (auto &var : n->outputs) {
138+
for (auto &adj_n : var->outputs) {
139+
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
140+
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
141+
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
142+
<< " via " << var->Name() << reinterpret_cast<void *>(var);
143+
adj_list[n].insert(adj_n);
144+
}
145+
}
146+
}
147+
return adj_list;
148+
}
149+
150+
std::vector<ir::Node *> OpDFSSort(const Graph &graph) {
151+
auto edge_table = BuildOperationOutAdjList(graph);
152+
std::stack<Node *> stack;
153+
for (auto &ele : edge_table) {
154+
if (ele.first->inputs.empty()) {
155+
// find the input ops (those without input vars)
156+
stack.push(ele.first);
157+
} else {
158+
// find the ops with only persistable vars as inputs.
159+
bool all_persistable = true;
160+
for (auto *input : ele.first->inputs) {
161+
if (!(input->IsVar() && input->Var() && input->Var()->Persistable())) {
162+
all_persistable = false;
163+
}
164+
}
165+
if (all_persistable) {
166+
stack.push(ele.first);
167+
}
168+
}
169+
}
170+
171+
std::vector<Node *> res;
172+
// start from the feed op and DFS
173+
std::unordered_set<Node *> unique_set;
174+
while (!stack.empty()) {
175+
// will start from the last feed by default.
176+
auto cur = stack.top();
177+
stack.pop();
178+
unique_set.insert(cur);
179+
res.push_back(cur);
180+
181+
for (auto *op : edge_table[cur]) {
182+
if (!unique_set.count(op)) {
183+
stack.push(op);
184+
}
185+
}
186+
}
187+
return res;
188+
}
189+
190+
std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph) {
191+
std::vector<ir::Node *> nodes;
192+
std::unordered_map<Node *, int> in_degree;
193+
194+
auto set_out_ops_ready = [&](Node *var) {
195+
for (auto *op : var->outputs) {
196+
--in_degree[op];
197+
}
198+
};
199+
// build in_degree
200+
for (auto *node : graph.Nodes()) {
201+
if (node->IsOp()) {
202+
in_degree[node] += node->inputs.size();
203+
} else if (node->IsVar() && node->inputs.empty()) {
204+
// put all the inputs of the whole graph ready.
205+
set_out_ops_ready(node);
206+
}
207+
}
208+
209+
std::deque<Node *> op_queue;
210+
// first visit
211+
for (auto &node : OpDFSSort(graph)) {
212+
if (node->IsOp()) {
213+
op_queue.push_back(node);
214+
}
215+
}
216+
217+
// traverse the graph
218+
int num_ops = op_queue.size();
219+
while (num_ops) {
220+
for (auto it = op_queue.begin(); it != op_queue.end(); it++) {
221+
auto *&cur_op = *it;
222+
if (!cur_op || in_degree[cur_op] > 0) continue;
223+
// visit this node
224+
// put all the output var of this op valid.
225+
for (auto *out_var : cur_op->outputs) {
226+
if (!out_var) continue;
227+
set_out_ops_ready(out_var);
228+
}
229+
VLOG(8) << "visit " << cur_op->Name();
230+
nodes.push_back(cur_op);
231+
232+
cur_op = nullptr;
233+
num_ops--;
234+
}
235+
}
236+
237+
return nodes;
238+
}
239+
124240
size_t GraphNum(const Graph &graph) {
125241
std::unordered_set<ir::Node *> nodes(graph.Nodes());
126242
std::unordered_set<ir::Node *> visited_nodes;
@@ -203,6 +319,29 @@ size_t GraphNum(const Graph &graph) {
203319
return graph_count;
204320
}
205321

322+
void CleanIndividualNodes(Graph *graph) {
323+
std::unordered_set<Node *> nodes2rm;
324+
for (auto *node : graph->Nodes()) {
325+
if (node->inputs.empty() && node->outputs.empty()) {
326+
nodes2rm.insert(node);
327+
}
328+
}
329+
330+
for (auto *node : nodes2rm) {
331+
graph->RemoveNode(node);
332+
}
333+
}
334+
335+
std::vector<Node *> TopologyVarientSort(const Graph &graph,
336+
SortKind sort_kind) {
337+
switch (sort_kind) {
338+
case SortKind::TS:
339+
return framework::ir::TopologySortOperations(graph);
340+
default:
341+
return framework::ir::TopologyDfsSortOperations(graph);
342+
}
343+
}
344+
206345
} // namespace ir
207346
} // namespace framework
208347
} // namespace paddle

paddle/fluid/framework/ir/graph_helper.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ size_t GraphNum(const Graph &graph);
3434
// `graph` cannot contain circle.
3535
std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
3636

37+
// Topological sort, but try to DFS.
38+
std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph);
39+
40+
// Different kinds to sort the operators in a graph to a sequence.
41+
enum class SortKind {
42+
// Topological Search
43+
TS = 0,
44+
// Topological and Depth First Search
45+
TDFS
46+
};
47+
48+
// Several kinds of topological sort.
49+
std::vector<Node *> TopologyVarientSort(const Graph &graph, SortKind sort_kind);
50+
51+
// Clean the nodes that doesn't connect to others.
52+
void CleanIndividualNodes(Graph *graph);
53+
3754
// Build an adjacency list of operations for the `graph`.
3855
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
3956
const Graph &graph);

paddle/fluid/framework/ir/graph_to_program_pass.cc

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ limitations under the License. */
2020

2121
#include "paddle/fluid/framework/ir/graph.h"
2222
#include "paddle/fluid/framework/ir/graph_helper.h"
23-
2423
#include "paddle/fluid/framework/program_desc.h"
2524

2625
namespace paddle {
@@ -29,6 +28,14 @@ namespace ir {
2928

3029
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
3130
std::unique_ptr<Graph> graph) const {
31+
// Remove the unneeded variables after memory optimization.
32+
std::unordered_set<std::string> vars2remove;
33+
if (graph->Has(kGraphToProgramVarsToRemove)) {
34+
vars2remove = graph->Get<std::unordered_set<std::string>>(
35+
kGraphToProgramVarsToRemove);
36+
VLOG(2) << "graph to program remove " << vars2remove.size() << " nodes";
37+
}
38+
3239
ProgramDesc& program = Get<ProgramDesc>("program");
3340

3441
std::unique_ptr<proto::ProgramDesc> program_pb(
@@ -40,25 +47,35 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
4047
std::unordered_set<std::string> visited_vars;
4148
for (ir::Node* n : graph->Nodes()) {
4249
if (n->IsVar()) {
43-
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
50+
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
51+
!vars2remove.count(n->Var()->Name())) {
4452
visited_vars.insert(n->Var()->Name());
4553
block->add_vars()->MergeFrom(*n->Var()->Proto());
4654
}
4755
}
4856
}
49-
5057
block->clear_ops();
51-
std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
58+
59+
std::vector<ir::Node*> nodes;
60+
if (Has(kGraphToProgramSortKind)) {
61+
// Inference Memory Optimize relays on this branch.
62+
int sort_kind = Get<int>(kGraphToProgramSortKind);
63+
nodes = TopologyVarientSort(
64+
*graph, static_cast<framework::ir::SortKind>(sort_kind));
65+
} else {
66+
nodes = TopologySortOperations(*graph);
67+
}
68+
5269
for (ir::Node* n : nodes) {
53-
if (!n->Op()) {
54-
continue;
55-
}
70+
if (!n->Op()) continue;
71+
5672
block->add_ops()->MergeFrom(*n->Op()->Proto());
5773
}
5874

5975
program.CopyFrom(*program_pb);
6076
return graph;
6177
}
78+
6279
} // namespace ir
6380
} // namespace framework
6481
} // namespace paddle

paddle/fluid/framework/ir/graph_to_program_pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ namespace paddle {
2020
namespace framework {
2121
namespace ir {
2222

23+
const char kGraphToProgramVarsToRemove[] =
24+
"__graph_to_program_vars_to_remove__";
25+
const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
26+
2327
class GraphToProgramPass : public Pass {
2428
protected:
2529
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;

paddle/fluid/framework/ir/graph_viz_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
135135
} // namespace paddle
136136

137137
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
138-
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
138+
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);

paddle/fluid/framework/ir/node.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Node {
6464

6565
std::string Name() const { return name_; }
6666

67-
VarDesc* Var() {
67+
VarDesc* Var() const {
6868
PADDLE_ENFORCE(IsVar());
6969
return var_desc_.get();
7070
}

paddle/fluid/framework/naive_executor.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ void NaiveExecutor::Run() {
5050
"running Paddle Inference";
5151
#endif // PADDLE_ON_INFERENCE
5252
for (auto &op : ops_) {
53-
VLOG(3) << std::this_thread::get_id() << " run " << op->Type()
54-
<< " on scope " << scope_;
53+
VLOG(4) << std::this_thread::get_id() << " run "
54+
<< op->DebugStringEx(scope_) << " on scope " << scope_;
5555
op->SetIsCalledByExecutor(false);
5656
op->Run(*scope_, place_);
5757
}
@@ -69,10 +69,12 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
6969
anc = anc->parent();
7070
}
7171

72+
int num_vars = 0;
7273
for (auto &var : global_block.AllVars()) {
7374
if (var->Name() == framework::kEmptyVarName) {
7475
continue;
7576
}
77+
num_vars++;
7678

7779
if (persistable == var->Persistable()) {
7880
if (persistable) {
@@ -90,6 +92,7 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
9092
}
9193
}
9294
}
95+
VLOG(4) << "naive executor create " << num_vars << " vars";
9396
}
9497

9598
void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id,

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(analysis SRCS
1818
analyzer.cc
1919
analysis_pass
2020
DEPS ${analysis_deps} analysis_helper
21+
${INFER_IR_PASSES}
2122
)
2223

2324
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,25 @@
1515
#include "paddle/fluid/inference/analysis/analyzer.h"
1616
#include <string>
1717
#include <vector>
18-
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h"
1918
#include "paddle/fluid/inference/analysis/passes/passes.h"
19+
#include "paddle/fluid/string/pretty_log.h"
2020

2121
namespace paddle {
2222
namespace inference {
2323
namespace analysis {
2424

2525
Analyzer::Analyzer() {}
2626

27-
void Analyzer::Run(Argument *argument) { RunIrAnalysis(argument); }
27+
void Analyzer::Run(Argument *argument) { RunAnalysis(argument); }
2828

29-
void Analyzer::RunIrAnalysis(Argument *argument) {
30-
std::vector<std::string> passes({"ir_analysis_compose_pass"});
31-
32-
for (auto &pass : passes) {
33-
PassRegistry::Global().Retreive(pass)->Run(argument);
29+
void Analyzer::RunAnalysis(Argument *argument) {
30+
PADDLE_ENFORCE(argument->analysis_passes_valid(),
31+
"analsis_passes is not valid in the argument.");
32+
for (auto &pass : argument->analysis_passes()) {
33+
string::PrettyLogH1("--- Running analysis [%s]", pass);
34+
auto *ptr = PassRegistry::Global().Retreive(pass);
35+
PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass);
36+
ptr->Run(argument);
3437
}
3538
}
3639

0 commit comments

Comments
 (0)