Skip to content

Commit 78fee01

Browse files
committed
Add loop eval check in partition logic
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 15e6863 commit 78fee01

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
326326
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
327327
for (auto bn : n->blocks()[0]->nodes()) {
328328
if (bn->kind() == torch::jit::prim::Loop) {
329-
EvaluateLoopBlock(ctx, n);
329+
EvaluateLoopBlock(ctx, bn);
330330
} else if (bn->kind() == torch::jit::prim::If) {
331331
EvaluateConditionalBlock(ctx, bn, true);
332332
} else {

core/partitioning/partitioning.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <queue>
44
#include "core/conversion/conversion.h"
5+
#include "core/conversion/evaluators/evaluators.h"
56
#include "core/partitioning/shape_analysis.h"
67
#include "torch/csrc/jit/passes/constant_pooling.h"
78
#include "torch/csrc/jit/passes/dead_code_elimination.h"
@@ -258,6 +259,21 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
258259
return;
259260
}
260261

262+
bool check_if_loop_evaluatable(const torch::jit::Node* n);
263+
bool check_if_loop_evaluatable(const torch::jit::Node* n) {
264+
bool compile_to_trt = true;
265+
for (auto bn : n->blocks()[0]->nodes()) {
266+
if (bn->kind() == torch::jit::prim::Loop) {
267+
compile_to_trt = compile_to_trt && check_if_loop_evaluatable(bn);
268+
} else if (bn->kind() == torch::jit::prim::If) {
269+
return false;
270+
} else {
271+
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
272+
}
273+
}
274+
return compile_to_trt;
275+
}
276+
261277
std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
262278
auto min_block_size = partition_info.min_block_size;
263279
std::unordered_set<std::string> forced_fallback_operators(
@@ -298,6 +314,17 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
298314
}
299315
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
300316
continue;
317+
} else if (n->kind() == torch::jit::prim::Loop) {
318+
if (!pytorch_nodes.empty()) {
319+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
320+
pytorch_nodes.clear();
321+
}
322+
if (check_if_loop_evaluatable(n)) {
323+
tensorrt_nodes.push_back(n);
324+
} else {
325+
pytorch_nodes.push_back(n);
326+
}
327+
continue;
301328
}
302329
pytorch_nodes.push_back(n);
303330
}

0 commit comments

Comments
 (0)