1
-
2
-
3
-
4
-
1
+ #include " core/partitioning/partitioning.h"
5
2
#include < queue>
6
-
7
- #include " torch/csrc/jit/passes/constant_pooling.h"
8
- #include " torch/csrc/jit/passes/dead_code_elimination.h"
9
-
10
3
#include " core/conversion/conversion.h"
11
4
#include " core/conversion/evaluators/evaluators.h"
12
- #include " core/partitioning/partitioning.h"
13
5
#include " core/partitioning/partitioningctx/PartitioningCtx.h"
6
+ #include " torch/csrc/jit/passes/constant_pooling.h"
7
+ #include " torch/csrc/jit/passes/dead_code_elimination.h"
14
8
15
9
namespace torch_tensorrt {
16
10
namespace core {
@@ -35,8 +29,6 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
35
29
return false ;
36
30
}
37
31
38
-
39
-
40
32
// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
41
33
void SetInputsOutputsConnectedNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
42
34
// fallback nodes that produce entire graph's nonTensor output
@@ -91,7 +83,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
91
83
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
92
84
std::queue<torch::jit::Node*> q;
93
85
for (auto & node : initial_fallback_nodes) {
94
- q.push (node. first );
86
+ q.push (node);
95
87
}
96
88
97
89
while (!q.empty ()) {
@@ -110,8 +102,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
110
102
if (!isTensor (output)) {
111
103
for (auto use : output->uses ()) {
112
104
auto node = use.user ;
113
- if (node->kind () != torch::jit::prim::Constant &&
114
- ctx->shouldNodeRunInTensorRT (node)) {
105
+ if (node->kind () != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT (node)) {
115
106
ctx->setNodeExecutorDecision (node, NodeExecutorDecision::kNON_TENSOR );
116
107
q.push (node);
117
108
}
@@ -147,7 +138,6 @@ std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx
147
138
return min_block_fallback_nodes;
148
139
}
149
140
150
-
151
141
// Set the nodes that fallback because of min_block_size
152
142
void SetMinBlockFallbackNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
153
143
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
@@ -328,7 +318,6 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
328
318
return compile_to_trt;
329
319
}
330
320
331
-
332
321
void SegmentGraph (PartitioningCtx* ctx, torch::jit::Block* block) {
333
322
auto nodes = block->nodes ();
334
323
@@ -429,8 +418,6 @@ void SetNodeExecutorDecision(PartitioningCtx* ctx, torch::jit::Block* block) {
429
418
SetMinBlockFallbackNodes (ctx, block);
430
419
}
431
420
432
-
433
-
434
421
PartitionedGraph Partition (PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
435
422
LOG_DEBUG (ctx->settings );
436
423
0 commit comments