Skip to content

Commit 5175b3c

Browse files
author
chengduo
authored
Add GraphChecker (#13580)
* add GraphNum test=develop * add graph number check in parallelExecutor test=develop * fix transformer_model bug test=develop * fix graph num
1 parent 7cd2761 commit 5175b3c

File tree

5 files changed

+171
-3
lines changed

5 files changed

+171
-3
lines changed

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/ir/graph_helper.h"
1516
#include <algorithm>
17+
#include <deque>
1618
#include <unordered_set>
1719

18-
#include "paddle/fluid/framework/ir/graph_helper.h"
19-
2020
namespace paddle {
2121
namespace framework {
2222
namespace ir {
@@ -113,6 +113,74 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
113113
return adj_list;
114114
}
115115

116+
size_t GraphNum(const Graph &graph) {
117+
std::unordered_set<ir::Node *> nodes = graph.Nodes();
118+
std::unordered_set<ir::Node *> visited_nodes;
119+
visited_nodes.reserve(nodes.size());
120+
std::deque<ir::Node *> q_nodes;
121+
std::vector<std::unordered_set<ir::Node *>> graph_nodes;
122+
std::unordered_set<ir::Node *> g_nodes;
123+
size_t graph_count = 0;
124+
125+
auto traverse_nodes = [&visited_nodes,
126+
&q_nodes](const std::vector<ir::Node *> &nodes) {
127+
std::copy_if(
128+
nodes.begin(), nodes.end(), std::back_inserter(q_nodes),
129+
[&visited_nodes](Node *node) { return !visited_nodes.count(node); });
130+
};
131+
132+
while (visited_nodes.size() != nodes.size()) {
133+
if (!q_nodes.empty()) {
134+
auto cur_node = q_nodes.front();
135+
q_nodes.pop_front();
136+
visited_nodes.insert(cur_node);
137+
g_nodes.insert(cur_node);
138+
traverse_nodes(cur_node->inputs);
139+
traverse_nodes(cur_node->outputs);
140+
} else {
141+
++graph_count;
142+
if (g_nodes.size()) {
143+
graph_nodes.emplace_back(g_nodes);
144+
}
145+
g_nodes.clear();
146+
for (auto &n : nodes) {
147+
if (visited_nodes.count(n) == 0) {
148+
q_nodes.push_back(n);
149+
break;
150+
}
151+
}
152+
}
153+
}
154+
155+
if (g_nodes.size()) {
156+
graph_nodes.emplace_back(g_nodes);
157+
}
158+
159+
if (VLOG_IS_ON(10)) {
160+
VLOG(10) << "graph_num: " << graph_nodes.size();
161+
for (auto &g_n : graph_nodes) {
162+
VLOG(10) << "graph_nodes: " << g_n.size();
163+
if (g_n.size() < 10) {
164+
std::stringstream out;
165+
for (auto &node : g_n) {
166+
out << "\nNode: " << node->Name() << " in [";
167+
for (auto &n : node->inputs) {
168+
out << n->Name() << ", ";
169+
}
170+
out << "], out[";
171+
for (auto &n : node->outputs) {
172+
out << n->Name() << ", ";
173+
}
174+
out << "]";
175+
}
176+
VLOG(10) << out.str();
177+
}
178+
}
179+
}
180+
181+
return graph_count;
182+
}
183+
116184
} // namespace ir
117185
} // namespace framework
118186
} // namespace paddle

paddle/fluid/framework/ir/graph_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace ir {
2727
// Test if the graph contains circle.
2828
bool HasCircle(const Graph &graph);
2929

30+
size_t GraphNum(const Graph &graph);
31+
3032
// Topology Sort the operations in the graph from inputs to outputs.
3133
// `graph` cannot contain circle.
3234
std::vector<ir::Node *> TopologySortOperations(const Graph &graph);

paddle/fluid/framework/ir/graph_helper_test.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,97 @@ TEST(GraphHelperTest, Basic) {
120120
ASSERT_EQ(node_map.at("op2"), 1UL);
121121
ASSERT_TRUE(node_map.at("op3") < node_map.at("op5"));
122122
}
123+
124+
void BuildZeroGraph(Graph* g) {}
125+
126+
void BuildOneGraph(Graph* g) {
127+
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
128+
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
129+
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
130+
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
131+
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
132+
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
133+
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
134+
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
135+
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
136+
137+
// o1->v1->o2
138+
o1->outputs.push_back(v1);
139+
o2->inputs.push_back(v1);
140+
v1->inputs.push_back(o1);
141+
v1->outputs.push_back(o2);
142+
// o2->v2->o3
143+
// o2->v2->o4
144+
o2->outputs.push_back(v2);
145+
o3->inputs.push_back(v2);
146+
o4->inputs.push_back(v2);
147+
v2->inputs.push_back(o2);
148+
v2->outputs.push_back(o3);
149+
v2->outputs.push_back(o4);
150+
// o2->v3->o5
151+
o2->outputs.push_back(v3);
152+
o5->inputs.push_back(v3);
153+
v3->inputs.push_back(o2);
154+
v3->outputs.push_back(o5);
155+
// o3-v4->o5
156+
o3->outputs.push_back(v4);
157+
o5->inputs.push_back(v4);
158+
v4->inputs.push_back(o3);
159+
v4->outputs.push_back(o5);
160+
}
161+
162+
void BuildTwoGraphs(Graph* g) {
163+
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
164+
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
165+
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
166+
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
167+
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
168+
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
169+
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
170+
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
171+
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
172+
173+
// o1->v1->o2
174+
o1->outputs.push_back(v1);
175+
o2->inputs.push_back(v1);
176+
v1->inputs.push_back(o1);
177+
v1->outputs.push_back(o2);
178+
// o2->v2->o3
179+
// o2->v2->o4
180+
o2->outputs.push_back(v2);
181+
o3->inputs.push_back(v2);
182+
o4->inputs.push_back(v2);
183+
v2->inputs.push_back(o2);
184+
v2->outputs.push_back(o3);
185+
v2->outputs.push_back(o4);
186+
// o2->v3->o5
187+
// o2->outputs.push_back(v3);
188+
o5->inputs.push_back(v3);
189+
// v3->inputs.push_back(o2);
190+
v3->outputs.push_back(o5);
191+
// o3-v4->o5
192+
o3->outputs.push_back(v4);
193+
// o5->inputs.push_back(v4);
194+
v4->inputs.push_back(o3);
195+
// v4->outputs.push_back(o5);
196+
}
197+
198+
TEST(GraphHelperTest, GraphNum) {
199+
ProgramDesc prog;
200+
201+
Graph g(prog);
202+
BuildZeroGraph(&g);
203+
ASSERT_EQ(GraphNum(g), 0);
204+
205+
Graph g2(prog);
206+
BuildOneGraph(&g2);
207+
ASSERT_EQ(GraphNum(g2), 1);
208+
209+
Graph g3(prog);
210+
BuildTwoGraphs(&g3);
211+
ASSERT_EQ(GraphNum(g3), 2);
212+
}
213+
123214
} // namespace ir
124215
} // namespace framework
125216
} // namespace paddle

paddle/fluid/framework/parallel_executor.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/parallel_executor.h"
16-
1716
#include <string>
1817
#include <tuple>
1918
#include <vector>
19+
#include "paddle/fluid/framework/ir/graph_helper.h"
2020

2121
#include "paddle/fluid/framework/ir/graph.h"
2222

@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor(
156156
params, member_->local_scopes_, member_->use_cuda_);
157157
#endif
158158

159+
// If the loss_var_name is given, the number of graph should be only one.
160+
if (loss_var_name.size()) {
161+
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
162+
"The number of graph should be only one");
163+
}
164+
159165
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
160166
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
161167
exec_strategy, member_->local_scopes_, places, std::move(graph)));

python/paddle/fluid/tests/unittests/transformer_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def prepare_encoder(src_word,
246246
padding_idx=pos_pad_idx,
247247
param_attr=fluid.ParamAttr(
248248
name=pos_enc_param_name, trainable=False))
249+
src_pos_enc.stop_gradient = True
249250
enc_input = src_word_emb + src_pos_enc
250251

251252
# FIXME(guosheng): Decouple the program desc with batch_size.

0 commit comments

Comments
 (0)