Skip to content

Commit fa29ef0

Browse files
author
chengduo
authored
Merge pull request #11277 from chengduoZH/check_ssa_graph
Check SSA Graph
2 parents 1cfd3cb + 0c851ca commit fa29ef0

9 files changed

+141
-5
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
8787
framework_proto glog lod_rank_table feed_fetch_method)
8888

8989

90-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
90+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
9191

9292
cc_library(prune SRCS prune.cc DEPS framework_proto)
9393
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
88
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
99
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
1010
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
11+
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
1112

1213
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1314

@@ -31,7 +32,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
3132
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
3233

3334

34-
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
35+
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
3536

3637
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
3738
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
1514
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
15+
#include <utility>
1616

1717
namespace paddle {
1818
namespace framework {

paddle/fluid/framework/details/graph_builder_factory.cc renamed to paddle/fluid/framework/details/ssa_graph_builder_factory.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/details/graph_builder_factory.h"
15+
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
1616
#include <fstream>
1717
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
18+
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
1819
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
1920

2021
namespace paddle {
@@ -40,6 +41,8 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
4041
res.reset(new SSAGraghBuilderWithPrinter(
4142
std::move(fout), std::move(graphviz_printer), std::move(res)));
4243
}
44+
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
45+
4346
return res;
4447
}
4548
} // namespace details
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/ssa_graph.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
24+
std::unordered_map<OpHandleBase *, size_t> pending_ops;
25+
std::unordered_set<VarHandleBase *> pending_vars;
26+
std::unordered_set<VarHandleBase *> ready_vars;
27+
std::unordered_set<OpHandleBase *> ready_ops;
28+
29+
auto insert_pending_var = [&](VarHandleBase *var) {
30+
pending_vars.insert(var);
31+
if (var->generated_op_ == nullptr) {
32+
ready_vars.emplace(var);
33+
}
34+
};
35+
36+
for (auto &var_map : graph->vars_) {
37+
for (auto &name_pair : var_map) {
38+
for (auto &version_pair : name_pair.second) {
39+
insert_pending_var(version_pair.get());
40+
}
41+
}
42+
}
43+
44+
for (auto &var : graph->dep_vars_) {
45+
insert_pending_var(var.get());
46+
}
47+
48+
for (auto &op : graph->ops_) {
49+
if (op->Inputs().empty()) {
50+
ready_ops.insert(op.get());
51+
} else {
52+
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
53+
}
54+
}
55+
56+
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
57+
for (auto *op : set) {
58+
for (auto out : op->Outputs()) {
59+
ready_vars.emplace(out);
60+
}
61+
}
62+
set.clear();
63+
};
64+
65+
while (!pending_vars.empty()) {
66+
run_all_ops(ready_ops);
67+
68+
if (ready_vars.empty()) {
69+
return false;
70+
}
71+
72+
for (auto ready_var : ready_vars) {
73+
pending_vars.erase(ready_var);
74+
for (auto *op : ready_var->pending_ops_) {
75+
auto &deps = --pending_ops[op];
76+
if (deps == 0) {
77+
ready_ops.insert(op);
78+
}
79+
}
80+
}
81+
ready_vars.clear();
82+
}
83+
return true;
84+
}
85+
} // namespace details
86+
} // namespace framework
87+
} // namespace paddle
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
class SSAGraph;
23+
24+
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
25+
public:
26+
explicit SSAGraghBuilderWithChecker(
27+
std::unique_ptr<SSAGraphBuilder>&& builder)
28+
: builder_(std::move(builder)) {}
29+
30+
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
31+
auto graph = builder_->Build(program);
32+
PADDLE_ENFORCE(IsValidGraph(graph.get()));
33+
return graph;
34+
}
35+
36+
bool IsValidGraph(const SSAGraph* graph) const;
37+
38+
private:
39+
std::unique_ptr<SSAGraphBuilder> builder_;
40+
};
41+
42+
} // namespace details
43+
} // namespace framework
44+
} // namespace paddle

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
185185
ready_vars->Push(var);
186186
}
187187
}
188+
188189
void ThreadedSSAGraphExecutor::RunOp(
189190
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
190191
auto op_run = [ready_var_q, op, this] {

paddle/fluid/framework/parallel_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ limitations under the License. */
2222
#include "paddle/fluid/platform/nccl_helper.h"
2323
#endif
2424

25-
#include "paddle/fluid/framework/details/graph_builder_factory.h"
2625
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
26+
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
2727
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
2828
#include "paddle/fluid/platform/profiler.h"
2929

0 commit comments

Comments
 (0)