Skip to content

Commit af79b19

Browse files
committed
add a simple program to graph
1 parent 7231ef6 commit af79b19

File tree

10 files changed

+30
-21
lines changed

10 files changed

+30
-21
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ else()
9494
endif()
9595

9696

97-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
97+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
9898

9999
cc_library(prune SRCS prune.cc DEPS framework_proto)
100100
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
168168
}
169169

170170
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
171-
const ProgramDesc &program) const {
172-
std::unique_ptr<Graph> graph(new Graph);
171+
std::unique_ptr<Graph> graph) const {
172+
const ProgramDesc &program = graph->Program();
173173
for (auto *var : program.Block(0).AllVars()) {
174174
all_vars_.emplace(var->Name(), var);
175175
}

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4747
const BuildStrategy &strategy);
4848
#endif
4949

50-
std::unique_ptr<Graph> Build(const ProgramDesc &program) const override;
50+
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override;
5151
int GetVarDeviceID(const std::string &varname) const override;
5252

5353
private:

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SSAGraphBuilder {
3838
public:
3939
SSAGraphBuilder() {}
4040
virtual ~SSAGraphBuilder() {}
41-
virtual std::unique_ptr<Graph> Build(const ProgramDesc &program) const = 0;
41+
virtual std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const = 0;
4242
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
4343

4444
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2929
std::unique_ptr<SSAGraphBuilder>&& builder)
3030
: builder_(std::move(builder)) {}
3131

32-
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
33-
auto graph = builder_->Build(program);
34-
PADDLE_ENFORCE(IsValidGraph(graph.get()));
35-
return graph;
32+
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
33+
auto new_graph = builder_->Build(std::move(graph));
34+
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
35+
return new_graph;
3636
}
3737

3838
int GetVarDeviceID(const std::string& var_name) const override {

paddle/fluid/framework/details/ssa_graph_printer.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
5050
stream_ptr_(std::move(sout)),
5151
stream_ref_(*stream_ptr_) {}
5252

53-
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
54-
auto graph = builder_->Build(program);
55-
printer_->Print(*graph, stream_ref_);
56-
return graph;
53+
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
54+
auto new_graph = builder_->Build(std::move(graph));
55+
printer_->Print(*new_graph, stream_ref_);
56+
return new_graph;
5757
}
5858

5959
int GetVarDeviceID(const std::string& var_name) const override {

paddle/fluid/framework/ir/graph.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,12 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/ir/graph.h"
1616

1717
namespace paddle {
18-
namespace framework {} // namespace framework
18+
namespace framework {
19+
20+
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
21+
std::unique_ptr<Graph> graph(new Graph(program));
22+
return std::move(graph);
23+
}
24+
25+
} // namespace framework
1926
} // namespace paddle

paddle/fluid/framework/ir/graph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include <vector>
2121

2222
#include "paddle/fluid/framework/ir/node.h"
23+
#include "paddle/fluid/framework/program_desc.h"
2324
#include "paddle/fluid/platform/enforce.h"
2425
#include "paddle/fluid/platform/variant.h"
2526

@@ -28,6 +29,8 @@ namespace framework {
2829

2930
class Graph {
3031
public:
32+
explicit Graph(const ProgramDesc& program) : program_(program) {}
33+
3134
virtual ~Graph() {
3235
for (auto& attr : attrs_) {
3336
attr_dels_[attr.first]();
@@ -36,6 +39,8 @@ class Graph {
3639
attr_dels_.clear();
3740
}
3841

42+
const ProgramDesc& Program() const { return program_; }
43+
3944
template <typename AttrType>
4045
AttrType& Get(const std::string& attr_name) const {
4146
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
@@ -63,9 +68,12 @@ class Graph {
6368
std::vector<std::unique_ptr<ir::Node>> nodes;
6469

6570
private:
71+
const ProgramDesc& program_;
6672
std::map<std::string, boost::any> attrs_;
6773
std::map<std::string, std::function<void(void)>> attr_dels_;
6874
};
6975

76+
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program);
77+
7078
} // namespace framework
7179
} // namespace paddle

paddle/fluid/framework/ir/pass.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,5 @@ class Pass {
3030
}
3131
};
3232

33-
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
34-
std::unique_ptr<Graph> g(new Graph);
35-
36-
return std::move(g);
37-
}
38-
3933
} // namespace framework
4034
} // namespace paddle

paddle/fluid/framework/parallel_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor(
133133
}
134134

135135
builder_ = builder_factory.Create();
136-
std::unique_ptr<Graph> graph = builder_->Build(main_program);
136+
std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program));
137137

138138
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
139139
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));

0 commit comments

Comments
 (0)