Skip to content

Commit 9f252e0

Browse files
authored
Combine Inference Analysis with IR (#13914)
1 parent 893c1b0 commit 9f252e0

File tree

109 files changed

+2722
-4433
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+2722
-4433
lines changed

cmake/inference_lib.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ endif()
164164
set(module "inference")
165165
copy(inference_lib DEPS ${inference_deps}
166166
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
167-
${src_dir}/${module}/api/paddle_inference_api.h
167+
${src_dir}/${module}/api/paddle_*.h
168168
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
169169
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
170170
)
@@ -202,10 +202,10 @@ copy(third_party DEPS fluid_lib_dist
202202
DSTS ${FLUID_INFERENCE_INSTALL_DIR} ${FLUID_INFERENCE_INSTALL_DIR}
203203
)
204204

205-
# only need libpaddle_fluid.so/a and paddle_inference_api.h for inference-only library
205+
# only need libpaddle_fluid.so/a and paddle_*.h for inference-only library
206206
copy(inference_api_lib DEPS fluid_lib_dist
207207
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.*
208-
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_inference_api.h
208+
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_*.h
209209
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
210210
)
211211

cmake/tensorrt.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ if(TENSORRT_FOUND)
3434
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
3535
include_directories(${TENSORRT_INCLUDE_DIR})
3636
list(APPEND EXTERNAL_LIBS ${TENSORRT_LIBRARY})
37+
add_definitions(-DPADDLE_WITH_TENSORRT)
3738
endif()

paddle/fluid/framework/executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
359359
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
360360
bool create_local_scope, bool create_vars,
361361
bool keep_kids) {
362+
PADDLE_ENFORCE_NOT_NULL(scope);
362363
Scope* local_scope = scope;
363364
if (create_vars) {
364365
if (create_local_scope) {

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
55

66

77
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
8+
unset(INFER_IR_PASSES CACHE) # clear the global variable
89
function(pass_library TARGET DEST)
910
set(options "")
1011
set(oneValueArgs "")
@@ -15,10 +16,11 @@ function(pass_library TARGET DEST)
1516
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
1617
message(STATUS "add pass ${TARGET} ${DEST}")
1718
file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
18-
set(PASS_LIBRARY ${TARGET} ${PASS_LIBRARY} PARENT_SCOPE)
19+
set(INFER_IR_PASSES ${INFER_IR_PASSES} ${TARGET} CACHE INTERNAL "")
1920
endif()
2021
endfunction()
2122

23+
2224
cc_library(node SRCS node.cc DEPS proto_desc)
2325
cc_library(graph SRCS graph.cc DEPS node pretty_log)
2426
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ void FindWhileOp(Graph* graph) {
9191
#undef OP_SET_IN
9292
#undef OP_SET_OUT
9393

94-
auto* X = graph->RetriveNode(34);
95-
auto* LSTMOUT = graph->RetriveNode(81);
96-
auto* cell_init = graph->RetriveNode(6);
97-
auto* hidden_init = graph->RetriveNode(8);
94+
auto* X = graph->RetrieveNode(34);
95+
auto* LSTMOUT = graph->RetrieveNode(81);
96+
auto* cell_init = graph->RetrieveNode(6);
97+
auto* hidden_init = graph->RetrieveNode(8);
9898

9999
auto* lstm_op = graph->CreateOpNode(&op_desc);
100100
PrepareParameters(graph, param);

paddle/fluid/framework/ir/graph.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ void CheckProgram(const ProgramDesc &program) {
8484

8585
Graph::Graph(const ProgramDesc &program) : program_(program) {
8686
CheckProgram(program_);
87-
// Make the nodes id start from 0.
88-
Node::ResetId();
8987
auto var_nodes = InitFromProgram(program_);
9088
ResolveHazard(var_nodes);
9189
}

paddle/fluid/framework/ir/graph.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,17 @@ class Graph {
116116
// Create a normal variable with non-null VarDesc.
117117
ir::Node *CreateVarNode(VarDesc *var_desc) {
118118
PADDLE_ENFORCE(var_desc);
119-
return AddNode(new ir::Node(var_desc));
119+
auto *x = AddNode(new ir::Node(var_desc));
120+
x->SetId(num_node_created_++);
121+
return x;
120122
}
121123

122124
// Create a normal runnable operator with OpDesc.
123125
ir::Node *CreateOpNode(OpDesc *op_desc) {
124126
PADDLE_ENFORCE(op_desc);
125-
return AddNode(new ir::Node(op_desc));
127+
auto *x = AddNode(new ir::Node(op_desc));
128+
x->SetId(num_node_created_++);
129+
return x;
126130
}
127131

128132
// Create a control dependency var that connects 2 operations. The
@@ -132,13 +136,17 @@ class Graph {
132136
// TODO(panyx0718): control var name should be really unique.
133137
const std::string name = string::Sprintf(
134138
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
135-
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
139+
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable));
140+
x->SetId(num_node_created_++);
141+
return x;
136142
}
137143

138144
// A more free style way of creating a graph node. Mostly use for test
139145
// or "copy" from another node. Avoid using it if possible.
140146
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
141-
return AddNode(new ir::Node(name, type));
147+
auto *x = AddNode(new ir::Node(name, type));
148+
x->SetId(num_node_created_++);
149+
return x;
142150
}
143151

144152
// Clear all node information of the graph and return the ownership of the
@@ -160,7 +168,7 @@ class Graph {
160168
}
161169

162170
// NOTE low performance, but simple and secure.
163-
Node *RetriveNode(int id) {
171+
Node *RetrieveNode(int id) {
164172
for (auto &node : nodes_) {
165173
if (node.second->id() == id) {
166174
return node.second.get();
@@ -169,6 +177,7 @@ class Graph {
169177
return nullptr;
170178
}
171179

180+
const ProgramDesc &program() const { return program_; }
172181
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
173182
const ProgramDesc &program);
174183

@@ -190,6 +199,7 @@ class Graph {
190199
std::map<std::string, std::function<void(void)>> attr_dels_;
191200
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
192201
std::unordered_set<ir::Node *> node_set_;
202+
size_t num_node_created_{0}; // help to generate a unique node id.
193203
};
194204

195205
bool IsControlDepVar(const ir::Node &var);

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ void GraphSafeRemoveNodes(Graph* graph,
310310
const std::unordered_set<const Node*>& nodes);
311311

312312
// Some pre-defined patterns those can be reused in multiple passes.
313-
// The related Fluid Layer or Op should be one pattern here for better reusage
314-
// accross different fusion.
313+
// The related Fluid Layer or Op should be one pattern here for better re-usage
314+
// across different fusion.
315315
namespace patterns {
316316

317317
struct KeyCounter {

paddle/fluid/framework/ir/graph_to_program_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
3535
new proto::ProgramDesc(*program.Proto()));
3636

3737
auto block = program_pb->mutable_blocks(kRootBlockIndex);
38+
block->set_idx(kRootBlockIndex);
3839
block->clear_vars();
3940
std::unordered_set<std::string> visited_vars;
4041
for (ir::Node* n : graph->Nodes()) {
41-
if (n->NodeType() == ir::Node::Type::kVariable) {
42+
if (n->IsVar()) {
4243
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
4344
visited_vars.insert(n->Var()->Name());
4445
block->add_vars()->MergeFrom(*n->Var()->Proto());

paddle/fluid/framework/ir/graph_traits.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,76 @@ NodesDFSIterator &NodesDFSIterator::operator=(const NodesDFSIterator &other) {
6666
}
6767
Node *NodesDFSIterator::operator->() { return stack_.top(); }
6868

69+
inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
70+
return node.inputs.size() == n;
71+
}
72+
73+
NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
74+
PADDLE_ENFORCE(!source.empty(),
75+
"Start points of topological sorting should not be empty!");
76+
// CHECK all the inputs' in-degree is 0
77+
for (auto *node : source) {
78+
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
79+
}
80+
81+
std::unordered_set<Node *> visited;
82+
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
83+
84+
std::vector<Node *> inlink_visited;
85+
while (!to_visit.empty()) {
86+
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
87+
for (auto *p : queue) {
88+
inlink_visited.clear();
89+
90+
std::copy_if(p->inputs.begin(), p->inputs.end(),
91+
std::back_inserter(inlink_visited),
92+
[&](Node *x) -> bool { return visited.count(x) != 0; });
93+
94+
if (inlink_visited.size() == p->inputs.size()) {
95+
sorted_.push_back(p);
96+
for (auto *_ : p->outputs) {
97+
if (!visited.count(_)) {
98+
to_visit.insert(_);
99+
}
100+
}
101+
102+
to_visit.erase(p);
103+
visited.insert(p);
104+
}
105+
}
106+
}
107+
}
108+
109+
NodesTSIterator::NodesTSIterator(const NodesTSIterator &other)
110+
: sorted_(other.sorted_), cursor_(other.cursor_) {}
111+
112+
Node &NodesTSIterator::operator*() {
113+
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
114+
return *sorted_[cursor_];
115+
}
116+
117+
NodesTSIterator &NodesTSIterator::operator++() {
118+
if (++cursor_ >= sorted_.size()) {
119+
sorted_.clear();
120+
cursor_ = 0;
121+
}
122+
return *this;
123+
}
124+
NodesTSIterator &NodesTSIterator::operator=(const NodesTSIterator &other) {
125+
cursor_ = other.cursor_;
126+
sorted_ = other.sorted_;
127+
return *this;
128+
}
129+
130+
bool NodesTSIterator::operator==(const NodesTSIterator &other) {
131+
return sorted_ == other.sorted_ && cursor_ == other.cursor_;
132+
}
133+
134+
Node *NodesTSIterator::operator->() {
135+
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
136+
return sorted_[cursor_];
137+
}
138+
69139
} // namespace ir
70140
} // namespace framework
71141
} // namespace paddle

0 commit comments

Comments
 (0)