@@ -74,25 +74,200 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
74
74
node_map.at (b)->attr (kUnionFindParent ).Int32 () = a_ancestor;
75
75
}
76
76
77
+ // This is a simple representation of a graph.
78
+ // The BriefNode hold the pointer of the Node.
79
+ // This is to avoid changing the original graph
80
+ // in the process of trt graph analysis.
81
+ struct BriefNode {
82
+ explicit BriefNode (Node *n) { node = n; }
83
+ Node *node;
84
+ std::vector<BriefNode *> inlinks;
85
+ std::vector<BriefNode *> outlinks;
86
+ };
87
+
88
+ void UnionContractedNodes (const std::unordered_map<int , BriefNode *> &node_map,
89
+ int src_id, int dst_id) {
90
+ // merge the two adjacent nodes into one node.
91
+ BriefNode *src_node = node_map.at (src_id);
92
+ BriefNode *dst_node = node_map.at (dst_id);
93
+
94
+ std::unordered_set<BriefNode *> inputs (src_node->inlinks .begin (),
95
+ src_node->inlinks .end ());
96
+ std::unordered_set<BriefNode *> outputs;
97
+
98
+ for (auto *n : src_node->outlinks ) {
99
+ if (n != dst_node) outputs.insert (n);
100
+ }
101
+
102
+ // Add the inlinks and outlinks of dst node to src node.
103
+ std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks ;
104
+ for (BriefNode *node : dst_in_nodes) {
105
+ if (node != src_node) {
106
+ inputs.insert (node);
107
+ }
108
+ }
109
+
110
+ std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks ;
111
+ for (BriefNode *node : dst_out_nodes) {
112
+ outputs.insert (node);
113
+ }
114
+
115
+ // update the dst and src node's inlinks and outlinks.
116
+ src_node->inlinks =
117
+ std::move (std::vector<BriefNode *>(inputs.begin (), inputs.end ()));
118
+ src_node->outlinks =
119
+ std::move (std::vector<BriefNode *>(outputs.begin (), outputs.end ()));
120
+ dst_node->inlinks .clear ();
121
+ dst_node->outlinks .clear ();
122
+
123
+ auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
124
+ for (auto *&n : nodes) {
125
+ if (n == src_node || n == dst_node) {
126
+ n = src_node;
127
+ }
128
+ }
129
+ };
130
+ // Change all the dst inputs and outputs corresponding inlink and
131
+ // outlink to the src node.
132
+ for (auto *node : src_node->inlinks ) {
133
+ inlink_or_outlink_cleaner (node->outlinks );
134
+ }
135
+
136
+ for (auto *node : src_node->outlinks ) {
137
+ inlink_or_outlink_cleaner (node->inlinks );
138
+ }
139
+ }
140
+
141
+ // FlexibleDfS
142
+ // If reverse is true, do reverse dfs.
143
+ // If enter func is not nullptr, calls enter(node) before visiting any children
144
+ // of node.
145
+ // If leave func not nullptr, calls leave(node) after visiting all parents of
146
+ // node.
147
+ void FlexibleDFS (const std::vector<BriefNode *> &source, bool reverse,
148
+ const std::function<bool (const BriefNode *)> &enter,
149
+ const std::function<bool(const BriefNode *)> &leave) {
150
+ typedef struct {
151
+ const BriefNode *node;
152
+ bool leave;
153
+ } FNode;
154
+
155
+ std::vector<FNode> stack;
156
+ for (auto &node : source) {
157
+ stack.push_back (FNode{node, false });
158
+ }
159
+ std::unordered_set<const BriefNode *> visited;
160
+ while (!stack.empty ()) {
161
+ auto fnode = stack.back ();
162
+ stack.pop_back ();
163
+
164
+ if (fnode.leave ) {
165
+ if (leave && !leave (fnode.node )) return ;
166
+ }
167
+ if (visited.count (fnode.node )) continue ;
168
+ visited.insert (fnode.node );
169
+
170
+ if (enter && !enter (fnode.node )) return ;
171
+
172
+ if (leave) stack.push_back (FNode{fnode.node , true });
173
+ const std::vector<BriefNode *> iter_nodes =
174
+ reverse == true ? fnode.node ->inlinks : fnode.node ->outlinks ;
175
+ for (const BriefNode *node : iter_nodes) {
176
+ if (!visited.count (node)) {
177
+ stack.push_back (FNode{node, false });
178
+ }
179
+ }
180
+ }
181
+ }
182
+
77
183
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs () {
184
+ // Run the Extract algorithm to find all subgraphs.
78
185
std::vector<Node *> marked_nodes;
186
+ // We use brief_node_map to represent the original graph in order to avoid
187
+ // changing the original graph.
188
+ std::unordered_map<int , BriefNode *> brief_node_map;
189
+
79
190
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS ()) {
191
+ brief_node_map[node.id ()] = new BriefNode (&node);
80
192
if (node.attr (kMarkerAttrName ).Bool ()) {
81
193
marked_nodes.push_back (&node);
82
194
}
83
195
}
196
+
84
197
// extract sub-graphs in the marked node set, use Union Find algorithm.
85
198
node_map_t node_map; // id to ptr
86
199
for (auto *n : marked_nodes) {
87
200
// n's parent == n.id means it is the ancestor
88
201
n->attr (kUnionFindParent ).Int32 () = n->id ();
89
202
node_map[n->id ()] = n;
90
203
}
91
- std::unordered_set<Node *> visited;
92
- for (auto *n : marked_nodes) {
93
- for (auto *out : n->outlinks ) {
94
- if (node_map.count (out->id ())) {
95
- UnionFindCombine (node_map, n->id (), out->id ());
204
+
205
+ // create breif node map
206
+ for (auto &itr : brief_node_map) {
207
+ for (Node *node : itr.second ->node ->inlinks ) {
208
+ itr.second ->inlinks .push_back (brief_node_map[node->id ()]);
209
+ }
210
+
211
+ for (Node *node : itr.second ->node ->outlinks ) {
212
+ itr.second ->outlinks .push_back (brief_node_map[node->id ()]);
213
+ }
214
+ }
215
+
216
+ for (auto &itr : brief_node_map) {
217
+ BriefNode *brief_node = itr.second ;
218
+
219
+ if (!brief_node->node ->attr (kMarkerAttrName ).Bool ()) {
220
+ VLOG (4 ) << brief_node->node ->id () << " node not a trt candicate." ;
221
+ continue ;
222
+ }
223
+
224
+ // Our algorithm must guarantee that:
225
+ // 1. The graph is always directed acyclic graph(DAG).
226
+ // 2. If there is a path in the subgraph from X to Y (X and Y are both
227
+ // nodes
228
+ // in the subgraph), then all paths from X to Y are in the subgraph.
229
+ //
230
+ // In order to achieve the above guarantee.
231
+ // For adjacent nodes src -> dst.
232
+ // 1. Get all dst input nodes except src.
233
+ // 2. Reverse DFS from those input nodes
234
+ // 3. If there is a path from input nodes to src,
235
+ // then the src and dst nodes can not be fused into one node,
236
+ // otherwise it can be done.
237
+
238
+ while (true ) {
239
+ std::unordered_set<BriefNode *> contract_nodes;
240
+ for (auto *out : brief_node->outlinks ) {
241
+ // must be an trt candidate
242
+ if (!out->node ->attr (kMarkerAttrName ).Bool ()) continue ;
243
+ // get all dst input nodes except src.
244
+ std::vector<BriefNode *> source_nodes;
245
+ for (auto *n : out->inlinks ) {
246
+ if (n != brief_node) {
247
+ source_nodes.push_back (n);
248
+ }
249
+ }
250
+
251
+ // Reverse DFS from the source_nodes.
252
+ bool have_excess_path = false ;
253
+ FlexibleDFS (source_nodes, true , nullptr ,
254
+ [&have_excess_path, brief_node](const BriefNode *n) {
255
+ if (n == brief_node) {
256
+ have_excess_path = true ;
257
+ return false ;
258
+ }
259
+ return true ;
260
+ });
261
+ if (have_excess_path) continue ;
262
+ contract_nodes.insert (out);
263
+ }
264
+ if (contract_nodes.empty ()) break ;
265
+
266
+ for (auto dst_node : contract_nodes) {
267
+ UnionFindCombine (node_map, brief_node->node ->id (),
268
+ dst_node->node ->id ());
269
+ UnionContractedNodes (brief_node_map, brief_node->node ->id (),
270
+ dst_node->node ->id ());
96
271
}
97
272
}
98
273
}
@@ -128,6 +303,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
128
303
auto io = ExtractInputAndOutputOfSubGraph (subgraph);
129
304
block_node->inlinks = std::move (io.first );
130
305
block_node->outlinks = std::move (io.second );
306
+
131
307
for (auto *node : subgraph) {
132
308
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
133
309
// pass.
0 commit comments