Skip to content

Commit 0c851ca

Browse files
committed
add SSA graph checker
1 parent 1076e85 commit 0c851ca

File tree

7 files changed

+137
-75
lines changed

7 files changed

+137
-75
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
88
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
99
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
1010
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
11+
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
1112

1213
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1314

@@ -30,7 +31,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
3031
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
3132

3233

33-
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
34+
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
3435

3536
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
3637
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -83,76 +83,6 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
8383
op->AddOutput(dummy_leaf);
8484
}
8585
}
86-
87-
std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck(
88-
const ProgramDesc &program) {
89-
std::unique_ptr<SSAGraph> graph = Build(program);
90-
PADDLE_ENFORCE(IsValidGraph(graph.get()));
91-
return std::move(graph);
92-
}
93-
94-
bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const {
95-
std::unordered_map<OpHandleBase *, size_t> pending_ops;
96-
std::unordered_set<VarHandleBase *> pending_vars;
97-
std::unordered_set<VarHandleBase *> ready_vars;
98-
std::unordered_set<OpHandleBase *> ready_ops;
99-
100-
auto insert_pending_var = [&](VarHandleBase *var) {
101-
pending_vars.insert(var);
102-
if (var->generated_op_ == nullptr) {
103-
ready_vars.emplace(var);
104-
}
105-
};
106-
107-
for (auto &var_map : graph->vars_) {
108-
for (auto &name_pair : var_map) {
109-
for (auto &version_pair : name_pair.second) {
110-
insert_pending_var(version_pair.get());
111-
}
112-
}
113-
}
114-
115-
for (auto &var : graph->dep_vars_) {
116-
insert_pending_var(var.get());
117-
}
118-
119-
for (auto &op : graph->ops_) {
120-
if (op->Inputs().empty()) {
121-
ready_ops.insert(op.get());
122-
} else {
123-
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
124-
}
125-
}
126-
127-
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
128-
for (auto *op : set) {
129-
for (auto out : op->Outputs()) {
130-
ready_vars.emplace(out);
131-
}
132-
}
133-
set.clear();
134-
};
135-
136-
while (!pending_vars.empty()) {
137-
run_all_ops(ready_ops);
138-
139-
if (ready_vars.empty()) {
140-
return false;
141-
}
142-
143-
for (auto ready_var : ready_vars) {
144-
pending_vars.erase(ready_var);
145-
for (auto *op : ready_var->pending_ops_) {
146-
auto &deps = --pending_ops[op];
147-
if (deps == 0) {
148-
ready_ops.insert(op);
149-
}
150-
}
151-
}
152-
ready_vars.clear();
153-
}
154-
return true;
155-
}
15686
} // namespace details
15787
} // namespace framework
15888
} // namespace paddle

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ class SSAGraphBuilder {
3131
virtual ~SSAGraphBuilder() {}
3232
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
3333

34-
std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program);
35-
3634
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
3735

3836
protected:
@@ -50,7 +48,6 @@ class SSAGraphBuilder {
5048
const platform::Place &place,
5149
size_t place_offset);
5250

53-
bool IsValidGraph(const SSAGraph *graph) const;
5451
// Add an output variable (each_var_name, place, place_offset) to op_handle,
5552
// which belongs to graph
5653
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,

paddle/fluid/framework/details/ssa_graph_builder_factory.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
1616
#include <fstream>
1717
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
18+
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
1819
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
1920

2021
namespace paddle {
@@ -40,6 +41,8 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
4041
res.reset(new SSAGraghBuilderWithPrinter(
4142
std::move(fout), std::move(graphviz_printer), std::move(res)));
4243
}
44+
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
45+
4346
return res;
4447
}
4548
} // namespace details
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/ssa_graph.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
24+
std::unordered_map<OpHandleBase *, size_t> pending_ops;
25+
std::unordered_set<VarHandleBase *> pending_vars;
26+
std::unordered_set<VarHandleBase *> ready_vars;
27+
std::unordered_set<OpHandleBase *> ready_ops;
28+
29+
auto insert_pending_var = [&](VarHandleBase *var) {
30+
pending_vars.insert(var);
31+
if (var->generated_op_ == nullptr) {
32+
ready_vars.emplace(var);
33+
}
34+
};
35+
36+
for (auto &var_map : graph->vars_) {
37+
for (auto &name_pair : var_map) {
38+
for (auto &version_pair : name_pair.second) {
39+
insert_pending_var(version_pair.get());
40+
}
41+
}
42+
}
43+
44+
for (auto &var : graph->dep_vars_) {
45+
insert_pending_var(var.get());
46+
}
47+
48+
for (auto &op : graph->ops_) {
49+
if (op->Inputs().empty()) {
50+
ready_ops.insert(op.get());
51+
} else {
52+
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
53+
}
54+
}
55+
56+
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
57+
for (auto *op : set) {
58+
for (auto out : op->Outputs()) {
59+
ready_vars.emplace(out);
60+
}
61+
}
62+
set.clear();
63+
};
64+
65+
while (!pending_vars.empty()) {
66+
run_all_ops(ready_ops);
67+
68+
if (ready_vars.empty()) {
69+
return false;
70+
}
71+
72+
for (auto ready_var : ready_vars) {
73+
pending_vars.erase(ready_var);
74+
for (auto *op : ready_var->pending_ops_) {
75+
auto &deps = --pending_ops[op];
76+
if (deps == 0) {
77+
ready_ops.insert(op);
78+
}
79+
}
80+
}
81+
ready_vars.clear();
82+
}
83+
return true;
84+
}
85+
} // namespace details
86+
} // namespace framework
87+
} // namespace paddle
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
class SSAGraph;
23+
24+
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
25+
public:
26+
explicit SSAGraghBuilderWithChecker(
27+
std::unique_ptr<SSAGraphBuilder>&& builder)
28+
: builder_(std::move(builder)) {}
29+
30+
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
31+
auto graph = builder_->Build(program);
32+
PADDLE_ENFORCE(IsValidGraph(graph.get()));
33+
return graph;
34+
}
35+
36+
bool IsValidGraph(const SSAGraph* graph) const;
37+
38+
private:
39+
std::unique_ptr<SSAGraphBuilder> builder_;
40+
};
41+
42+
} // namespace details
43+
} // namespace framework
44+
} // namespace paddle

paddle/fluid/framework/parallel_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor(
114114

115115
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
116116
exec_strategy, member_->local_scopes_, places,
117-
builder_factory.Create()->BuildAndCheck(main_program)));
117+
builder_factory.Create()->Build(main_program)));
118118

119119
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
120120
exec_strategy, member_->local_scopes_, std::move(var_infos),

0 commit comments

Comments
 (0)