|
| 1 | + |
| 2 | + |
| 3 | + |
| 4 | + |
1 | 5 | #include <queue>
|
2 | 6 |
|
3 | 7 | #include "torch/csrc/jit/passes/constant_pooling.h"
|
@@ -31,6 +35,136 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
|
31 | 35 | return false;
|
32 | 36 | }
|
33 | 37 |
|
| 38 | + |
| 39 | + |
| 40 | +// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes |
| 41 | +void SetInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) { |
| 42 | + // fallback nodes that produce entire graph's nonTensor output |
| 43 | + for (auto i : block->outputs()) { |
| 44 | + if (!isTensor(i)) { |
| 45 | + ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR); |
| 46 | + } |
| 47 | + } |
| 48 | + |
| 49 | + // fallback nodes that consume entire graph's nonTensor input |
| 50 | + for (auto i : block->inputs()) { |
| 51 | + if (!isTensor(i)) { |
| 52 | + for (auto use : i->uses()) { |
| 53 | + ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR); |
| 54 | + } |
| 55 | + } |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback) |
| 60 | +// we use a map to indicate the reason why it's fallback to torch |
| 61 | +// For any node that's not explicitly fallback, we set it to run in TensorRT for now |
| 62 | +void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) { |
| 63 | + auto nodes = block->nodes(); |
| 64 | + const auto to_compile_sym = c10::Symbol::attr("to_compile"); |
| 65 | + |
| 66 | + for (const auto n : nodes) { |
| 67 | + if (n->kind() == torch::jit::prim::Constant) { |
| 68 | + continue; |
| 69 | + } |
| 70 | + |
| 71 | + if (!conversion::OpSupported(n)) { |
| 72 | + // If the op is not supported by the conversion phase it should run in PyTorch |
| 73 | + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED); |
| 74 | + } else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) { |
| 75 | + // If the user specifies the op to run in Torch it should run in PyTorch |
| 76 | + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kOPERATOR_FALLBACK); |
| 77 | + } else if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { |
| 78 | + // If the user specifies the module containing this op to run in torch it should run in PyTorch |
| 79 | + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kMODULE_FALLBACK); |
| 80 | + } else { |
| 81 | + // Set the rest nodes to TensorRt |
| 82 | + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT); |
| 83 | + } |
| 84 | + } |
| 85 | + return; |
| 86 | +} |
| 87 | + |
| 88 | +// For a given set of fallback nodes, check their inputs/outputs, if any inputs/outputs of them are NonTensor, |
| 89 | +// then the nodes that produces/consumes those values should also fallback |
| 90 | +void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) { |
| 91 | + // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function |
| 92 | + std::queue<torch::jit::Node*> q; |
| 93 | + for (auto& node : initial_fallback_nodes) { |
| 94 | + q.push(node.first); |
| 95 | + } |
| 96 | + |
| 97 | + while (!q.empty()) { |
| 98 | + auto cur_node = q.front(); |
| 99 | + q.pop(); |
| 100 | + // for every node that produces this fallback node's NonTensor input, they should fallback too |
| 101 | + for (auto input : cur_node->inputs()) { |
| 102 | + if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && |
| 103 | + ctx->shouldNodeRunInTensorRT(input->node())) { |
| 104 | + ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR); |
| 105 | + q.push(input->node()); |
| 106 | + } |
| 107 | + } |
| 108 | + // for every node that consumes this fallback node's NonTensor output, they should fallback too |
| 109 | + for (auto output : cur_node->outputs()) { |
| 110 | + if (!isTensor(output)) { |
| 111 | + for (auto use : output->uses()) { |
| 112 | + auto node = use.user; |
| 113 | + if (node->kind() != torch::jit::prim::Constant && |
| 114 | + ctx->shouldNodeRunInTensorRT(node)) { |
| 115 | + ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR); |
| 116 | + q.push(node); |
| 117 | + } |
| 118 | + } |
| 119 | + } |
| 120 | + } |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size |
| 125 | +std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) { |
| 126 | + auto nodes = block->nodes(); |
| 127 | + std::vector<torch::jit::Node*> cur_trt_nodes; |
| 128 | + std::vector<torch::jit::Node*> min_block_fallback_nodes; |
| 129 | + for (const auto n : nodes) { |
| 130 | + if (n->kind() == torch::jit::prim::Constant) { |
| 131 | + continue; |
| 132 | + } |
| 133 | + |
| 134 | + // check if current node fallback or not |
| 135 | + if (!ctx->shouldNodeRunInTorch(n)) { |
| 136 | + cur_trt_nodes.push_back(n); |
| 137 | + } else { |
| 138 | + if (cur_trt_nodes.size() < ctx->settings.min_block_size) { |
| 139 | + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); |
| 140 | + } |
| 141 | + cur_trt_nodes.clear(); |
| 142 | + } |
| 143 | + } |
| 144 | + if (cur_trt_nodes.size() < ctx->settings.min_block_size) { |
| 145 | + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); |
| 146 | + } |
| 147 | + return min_block_fallback_nodes; |
| 148 | +} |
| 149 | + |
| 150 | + |
| 151 | +// Set the nodes that fallback because of min_block_size |
| 152 | +void SetMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) { |
| 153 | + // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement |
| 154 | + auto min_block_fallback_nodes = TraverseNodesForMinBlockSize(ctx, block); |
| 155 | + |
| 156 | + // keep fallback until all segments meet the min_block_size requirement |
| 157 | + while (!min_block_fallback_nodes.empty()) { |
| 158 | + for (const auto i : min_block_fallback_nodes) { |
| 159 | + ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK); |
| 160 | + } |
| 161 | + // find the fallback nodes because of dependency with min_block_size caused fallback nodes |
| 162 | + SetNonTensorConnectedNodes(ctx, min_block_fallback_nodes); |
| 163 | + // keep traverse the graph until there is no node fallback because of min_block_size |
| 164 | + min_block_fallback_nodes = TraverseNodesForMinBlockSize(ctx, block); |
| 165 | + } |
| 166 | +} |
| 167 | + |
34 | 168 | bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
|
35 | 169 | const torch::jit::FunctionSchema* schema = node->maybeSchema();
|
36 | 170 | if (!schema) {
|
@@ -97,62 +231,6 @@ std::vector<torch::jit::Node*> getDependencyNodes(
|
97 | 231 | return stk;
|
98 | 232 | }
|
99 | 233 |
|
100 |
| -// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related |
101 |
| -// nodes |
102 |
| -void fallback_graph_nontensor_in_out(PartitioningCtx* ctx, torch::jit::Block* block) { |
103 |
| - // fallback nodes that produce entire graph's nonTensor output |
104 |
| - for (auto i : block->outputs()) { |
105 |
| - if (!isTensor(i)) { |
106 |
| - ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR); |
107 |
| - } |
108 |
| - } |
109 |
| - |
110 |
| - // fallback nodes that consume entire graph's nonTensor input |
111 |
| - for (auto i : block->inputs()) { |
112 |
| - if (!isTensor(i)) { |
113 |
| - for (auto use : i->uses()) { |
114 |
| - ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR); |
115 |
| - } |
116 |
| - } |
117 |
| - } |
118 |
| -} |
119 |
| - |
120 |
| -void find_all_fallback_nodes(PartitioningCtx* ctx, NodeExecutorDecisionMap& initial_fallback_nodes) { |
121 |
| - // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function |
122 |
| - // global_fallback_nodes are the fallback nodes that we maintain globally |
123 |
| - std::queue<torch::jit::Node*> q; |
124 |
| - for (auto& node : initial_fallback_nodes) { |
125 |
| - q.push(node.first); |
126 |
| - } |
127 |
| - |
128 |
| - std::unordered_set<torch::jit::Node*> visited_nodes; |
129 |
| - while (!q.empty()) { |
130 |
| - auto cur_node = q.front(); |
131 |
| - q.pop(); |
132 |
| - // for every node that produces this fallback node's NonTensor input, they should fallback too |
133 |
| - for (auto input : cur_node->inputs()) { |
134 |
| - // NOTE: This does not make sense, does this rely on shortciruiting to work right? |
135 |
| - if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && |
136 |
| - ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR)) { |
137 |
| - q.push(input->node()); |
138 |
| - } |
139 |
| - } |
140 |
| - // for every node that consumes this fallback node's NonTensor output, they should fallback too |
141 |
| - for (auto output : cur_node->outputs()) { |
142 |
| - if (!isTensor(output)) { |
143 |
| - for (auto use : output->uses()) { |
144 |
| - auto node = use.user; |
145 |
| - // NOTE: This does not make sense, does this rely on shortciruiting to work right? |
146 |
| - if (node->kind() != torch::jit::prim::Constant && |
147 |
| - ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR)) { |
148 |
| - q.push(node); |
149 |
| - } |
150 |
| - } |
151 |
| - } |
152 |
| - } |
153 |
| - } |
154 |
| -} |
155 |
| - |
156 | 234 | void resolveTRTNonTensorInputs(PartitioningCtx* ctx) {
|
157 | 235 | // if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
|
158 | 236 | // because we have already found the interface between Torch and TRT in segmentation phase
|
@@ -250,102 +328,10 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
|
250 | 328 | return compile_to_trt;
|
251 | 329 | }
|
252 | 330 |
|
253 |
| -// use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback) |
254 |
| -// we use a map to indicate the reason why it's fallback to torch |
255 |
| -void get_fallback_nodes(PartitioningCtx* ctx, torch::jit::Block* block) { |
256 |
| - auto nodes = block->nodes(); |
257 |
| - for (const auto n : nodes) { |
258 |
| - if (n->kind() == torch::jit::prim::Constant) { |
259 |
| - continue; |
260 |
| - } |
261 |
| - |
262 |
| - // If the op is not supported by the conversion phase it should run in PyTorch |
263 |
| - if (!conversion::OpSupported(n)) { |
264 |
| - ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED); |
265 |
| - } |
266 |
| - |
267 |
| - // If the user specifies the op to run in Torch it should run in PyTorch |
268 |
| - if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) { |
269 |
| - ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kOPERATOR_FALLBACK); |
270 |
| - } |
271 |
| - |
272 |
| - // If the user specifies the module containing this op to run in torch it should run in PyTorch |
273 |
| - const auto to_compile_sym = c10::Symbol::attr("to_compile"); |
274 |
| - if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { |
275 |
| - ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kMODULE_FALLBACK); |
276 |
| - } |
277 |
| - } |
278 |
| - return; |
279 |
| -} |
280 |
| - |
281 |
| -std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(PartitioningCtx* ctx, torch::jit::Block* block) { |
282 |
| - auto nodes = block->nodes(); |
283 |
| - std::vector<torch::jit::Node*> cur_trt_nodes; |
284 |
| - std::vector<torch::jit::Node*> min_block_fallback_nodes; |
285 |
| - for (const auto n : nodes) { |
286 |
| - if (n->kind() == torch::jit::prim::Constant) { |
287 |
| - continue; |
288 |
| - } |
289 |
| - |
290 |
| - // check if current node fallback or not |
291 |
| - if (!ctx->shouldNodeRunInTorch(n)) { |
292 |
| - // if this node is not in fallback nodes, then it's in trt segments |
293 |
| - cur_trt_nodes.push_back(n); |
294 |
| - } else { |
295 |
| - if (cur_trt_nodes.size() < ctx->settings.min_block_size) { |
296 |
| - min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); |
297 |
| - } |
298 |
| - cur_trt_nodes.clear(); |
299 |
| - } |
300 |
| - } |
301 |
| - if (cur_trt_nodes.size() < ctx->settings.min_block_size) { |
302 |
| - min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); |
303 |
| - } |
304 |
| - return min_block_fallback_nodes; |
305 |
| -} |
306 |
| - |
307 |
| -void find_min_block_size_fallback_nodes(PartitioningCtx* ctx, torch::jit::Block* block) { |
308 |
| - // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement |
309 |
| - auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(ctx, block); |
310 |
| - NodeExecutorDecisionMap initial_fallback_nodes; |
311 |
| - |
312 |
| - // keep fallback until all segments meet the min_block_size requirement |
313 |
| - while (!min_block_fallback_nodes.empty()) { |
314 |
| - for (const auto i : min_block_fallback_nodes) { |
315 |
| - initial_fallback_nodes.insert({i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK}); |
316 |
| - ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK); |
317 |
| - } |
318 |
| - // find the fallback nodes because of dependency with min_block_size caused fallback nodes |
319 |
| - find_all_fallback_nodes(ctx, initial_fallback_nodes); |
320 |
| - // keep traverse the graph until there is no node fallback because of min_block_size |
321 |
| - min_block_fallback_nodes = traverse_nodes_for_min_block_size(ctx, block); |
322 |
| - } |
323 |
| -} |
324 |
| - |
325 |
| -void segment_graph(PartitioningCtx* ctx, torch::jit::Block* block) { |
326 |
| - // get the initial fallback nodes (nodes that are unsupported or forced fallback) |
327 |
| - get_fallback_nodes(ctx, block); |
328 |
| - |
329 |
| - // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this |
330 |
| - // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node |
331 |
| - // that produces this input should also fallback |
332 |
| - // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported |
333 |
| - find_all_fallback_nodes(ctx, ctx->node_executor_decision_map); |
334 |
| - |
335 |
| - // find all fallback nodes because of the min_block_size requirement |
336 |
| - find_min_block_size_fallback_nodes(ctx, block); |
337 | 331 |
|
| 332 | +void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { |
338 | 333 | auto nodes = block->nodes();
|
339 | 334 |
|
340 |
| - // NOTE: Realize this may be redundant, but will let us have an explicit state for each node. Maybe there is a better |
341 |
| - // way for (auto n : nodes) { |
342 |
| - // if (!ctx->shouldNodeRunInTorch(n) && !ctx->isNodeExecutorKnown(n)) { |
343 |
| - // if (conversion::OpSupported(n)) { |
344 |
| - // ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT); |
345 |
| - // } |
346 |
| - // } |
347 |
| - // } |
348 |
| - |
349 | 335 | // segment the nodes
|
350 | 336 | std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
|
351 | 337 | for (const auto n : nodes) {
|
@@ -420,18 +406,41 @@ void segment_graph(PartitioningCtx* ctx, torch::jit::Block* block) {
|
420 | 406 | return;
|
421 | 407 | }
|
422 | 408 |
|
423 |
| -PartitionedGraph partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) { |
| 409 | +void SetNodeExecutorDecision(PartitioningCtx* ctx, torch::jit::Block* block) { |
| 410 | + // First, find all the explicit fallback nodes that should run in Torch: |
| 411 | + // 1. nodes that are unsupported |
| 412 | + // 2. nodes that the user specifies to run in torch |
| 413 | + // 3. nodes that the user specifies the module containing this op to run in torch |
| 414 | + // At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT |
| 415 | + SetExplicitFallbackNodes(ctx, block); |
| 416 | + |
| 417 | + // Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that |
| 418 | + // consume/produce this nonTensor value |
| 419 | + SetInputsOutputsConnectedNodes(ctx, block); |
| 420 | + |
| 421 | + // Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this |
| 422 | + // input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes |
| 423 | + // that consume this output should also fallback |
| 424 | + auto cur_fallback_nodes = ctx->getNodesRunInTorch(); |
| 425 | + SetNonTensorConnectedNodes(ctx, cur_fallback_nodes); |
| 426 | + |
| 427 | + // Finally, check if all current tensorrt blocks satisfy the min_block_size requirement. |
| 428 | + // We need to traverse the whole graph many times here |
| 429 | + SetMinBlockFallbackNodes(ctx, block); |
| 430 | +} |
| 431 | + |
| 432 | + |
| 433 | + |
| 434 | +PartitionedGraph Partition(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) { |
424 | 435 | LOG_DEBUG(ctx->settings);
|
425 |
| - // if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor |
426 |
| - // output |
427 |
| - fallback_graph_nontensor_in_out(ctx, block); |
| 436 | + |
| 437 | + SetNodeExecutorDecision(ctx, block); |
428 | 438 |
|
429 | 439 | // segment lowering global graph into blocks
|
430 | 440 | LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
|
431 |
| - segment_graph(ctx, block); |
| 441 | + SegmentGraph(ctx, block); |
432 | 442 |
|
433 | 443 | // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
|
434 |
| - |
435 | 444 | // resolve nonTensor inputs/outputs
|
436 | 445 | resolveTRTNonTensorInputs(ctx);
|
437 | 446 |
|
|
0 commit comments