|
2 | 2 |
|
3 | 3 | #include <queue>
|
4 | 4 | #include "core/conversion/conversion.h"
|
| 5 | +#include "core/conversion/evaluators/evaluators.h" |
5 | 6 | #include "core/partitioning/shape_analysis.h"
|
6 | 7 | #include "torch/csrc/jit/passes/constant_pooling.h"
|
7 | 8 | #include "torch/csrc/jit/passes/dead_code_elimination.h"
|
@@ -258,6 +259,21 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
|
258 | 259 | return;
|
259 | 260 | }
|
260 | 261 |
|
| 262 | +bool check_if_loop_evaluatable(const torch::jit::Node* n); |
| 263 | +bool check_if_loop_evaluatable(const torch::jit::Node* n) { |
| 264 | + bool compile_to_trt = true; |
| 265 | + for (auto bn : n->blocks()[0]->nodes()) { |
| 266 | + if (bn->kind() == torch::jit::prim::Loop) { |
| 267 | + compile_to_trt = compile_to_trt && check_if_loop_evaluatable(bn); |
| 268 | + } else if (bn->kind() == torch::jit::prim::If) { |
| 269 | + return false; |
| 270 | + } else { |
| 271 | + compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn); |
| 272 | + } |
| 273 | + } |
| 274 | + return compile_to_trt; |
| 275 | +} |
| 276 | + |
261 | 277 | std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
|
262 | 278 | auto min_block_size = partition_info.min_block_size;
|
263 | 279 | std::unordered_set<std::string> forced_fallback_operators(
|
@@ -298,6 +314,17 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
|
298 | 314 | }
|
299 | 315 | segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
|
300 | 316 | continue;
|
| 317 | + } else if (n->kind() == torch::jit::prim::Loop) { |
| 318 | + if (!pytorch_nodes.empty()) { |
| 319 | + segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); |
| 320 | + pytorch_nodes.clear(); |
| 321 | + } |
| 322 | + if (check_if_loop_evaluatable(n)) { |
| 323 | + tensorrt_nodes.push_back(n); |
| 324 | + } else { |
| 325 | + pytorch_nodes.push_back(n); |
| 326 | + } |
| 327 | + continue; |
301 | 328 | }
|
302 | 329 | pytorch_nodes.push_back(n);
|
303 | 330 | }
|
|
0 commit comments