Skip to content

Commit 6833017

Browse files
committed
Add conditional logic
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 78fee01 commit 6833017

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

core/partitioning/partitioning.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,13 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
259259
return;
260260
}
261261

262-
bool check_if_loop_evaluatable(const torch::jit::Node* n);
263-
bool check_if_loop_evaluatable(const torch::jit::Node* n) {
262+
bool checkLoopEvaluatable(torch::jit::Node* n) {
264263
bool compile_to_trt = true;
265264
for (auto bn : n->blocks()[0]->nodes()) {
266265
if (bn->kind() == torch::jit::prim::Loop) {
267-
compile_to_trt = compile_to_trt && check_if_loop_evaluatable(bn);
266+
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
268267
} else if (bn->kind() == torch::jit::prim::If) {
269-
return false;
268+
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
270269
} else {
271270
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
272271
}
@@ -319,10 +318,10 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
319318
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
320319
pytorch_nodes.clear();
321320
}
322-
if (check_if_loop_evaluatable(n)) {
323-
tensorrt_nodes.push_back(n);
321+
if (checkLoopEvaluatable(n)) {
322+
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, std::vector<torch::jit::Node*>{n});
324323
} else {
325-
pytorch_nodes.push_back(n);
324+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
326325
}
327326
continue;
328327
}

0 commit comments

Comments
 (0)