2
2
3
3
#include < queue>
4
4
#include " core/conversion/conversion.h"
5
+ #include " core/conversion/evaluators/evaluators.h"
5
6
#include " core/partitioning/shape_analysis.h"
6
7
#include " torch/csrc/jit/passes/constant_pooling.h"
7
8
#include " torch/csrc/jit/passes/dead_code_elimination.h"
@@ -114,7 +115,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
114
115
pytorch_nodes.push_back (n);
115
116
prev_non_tensor_outputs = containNonTensorOutputs (n);
116
117
} else {
117
- // If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes . Construct a
118
+ // If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes . Construct a
118
119
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
119
120
if (!pytorch_nodes.empty ()) {
120
121
new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
@@ -131,6 +132,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
131
132
new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
132
133
}
133
134
}
135
+
134
136
return std::move (new_seg_blocks);
135
137
}
136
138
@@ -158,6 +160,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
158
160
}
159
161
}
160
162
163
+ // For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block
164
+ // that has/produces it.
161
165
for (auto & use : usage_counts) {
162
166
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
163
167
if (segmented_blocks[i].contain_raw_value (use.first )) {
@@ -177,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
177
181
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
178
182
// TRTorch doesn't support non-tensor inputs for a module.
179
183
auto to_inject_blocks = segmentBlocksWithNonTensorInputs (segmented_blocks[first_torch_id]);
180
- segmented_blocks.erase (segmented_blocks.begin () + first_torch_id);
181
- segmented_blocks.insert (
182
- segmented_blocks.begin () + first_torch_id, to_inject_blocks.begin (), to_inject_blocks.end ());
184
+ auto next_iter = segmented_blocks_list.erase (idx_to_iter[first_torch_id]);
185
+ segmented_blocks_list.insert (next_iter, to_inject_blocks.begin (), to_inject_blocks.end ());
183
186
updated_segments.insert (first_torch_id);
184
187
}
185
188
}
@@ -258,6 +261,20 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
258
261
return ;
259
262
}
260
263
264
+ bool checkLoopEvaluatable (torch::jit::Node* n) {
265
+ bool compile_to_trt = true ;
266
+ for (auto bn : n->blocks ()[0 ]->nodes ()) {
267
+ if (bn->kind () == torch::jit::prim::Loop) {
268
+ compile_to_trt = compile_to_trt && checkLoopEvaluatable (bn);
269
+ } else if (bn->kind () == torch::jit::prim::If) {
270
+ compile_to_trt = compile_to_trt && containNonTensorOutputs (bn);
271
+ } else {
272
+ compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime (bn);
273
+ }
274
+ }
275
+ return compile_to_trt;
276
+ }
277
+
261
278
std::vector<SegmentedBlock> segment_graph (torch::jit::Block* block, const PartitionInfo& partition_info) {
262
279
auto min_block_size = partition_info.min_block_size ;
263
280
std::unordered_set<std::string> forced_fallback_operators (
@@ -298,6 +315,17 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
298
315
}
299
316
segmented_blocks.emplace_back (SegmentedBlock::kTorch , std::vector<torch::jit::Node*>{n});
300
317
continue ;
318
+ } else if (n->kind () == torch::jit::prim::Loop) {
319
+ if (!pytorch_nodes.empty ()) {
320
+ segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
321
+ pytorch_nodes.clear ();
322
+ }
323
+ if (checkLoopEvaluatable (n)) {
324
+ tensorrt_nodes.push_back (n);
325
+ } else {
326
+ segmented_blocks.emplace_back (SegmentedBlock::kTorch , std::vector<torch::jit::Node*>{n});
327
+ }
328
+ continue ;
301
329
}
302
330
pytorch_nodes.push_back (n);
303
331
}
0 commit comments