File tree Expand file tree Collapse file tree 10 files changed +30
-21
lines changed Expand file tree Collapse file tree 10 files changed +30
-21
lines changed Original file line number Diff line number Diff line change 94
94
endif ()
95
95
96
96
97
- cc_library (parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor )
97
+ cc_library (parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph )
98
98
99
99
cc_library (prune SRCS prune.cc DEPS framework_proto )
100
100
cc_test (prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context )
Original file line number Diff line number Diff line change @@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
168
168
}
169
169
170
170
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build (
171
- const ProgramDesc &program ) const {
172
- std::unique_ptr<Graph> graph ( new Graph );
171
+ std::unique_ptr<Graph> graph ) const {
172
+ const ProgramDesc &program = graph-> Program ( );
173
173
for (auto *var : program.Block (0 ).AllVars ()) {
174
174
all_vars_.emplace (var->Name (), var);
175
175
}
Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
47
47
const BuildStrategy &strategy);
48
48
#endif
49
49
50
- std::unique_ptr<Graph> Build (const ProgramDesc &program ) const override ;
50
+ std::unique_ptr<Graph> Build (std::unique_ptr<Graph> graph ) const override ;
51
51
int GetVarDeviceID (const std::string &varname) const override ;
52
52
53
53
private:
Original file line number Diff line number Diff line change @@ -38,7 +38,7 @@ class SSAGraphBuilder {
38
38
public:
39
39
SSAGraphBuilder () {}
40
40
virtual ~SSAGraphBuilder () {}
41
- virtual std::unique_ptr<Graph> Build (const ProgramDesc &program ) const = 0;
41
+ virtual std::unique_ptr<Graph> Build (std::unique_ptr<Graph> graph ) const = 0;
42
42
virtual int GetVarDeviceID (const std::string &var_name) const = 0;
43
43
44
44
DISABLE_COPY_AND_ASSIGN (SSAGraphBuilder);
Original file line number Diff line number Diff line change @@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
29
29
std::unique_ptr<SSAGraphBuilder>&& builder)
30
30
: builder_(std::move(builder)) {}
31
31
32
- std::unique_ptr<Graph> Build (const ProgramDesc& program ) const override {
33
- auto graph = builder_->Build (program );
34
- PADDLE_ENFORCE (IsValidGraph (graph .get ()));
35
- return graph ;
32
+ std::unique_ptr<Graph> Build (std::unique_ptr<Graph> graph ) const override {
33
+ auto new_graph = builder_->Build (std::move (graph) );
34
+ PADDLE_ENFORCE (IsValidGraph (new_graph .get ()));
35
+ return new_graph ;
36
36
}
37
37
38
38
int GetVarDeviceID (const std::string& var_name) const override {
Original file line number Diff line number Diff line change @@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
50
50
stream_ptr_(std::move(sout)),
51
51
stream_ref_(*stream_ptr_) {}
52
52
53
- std::unique_ptr<Graph> Build (const ProgramDesc& program ) const override {
54
- auto graph = builder_->Build (program );
55
- printer_->Print (*graph , stream_ref_);
56
- return graph ;
53
+ std::unique_ptr<Graph> Build (std::unique_ptr<Graph> graph ) const override {
54
+ auto new_graph = builder_->Build (std::move (graph) );
55
+ printer_->Print (*new_graph , stream_ref_);
56
+ return new_graph ;
57
57
}
58
58
59
59
int GetVarDeviceID (const std::string& var_name) const override {
Original file line number Diff line number Diff line change @@ -15,5 +15,12 @@ limitations under the License. */
15
15
#include " paddle/fluid/framework/ir/graph.h"
16
16
17
17
namespace paddle {
18
- namespace framework {} // namespace framework
18
+ namespace framework {
19
+
20
+ std::unique_ptr<Graph> ProgramToGraph (const ProgramDesc &program) {
21
+ std::unique_ptr<Graph> graph (new Graph (program));
22
+ return std::move (graph);
23
+ }
24
+
25
+ } // namespace framework
19
26
} // namespace paddle
Original file line number Diff line number Diff line change @@ -20,6 +20,7 @@ limitations under the License. */
20
20
#include < vector>
21
21
22
22
#include " paddle/fluid/framework/ir/node.h"
23
+ #include " paddle/fluid/framework/program_desc.h"
23
24
#include " paddle/fluid/platform/enforce.h"
24
25
#include " paddle/fluid/platform/variant.h"
25
26
@@ -28,6 +29,8 @@ namespace framework {
28
29
29
30
class Graph {
30
31
public:
32
+ explicit Graph (const ProgramDesc& program) : program_(program) {}
33
+
31
34
virtual ~Graph () {
32
35
for (auto & attr : attrs_) {
33
36
attr_dels_[attr.first ]();
@@ -36,6 +39,8 @@ class Graph {
36
39
attr_dels_.clear ();
37
40
}
38
41
42
+ const ProgramDesc& Program () const { return program_; }
43
+
39
44
template <typename AttrType>
40
45
AttrType& Get (const std::string& attr_name) const {
41
46
return *boost::any_cast<AttrType*>(attrs_.at (attr_name));
@@ -63,9 +68,12 @@ class Graph {
63
68
std::vector<std::unique_ptr<ir::Node>> nodes;
64
69
65
70
private:
71
+ const ProgramDesc& program_;
66
72
std::map<std::string, boost::any> attrs_;
67
73
std::map<std::string, std::function<void (void )>> attr_dels_;
68
74
};
69
75
76
+ std::unique_ptr<Graph> ProgramToGraph (const ProgramDesc& program);
77
+
70
78
} // namespace framework
71
79
} // namespace paddle
Original file line number Diff line number Diff line change @@ -30,11 +30,5 @@ class Pass {
30
30
}
31
31
};
32
32
33
- std::unique_ptr<Graph> ProgramToGraph (const ProgramDesc& program) {
34
- std::unique_ptr<Graph> g (new Graph);
35
-
36
- return std::move (g);
37
- }
38
-
39
33
} // namespace framework
40
34
} // namespace paddle
Original file line number Diff line number Diff line change @@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor(
133
133
}
134
134
135
135
builder_ = builder_factory.Create ();
136
- std::unique_ptr<Graph> graph = builder_->Build (main_program);
136
+ std::unique_ptr<Graph> graph = builder_->Build (ProgramToGraph ( main_program) );
137
137
138
138
std::unique_ptr<details::SSAGraph> ssa_graph (new details::SSAGraph);
139
139
ssa_graph->vars_ = std::move (graph->Get <details::GraphVars>(" vars" ));
You can’t perform that action at this time.
0 commit comments