Skip to content

Commit 99c0c20

Browse files
committed
add pass test
1 parent 12e9bf6 commit 99c0c20

File tree

7 files changed

+195
-26
lines changed

7 files changed

+195
-26
lines changed

doc/fluid/design/ir/draft.md

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
6464
the `Graph` or `Graph` nodes. `Attribute` can be passed
6565
across `Pass`. However, it should be used with care.
6666

67+
```cpp
68+
class Graph {
69+
public:
70+
explicit Graph(const ProgramDesc &program);
71+
72+
bool Has(const std::string &attr_name) const;
73+
74+
template <typename AttrType>
75+
AttrType &Get(const std::string &attr_name) const;
76+
77+
template <typename AttrType>
78+
void Set(const std::string &attr_name, AttrType *attr);
79+
const std::unordered_set<ir::Node *> &Nodes() const;
80+
81+
// Create a normal variable with non-null VarDesc.
82+
ir::Node *CreateVarNode(VarDesc *var_desc);
83+
84+
// Create a normal runnable operator with OpDesc.
85+
ir::Node *CreateOpNode(OpDesc *op_desc);
86+
87+
// Create a control dependency var that connects 2 operations. The
88+
// var doesn't hold any data. Other than that, it's no different from
89+
// other var, considering dependency analysis.
90+
ir::Node *CreateControlDepVar();
91+
92+
// A more free style way of creating a graph node. Mostly use for test
93+
// or "copy" from another node. Avoid using it if possible.
94+
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type);
95+
96+
// Clear all node information of the graph and return the ownership of the
97+
// nodes.
98+
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes();
99+
};
100+
```
101+
67102
#### Pass
68103
69104
`Pass` represents a transformation of `Graph`. Its input
@@ -101,13 +136,15 @@ class Pass {
101136
102137
// In my_pass.cc
103138
class MyPass : public Pass {
104-
public:
105-
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
139+
protected:
140+
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override {
106141
// do something.
107142
return graph;
108143
}
109144
}
110-
REGISTER_PASS(my_pass, MyPass);
145+
REGISTER_PASS(my_pass, MyPass)
146+
.RequirePassAttr("places")
147+
.RequireGraphAttr("dep_vars");
111148
112149
113150
// To use the pass.
@@ -132,4 +169,17 @@ maintaining the original modeling logic.
132169
* Graph is transformed from raw model logic to a
133170
form that is efficient to execute.
134171

135-
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
172+
```
173+
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
174+
auto graph = Graph(program);
175+
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
176+
// For more complex Pass, Optimize Process can provide Pass attributes.
177+
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
178+
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
179+
mem_opt_pass->Apply(std::move(graph));
180+
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
181+
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
182+
Executor exe;
183+
exe.Run(graph);
184+
185+
```
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
cc_library(node SRCS node.cc DEPS proto_desc)
22
cc_library(graph SRCS graph.cc DEPS node)
33
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
4-
cc_library(pass SRCS pass.cc DEPS graph node)
4+
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
55
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
6-
cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry)
7-
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry)
6+
7+
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
8+
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
9+
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)

paddle/fluid/framework/ir/graph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class Graph {
5353

5454
template <typename AttrType>
5555
void Set(const std::string &attr_name, AttrType *attr) {
56-
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
56+
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph",
57+
attr_name);
5758
attrs_[attr_name] = attr;
5859
attr_dels_[attr_name] = [attr, attr_name]() {
5960
VLOG(3) << "deleting " << attr_name;

paddle/fluid/framework/ir/pass.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,27 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/ir/pass.h"
16+
#include "paddle/fluid/framework/ir/graph_helper.h"
1617

1718
namespace paddle {
1819
namespace framework {
1920
namespace ir {
2021
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
22+
PADDLE_ENFORCE(!applied_, "Pass can only Apply() once.");
23+
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
2124
for (const std::string& attr : required_pass_attrs_) {
2225
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
23-
"Required pass atrribute %s not registered.", attr);
26+
"Required pass atrribute %s not set.", attr);
2427
}
2528
for (const std::string& attr : required_graph_attrs_) {
26-
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.",
29+
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
2730
attr);
2831
}
2932
auto applied_graph = ApplyImpl(std::move(graph));
3033
// TODO(panyx0718): Add more verifications.
3134
PADDLE_ENFORCE(!HasCircle(*applied_graph),
3235
"Illegal Pass. Generated graph shouldn't has cycle.");
36+
applied_ = true;
3337
return applied_graph;
3438
}
3539

paddle/fluid/framework/ir/pass.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ limitations under the License. */
1919
#include <string>
2020

2121
#include "paddle/fluid/framework/ir/graph.h"
22-
#include "paddle/fluid/framework/ir/graph_helper.h"
2322
#include "paddle/fluid/framework/ir/node.h"
2423
#include "paddle/fluid/framework/program_desc.h"
2524
#include "paddle/fluid/platform/variant.h"
@@ -56,7 +55,8 @@ class Pass {
5655
// Set a pointer to the attribute. Pass takes ownership of the attribute.
5756
template <typename AttrType>
5857
void Set(const std::string &attr_name, AttrType *attr) {
59-
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
58+
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
59+
attr_name);
6060
attrs_[attr_name] = attr;
6161
attr_dels_[attr_name] = [attr, attr_name]() {
6262
VLOG(3) << "deleting " << attr_name;
@@ -89,6 +89,7 @@ class Pass {
8989
required_graph_attrs_.insert(attrs.begin(), attrs.end());
9090
}
9191

92+
mutable bool applied_{false};
9293
std::unordered_set<std::string> required_pass_attrs_;
9394
std::unordered_set<std::string> required_graph_attrs_;
9495
std::map<std::string, boost::any> attrs_;
@@ -118,14 +119,15 @@ class PassRegistry {
118119
return map_.find(pass_type) != map_.end();
119120
}
120121

121-
void Insert(const std::string &type, const PassCreator &pass_creator) {
122-
PADDLE_ENFORCE(!Has(type), "Pass %s has been registered", type);
123-
map_.insert({type, pass_creator});
122+
void Insert(const std::string &pass_type, const PassCreator &pass_creator) {
123+
PADDLE_ENFORCE(!Has(pass_type), "Pass %s has been registered", pass_type);
124+
map_.insert({pass_type, pass_creator});
124125
}
125126

126-
std::unique_ptr<Pass> Get(const std::string &type) const {
127-
PADDLE_ENFORCE(Has(type), "Pass %s has not been registered", type);
128-
return map_.at(type)();
127+
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
128+
PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered",
129+
pass_type);
130+
return map_.at(pass_type)();
129131
}
130132

131133
private:
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/pass.h"
16+
#include <string>
17+
#include "gtest/gtest.h"
18+
#include "paddle/fluid/framework/ir/graph.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
void BuildCircleGraph(Graph* g) {
24+
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
25+
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
26+
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
27+
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
28+
29+
o1->outputs.push_back(v1);
30+
o2->inputs.push_back(v1);
31+
v1->inputs.push_back(o1);
32+
v1->outputs.push_back(o2);
33+
34+
o2->outputs.push_back(v2);
35+
o1->inputs.push_back(v2);
36+
v2->inputs.push_back(o2);
37+
v2->outputs.push_back(o1);
38+
}
39+
40+
class TestPass : public Pass {
41+
protected:
42+
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
43+
graph->Set<int>("copy_test_pass_attr", new int);
44+
graph->Set<int>("copy_test_graph_attr", new int);
45+
46+
int test_pass_attr = this->Get<int>("test_pass_attr");
47+
graph->Get<int>("copy_test_pass_attr") = test_pass_attr + 1;
48+
49+
int test_graph_attr = graph->Get<int>("test_graph_attr");
50+
graph->Get<int>("copy_test_graph_attr") = test_graph_attr + 1;
51+
return graph;
52+
}
53+
};
54+
55+
TEST(PassTest, TestPassAttrCheck) {
56+
ProgramDesc prog;
57+
auto pass = PassRegistry::Instance().Get("test_pass");
58+
std::unique_ptr<Graph> graph(new Graph(prog));
59+
std::string exception;
60+
try {
61+
graph = pass->Apply(std::move(graph));
62+
} catch (paddle::platform::EnforceNotMet e) {
63+
exception = std::string(e.what());
64+
}
65+
ASSERT_TRUE(exception.find("test_pass_attr not set") != exception.npos);
66+
67+
int val = 1;
68+
graph.reset(new Graph(prog));
69+
pass->SetNotOwned<int>("test_pass_attr", &val);
70+
71+
try {
72+
graph = pass->Apply(std::move(graph));
73+
} catch (paddle::platform::EnforceNotMet e) {
74+
exception = std::string(e.what());
75+
}
76+
ASSERT_TRUE(exception.find("test_graph_attr not set") != exception.npos);
77+
78+
graph.reset(new Graph(prog));
79+
graph->Set<int>("test_graph_attr", new int);
80+
graph->Get<int>("test_graph_attr") = 1;
81+
graph = pass->Apply(std::move(graph));
82+
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
83+
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);
84+
85+
try {
86+
graph = pass->Apply(std::move(graph));
87+
} catch (paddle::platform::EnforceNotMet e) {
88+
exception = std::string(e.what());
89+
}
90+
ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos);
91+
92+
pass = PassRegistry::Instance().Get("test_pass");
93+
pass->SetNotOwned<int>("test_pass_attr", &val);
94+
graph.reset(new Graph(prog));
95+
BuildCircleGraph(graph.get());
96+
graph->Set<int>("test_graph_attr", new int);
97+
graph->Get<int>("test_graph_attr") = 2;
98+
try {
99+
auto tmp = pass->Apply(std::move(graph));
100+
} catch (paddle::platform::EnforceNotMet e) {
101+
exception = std::string(e.what());
102+
}
103+
ASSERT_TRUE(exception.find("shouldn't has cycle") != exception.npos);
104+
}
105+
106+
} // namespace ir
107+
} // namespace framework
108+
} // namespace paddle
109+
110+
REGISTER_PASS(test_pass, paddle::framework::ir::TestPass)
111+
.RequirePassAttr("test_pass_attr")
112+
.RequireGraphAttr("test_graph_attr");

paddle/fluid/framework/parallel_executor.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
4444
#else
4545
const BuildStrategy &strategy) {
4646
#endif
47+
// Convert the program to graph.
4748
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
49+
50+
// Apply a graph viz pass to record a graph.
4851
if (!strategy.debug_graphviz_path_.empty()) {
4952
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
5053
const std::string graph_path = string::Sprintf(
@@ -53,6 +56,7 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
5356
graph = viz_pass->Apply(std::move(graph));
5457
}
5558

59+
// Convert graph to run on multi-devices.
5660
auto multi_device_pass =
5761
ir::PassRegistry::Instance().Get("multi_device_pass");
5862
multi_device_pass->SetNotOwned<const std::vector<platform::Place>>("places",
@@ -71,6 +75,7 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
7175
#endif
7276
graph = multi_device_pass->Apply(std::move(graph));
7377

78+
// Apply a graph print pass to record a graph with device info.
7479
if (!strategy.debug_graphviz_path_.empty()) {
7580
auto multi_device_print_pass =
7681
ir::PassRegistry::Instance().Get("multi_device_print_pass");
@@ -81,17 +86,10 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
8186
graph = multi_device_print_pass->Apply(std::move(graph));
8287
}
8388

89+
// Verify that the graph is correct for multi-device executor.
8490
auto multi_device_check_pass =
8591
ir::PassRegistry::Instance().Get("multi_device_check_pass");
8692
graph = multi_device_check_pass->Apply(std::move(graph));
87-
88-
if (!strategy.debug_graphviz_path_.empty()) {
89-
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
90-
const std::string graph_path = string::Sprintf(
91-
"%s%s", strategy.debug_graphviz_path_.c_str(), "_before_exec");
92-
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
93-
graph = viz_pass->Apply(std::move(graph));
94-
}
9593
return graph;
9694
}
9795

0 commit comments

Comments
 (0)