Skip to content

Commit 9605fcd

Browse files
committed
all graphs
1 parent af79b19 commit 9605fcd

File tree

5 files changed

+12
-17
lines changed

5 files changed

+12
-17
lines changed

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
namespace paddle {
2222
namespace framework {
2323
namespace details {
24-
struct SSAGraph;
2524

2625
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2726
public:

paddle/fluid/framework/details/ssa_graph_printer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
namespace paddle {
2222
namespace framework {
2323
namespace details {
24-
struct SSAGraph;
24+
2525
class SSAGraphPrinter {
2626
public:
2727
virtual ~SSAGraphPrinter() {}

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
1616

17+
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
18+
1719
namespace paddle {
1820
namespace framework {
1921
namespace details {
2022
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2123
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
22-
const std::vector<platform::Place> &places,
23-
std::unique_ptr<SSAGraph> &&graph)
24+
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph)
2425
: graph_(std::move(graph)),
2526
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
2627
: nullptr),
@@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
4344
std::unordered_set<OpHandleBase *> delayed_ops;
4445

4546
// Transform SSAGraph to pending_ops & pending_vars
46-
for (auto &var_map : graph_->vars_) {
47+
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
4748
for (auto &name_pair : var_map) {
4849
for (auto &version_pair : name_pair.second) {
4950
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
5051
}
5152
}
5253
}
53-
for (auto &var : graph_->dep_vars_) {
54+
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
5455
InsertPendingVar(&pending_vars, &ready_vars, var.get());
5556
}
5657

57-
for (auto &op : graph_->ops_) {
58+
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
5859
if (op->Inputs().empty()) { // Special case, Op has no input.
5960
ready_ops.insert(op.get());
6061
} else {
@@ -158,7 +159,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
158159
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
159160

160161
for (auto &fetch_var_name : fetch_tensors) {
161-
for (auto &var_map : graph_->vars_) {
162+
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
162163
auto it = var_map.find(fetch_var_name);
163164
if (it != var_map.end()) {
164165
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "paddle/fluid/framework/details/execution_strategy.h"
2828
#include "paddle/fluid/framework/details/fetch_op_handle.h"
2929
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
30+
#include "paddle/fluid/framework/ir/graph.h"
3031

3132
namespace paddle {
3233
namespace framework {
@@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
3940
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
4041
const std::vector<Scope *> &local_scopes,
4142
const std::vector<platform::Place> &places,
42-
std::unique_ptr<SSAGraph> &&graph);
43+
std::unique_ptr<Graph> &&graph);
4344

4445
// Run a SSAGraph by a thread pool
4546
// Use topological sort algorithm
@@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5253
details::OpHandleBase *op);
5354

5455
private:
55-
std::unique_ptr<SSAGraph> graph_;
56+
std::unique_ptr<Graph> graph_;
5657
std::unique_ptr<::ThreadPool> pool_;
5758
std::vector<Scope *> local_scopes_;
5859
std::vector<platform::Place> places_;

paddle/fluid/framework/parallel_executor.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,8 @@ ParallelExecutor::ParallelExecutor(
135135
builder_ = builder_factory.Create();
136136
std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program));
137137

138-
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
139-
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));
140-
ssa_graph->ops_ = std::move(graph->Get<details::GraphOps>("ops"));
141-
ssa_graph->dep_vars_ =
142-
std::move(graph->Get<details::GraphDepVars>("dep_vars"));
143-
144138
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
145-
exec_strategy, member_->local_scopes_, places, std::move(ssa_graph)));
139+
exec_strategy, member_->local_scopes_, places, std::move(graph)));
146140

147141
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
148142
exec_strategy, member_->local_scopes_, std::move(var_infos),

0 commit comments

Comments
 (0)