@@ -74,25 +74,208 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
74
74
node_map.at (b)->attr (kUnionFindParent ).Int32 () = a_ancestor;
75
75
}
76
76
77
+ // This is a simple representation of a graph.
78
+ // The BriefNode hold the pointer of the Node.
79
+ // This is to avoid changing the original graph
80
+ // in the process of trt graph analysis.
81
+ struct BriefNode {
82
+ explicit BriefNode (Node *n) { node = n; }
83
+ Node *node;
84
+ std::vector<BriefNode *> inlinks;
85
+ std::vector<BriefNode *> outlinks;
86
+ };
87
+
88
+ // Union two adjacent BriefNode.
89
+ // Suppose we have two adjacent nodes src and dst.
90
+ // We will perform the following operations:
91
+ // 1. add all inputs(except src) of dst to src inlinks.
92
+ // 2. add all outputs of dst to src outlinks.
93
+ // 3. change all the dst's inputs and outputs
94
+ // corresponding inlinks and outlinks to src node.
95
+ // 4. delete all dst's inlinks and outlinks.
96
+ void UnionContractedNodes (const std::unordered_map<int , BriefNode *> &node_map,
97
+ int src_id, int dst_id) {
98
+ // merge the two adjacent nodes into one node.
99
+ BriefNode *src_node = node_map.at (src_id);
100
+ BriefNode *dst_node = node_map.at (dst_id);
101
+
102
+ std::unordered_set<BriefNode *> inputs (src_node->inlinks .begin (),
103
+ src_node->inlinks .end ());
104
+ std::unordered_set<BriefNode *> outputs;
105
+
106
+ for (auto *n : src_node->outlinks ) {
107
+ if (n != dst_node) outputs.insert (n);
108
+ }
109
+
110
+ // Add the inlinks and outlinks of dst node to src node.
111
+ std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks ;
112
+ for (BriefNode *node : dst_in_nodes) {
113
+ if (node != src_node) {
114
+ inputs.insert (node);
115
+ }
116
+ }
117
+
118
+ std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks ;
119
+ for (BriefNode *node : dst_out_nodes) {
120
+ outputs.insert (node);
121
+ }
122
+
123
+ // update the dst and src node's inlinks and outlinks.
124
+ src_node->inlinks =
125
+ std::move (std::vector<BriefNode *>(inputs.begin (), inputs.end ()));
126
+ src_node->outlinks =
127
+ std::move (std::vector<BriefNode *>(outputs.begin (), outputs.end ()));
128
+ dst_node->inlinks .clear ();
129
+ dst_node->outlinks .clear ();
130
+
131
+ auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
132
+ for (auto *&n : nodes) {
133
+ if (n == src_node || n == dst_node) {
134
+ n = src_node;
135
+ }
136
+ }
137
+ };
138
+ // Change all the dst inputs and outputs corresponding inlink and
139
+ // outlink to the src node.
140
+ for (auto *node : src_node->inlinks ) {
141
+ inlink_or_outlink_cleaner (node->outlinks );
142
+ }
143
+
144
+ for (auto *node : src_node->outlinks ) {
145
+ inlink_or_outlink_cleaner (node->inlinks );
146
+ }
147
+ }
148
+
149
+ // FlexibleDFS
150
+ // If reverse is true, do reverse dfs.
151
+ // If enter func is not nullptr, calls enter(node) before visiting any children
152
+ // of node.
153
+ // If leave func not nullptr, calls leave(node) after visiting all parents of
154
+ // node.
155
+ void FlexibleDFS (const std::vector<BriefNode *> &source, bool reverse,
156
+ const std::function<bool (const BriefNode *)> &enter,
157
+ const std::function<bool(const BriefNode *)> &leave) {
158
+ typedef struct {
159
+ const BriefNode *node;
160
+ bool leave;
161
+ } FNode;
162
+
163
+ std::vector<FNode> stack;
164
+ for (auto &node : source) {
165
+ stack.push_back (FNode{node, false });
166
+ }
167
+ std::unordered_set<const BriefNode *> visited;
168
+ while (!stack.empty ()) {
169
+ auto fnode = stack.back ();
170
+ stack.pop_back ();
171
+
172
+ if (fnode.leave ) {
173
+ if (leave && !leave (fnode.node )) return ;
174
+ }
175
+ if (visited.count (fnode.node )) continue ;
176
+ visited.insert (fnode.node );
177
+
178
+ if (enter && !enter (fnode.node )) return ;
179
+
180
+ if (leave) stack.push_back (FNode{fnode.node , true });
181
+ const std::vector<BriefNode *> iter_nodes =
182
+ reverse == true ? fnode.node ->inlinks : fnode.node ->outlinks ;
183
+ for (const BriefNode *node : iter_nodes) {
184
+ if (!visited.count (node)) {
185
+ stack.push_back (FNode{node, false });
186
+ }
187
+ }
188
+ }
189
+ }
190
+
77
191
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs () {
192
+ // Run the Extract algorithm to find all subgraphs.
78
193
std::vector<Node *> marked_nodes;
194
+ // We use brief_node_map to represent the original graph in order to avoid
195
+ // changing the original graph.
196
+ std::unordered_map<int , BriefNode *> brief_node_map;
197
+
79
198
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS ()) {
199
+ brief_node_map[node.id ()] = new BriefNode (&node);
80
200
if (node.attr (kMarkerAttrName ).Bool ()) {
81
201
marked_nodes.push_back (&node);
82
202
}
83
203
}
204
+
84
205
// extract sub-graphs in the marked node set, use Union Find algorithm.
85
206
node_map_t node_map; // id to ptr
86
207
for (auto *n : marked_nodes) {
87
208
// n's parent == n.id means it is the ancestor
88
209
n->attr (kUnionFindParent ).Int32 () = n->id ();
89
210
node_map[n->id ()] = n;
90
211
}
91
- std::unordered_set<Node *> visited;
92
- for (auto *n : marked_nodes) {
93
- for (auto *out : n->outlinks ) {
94
- if (node_map.count (out->id ())) {
95
- UnionFindCombine (node_map, n->id (), out->id ());
212
+
213
+ // create breif node map
214
+ for (auto &itr : brief_node_map) {
215
+ for (Node *node : itr.second ->node ->inlinks ) {
216
+ itr.second ->inlinks .push_back (brief_node_map[node->id ()]);
217
+ }
218
+
219
+ for (Node *node : itr.second ->node ->outlinks ) {
220
+ itr.second ->outlinks .push_back (brief_node_map[node->id ()]);
221
+ }
222
+ }
223
+
224
+ for (auto &itr : brief_node_map) {
225
+ BriefNode *brief_node = itr.second ;
226
+
227
+ if (!brief_node->node ->attr (kMarkerAttrName ).Bool ()) {
228
+ VLOG (4 ) << brief_node->node ->id () << " node not a trt candicate." ;
229
+ continue ;
230
+ }
231
+
232
+ // Our algorithm must guarantee that:
233
+ // 1. The graph is always directed acyclic graph(DAG).
234
+ // 2. If there is a path in the subgraph from X to Y (X and Y are both
235
+ // nodes in the subgraph), then all paths from X to Y are in the
236
+ // subgraph.
237
+ //
238
+ // In order to achieve the above guarantee.
239
+ // For adjacent nodes src -> dst.
240
+ // 1. Get all dst input nodes except src.
241
+ // 2. Reverse DFS from those input nodes
242
+ // 3. If there is a path from input nodes to src,
243
+ // then the src and dst nodes can not be fused into one node,
244
+ // otherwise it can be done.
245
+
246
+ while (true ) {
247
+ std::unordered_set<BriefNode *> contract_nodes;
248
+ for (auto *out : brief_node->outlinks ) {
249
+ // must be an trt candidate
250
+ if (!out->node ->attr (kMarkerAttrName ).Bool ()) continue ;
251
+ // get all dst input nodes except src.
252
+ std::vector<BriefNode *> source_nodes;
253
+ for (auto *n : out->inlinks ) {
254
+ if (n != brief_node) {
255
+ source_nodes.push_back (n);
256
+ }
257
+ }
258
+
259
+ // Reverse DFS from the source_nodes.
260
+ bool have_excess_path = false ;
261
+ FlexibleDFS (source_nodes, true , nullptr ,
262
+ [&have_excess_path, brief_node](const BriefNode *n) {
263
+ if (n == brief_node) {
264
+ have_excess_path = true ;
265
+ return false ;
266
+ }
267
+ return true ;
268
+ });
269
+ if (have_excess_path) continue ;
270
+ contract_nodes.insert (out);
271
+ }
272
+ if (contract_nodes.empty ()) break ;
273
+
274
+ for (auto dst_node : contract_nodes) {
275
+ UnionFindCombine (node_map, brief_node->node ->id (),
276
+ dst_node->node ->id ());
277
+ UnionContractedNodes (brief_node_map, brief_node->node ->id (),
278
+ dst_node->node ->id ());
96
279
}
97
280
}
98
281
}
@@ -128,6 +311,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
128
311
auto io = ExtractInputAndOutputOfSubGraph (subgraph);
129
312
block_node->inlinks = std::move (io.first );
130
313
block_node->outlinks = std::move (io.second );
314
+
131
315
for (auto *node : subgraph) {
132
316
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
133
317
// pass.
0 commit comments