@@ -64,13 +64,96 @@ can also contain other things that describe some properties of
64
64
the ` Graph ` or ` Graph ` nodes. ` Attribute ` can be passed
65
65
across ` Pass ` . However, it should be used with care.
66
66
67
+ ``` cpp
68
+ class Graph {
69
+ public:
70
+ explicit Graph(const ProgramDesc &program);
71
+
72
+ bool Has(const std::string &attr_name) const;
73
+
74
+ template <typename AttrType >
75
+ AttrType &Get(const std::string &attr_name) const;
76
+
77
+ template <typename AttrType >
78
+ void Set(const std::string &attr_name, AttrType * attr);
79
+ const std::unordered_set<ir::Node * > &Nodes() const;
80
+
81
+ // Create a normal variable with non-null VarDesc.
82
+ ir::Node * CreateVarNode(VarDesc * var_desc);
83
+
84
+ // Create a normal runnable operator with OpDesc.
85
+ ir::Node * CreateOpNode(OpDesc * op_desc);
86
+
87
+ // Create a control dependency var that connects 2 operations. The
88
+ // var doesn't hold any data. Other than that, it's no different from
89
+ // other var, considering dependency analysis.
90
+ ir::Node * CreateControlDepVar();
91
+
92
+ // A more free style way of creating a graph node. Mostly use for test
93
+ // or "copy" from another node. Avoid using it if possible.
94
+ ir::Node * CreateEmptyNode(const std::string &name, ir::Node::Type type);
95
+
96
+ // Clear all node information of the graph and return the ownership of the
97
+ // nodes.
98
+ std::vector< std::unique_ptr<ir::Node > > ReleaseNodes();
99
+ };
100
+ ```
101
+
67
102
#### Pass
68
103
69
104
`Pass` represents a transformation of `Graph`. Its input
70
105
is a `Graph` and its output is also a `Graph`. For example,
71
106
a `Pass` can simply print out the `Graph`. A `Pass`
72
107
can also fuse some `Graph`'s `Node`s.
73
108
109
+ ```cpp
110
+ class Pass {
111
+ public:
112
+
113
+ std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
114
+ // Some correctness check.
115
+ auto new_graph = ApplyImpl(std::move(graph));
116
+ // Some correctness check.
117
+ return new_graph;
118
+ }
119
+
120
+ // Get a reference to the attributed previously set.
121
+ template <typename AttrType>
122
+ AttrType &Get(const std::string &attr_name) const;
123
+
124
+ // Set a pointer to the attribute. Pass takes ownership of the attribute.
125
+ template <typename AttrType>
126
+ void Set(const std::string &attr_name, AttrType *attr) ;
127
+
128
+ // Set a pointer to the attribute. Pass doesn't take ownership. Caller
129
+ // should delete the attribute.
130
+ template <typename AttrType>
131
+ void SetNotOwned(const std::string &attr_name, AttrType *attr);
132
+
133
+ protected:
134
+ virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
135
+ };
136
+
137
+ // In my_pass.cc
138
+ class MyPass : public Pass {
139
+ protected:
140
+ std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override {
141
+ // do something.
142
+ return graph;
143
+ }
144
+ }
145
+ REGISTER_PASS(my_pass, MyPass)
146
+ .RequirePassAttr("places")
147
+ .RequireGraphAttr("dep_vars");
148
+
149
+
150
+ // To use the pass.
151
+ auto my_pass = ir::PassRegistry::Instance().Get("my_pass");
152
+ graph = my_pass->Apply(std::move(graph));
153
+ // Note: to force link my_pass.cc, in the code:
154
+ USE_PASS(my_pass);
155
+ ```
156
+
74
157
#### Optimize
75
158
76
159
` Optimize ` contains a series of ` Pass ` with defined order.
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
86
169
* Graph is transformed from raw model logic to a
87
170
form that is efficient to execute.
88
171
89
- Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
172
+ ```
173
+ // Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
174
+ auto graph = Graph(program);
175
+ graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
176
+ // For more complex Pass, Optimize Process can provide Pass attributes.
177
+ auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
178
+ mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
179
+ mem_opt_pass->Apply(std::move(graph));
180
+ graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
181
+ graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
182
+ Executor exe;
183
+ exe.Run(graph);
184
+
185
+ ```
0 commit comments