Skip to content

Commit 299cd24

Browse files
authored
Merge pull request #617 from NVIDIA/arvind/loop_fallback
Loop Fallback
2 parents a1180ce + 0b3cf89 commit 299cd24

File tree

8 files changed

+152
-9
lines changed

8 files changed

+152
-9
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
326326
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
327327
for (auto bn : n->blocks()[0]->nodes()) {
328328
if (bn->kind() == torch::jit::prim::Loop) {
329-
EvaluateLoopBlock(ctx, n);
329+
EvaluateLoopBlock(ctx, bn);
330330
} else if (bn->kind() == torch::jit::prim::If) {
331331
EvaluateConditionalBlock(ctx, bn, true);
332332
} else {

core/lowering/passes/module_fallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,4 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
133133
} // namespace passes
134134
} // namespace lowering
135135
} // namespace core
136-
} // namespace trtorch
136+
} // namespace trtorch

core/lowering/passes/unpack_var.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
4242
torch::jit::SubgraphRewriter var_rewriter;
4343
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
4444
var_rewriter.runOnGraph(graph);
45-
LOG_DEBUG("Post unpack var: " << *graph);
45+
LOG_GRAPH("Post unpack var: " << *graph);
4646
}
4747

4848
} // namespace passes

core/partitioning/partitioning.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <queue>
44
#include "core/conversion/conversion.h"
5+
#include "core/conversion/evaluators/evaluators.h"
56
#include "core/partitioning/shape_analysis.h"
67
#include "torch/csrc/jit/passes/constant_pooling.h"
78
#include "torch/csrc/jit/passes/dead_code_elimination.h"
@@ -114,7 +115,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
114115
pytorch_nodes.push_back(n);
115116
prev_non_tensor_outputs = containNonTensorOutputs(n);
116117
} else {
117-
// If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
118+
// If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
118119
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
119120
if (!pytorch_nodes.empty()) {
120121
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
@@ -131,6 +132,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
131132
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
132133
}
133134
}
135+
134136
return std::move(new_seg_blocks);
135137
}
136138

@@ -158,6 +160,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
158160
}
159161
}
160162

163+
// For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block
164+
// that has/produces it.
161165
for (auto& use : usage_counts) {
162166
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
163167
if (segmented_blocks[i].contain_raw_value(use.first)) {
@@ -177,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
177181
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
178182
// TRTorch doesn't support non-tensor inputs for a module.
179183
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
180-
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
181-
segmented_blocks.insert(
182-
segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
184+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
185+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
183186
updated_segments.insert(first_torch_id);
184187
}
185188
}
@@ -258,6 +261,20 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
258261
return;
259262
}
260263

264+
bool checkLoopEvaluatable(torch::jit::Node* n) {
265+
bool compile_to_trt = true;
266+
for (auto bn : n->blocks()[0]->nodes()) {
267+
if (bn->kind() == torch::jit::prim::Loop) {
268+
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
269+
} else if (bn->kind() == torch::jit::prim::If) {
270+
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
271+
} else {
272+
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
273+
}
274+
}
275+
return compile_to_trt;
276+
}
277+
261278
std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
262279
auto min_block_size = partition_info.min_block_size;
263280
std::unordered_set<std::string> forced_fallback_operators(
@@ -298,6 +315,17 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
298315
}
299316
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
300317
continue;
318+
} else if (n->kind() == torch::jit::prim::Loop) {
319+
if (!pytorch_nodes.empty()) {
320+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
321+
pytorch_nodes.clear();
322+
}
323+
if (checkLoopEvaluatable(n)) {
324+
tensorrt_nodes.push_back(n);
325+
} else {
326+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
327+
}
328+
continue;
301329
}
302330
pytorch_nodes.push_back(n);
303331
}

core/partitioning/shape_analysis.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ void getSegmentsOutputByRunning(
5656
for (auto& input : seg_block.raw_inputs()) {
5757
TRTORCH_CHECK(
5858
ivalues_maps.count(input),
59-
"Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n");
59+
"Could not find torch::jit::Value* " << input->debugName() << " produced from "
60+
<< util::node_info(input->node())
61+
<< " in lowering graph for mini graph input.\n");
6062
if (input->node()->kind() == torch::jit::prim::Param) {
6163
jit_inputs_ivalues.push_back(ivalues_maps[input]);
6264
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {

tests/core/partitioning/BUILD

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ filegroup(
1111
name = "jit_models",
1212
srcs = ["//tests/modules:resnet50_traced.jit.pt",
1313
"//tests/modules:mobilenet_v2_traced.jit.pt",
14-
"//tests/modules:conditional_scripted.jit.pt"]
14+
"//tests/modules:conditional_scripted.jit.pt",
15+
"//tests/modules:loop_fallback_eval_scripted.jit.pt",
16+
"//tests/modules:loop_fallback_no_eval_scripted.jit.pt"]
1517
)
1618

1719
partitioning_test(
@@ -46,6 +48,22 @@ cc_test(
4648
]
4749
)
4850

51+
cc_test(
52+
name = "test_loop_fallback",
53+
srcs = ["test_loop_fallback.cpp"],
54+
deps = [
55+
"//tests/util",
56+
"//core",
57+
"@googletest//:gtest_main",
58+
] + select({
59+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
60+
"//conditions:default": ["@libtorch//:libtorch"],
61+
}),
62+
data = [
63+
":jit_models"
64+
]
65+
)
66+
4967
cc_test(
5068
name = "test_conditionals",
5169
srcs = ["test_conditionals.cpp"],
@@ -70,6 +88,7 @@ test_suite(
7088
":test_tensorrt_conversion",
7189
":test_stitched_graph",
7290
":test_fallback_graph_output",
91+
":test_loop_fallback",
7392
":test_conditionals"
7493
]
7594
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/script.h"
7+
8+
TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) {
9+
torch::jit::script::Module mod;
10+
try {
11+
mod = torch::jit::load("tests/modules/loop_fallback_eval_scripted.jit.pt");
12+
} catch (const c10::Error& e) {
13+
std::cerr << "error loading the model\n";
14+
return;
15+
}
16+
17+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
18+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
19+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
20+
for (auto in_shape : input_shapes) {
21+
auto in = at::randint(5, in_shape, {at::kCUDA});
22+
jit_inputs_ivalues.push_back(in.clone());
23+
trt_inputs_ivalues.push_back(in.clone());
24+
}
25+
26+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
27+
trtorch::core::CompileSpec cfg(input_ranges);
28+
cfg.partition_info.enabled = true;
29+
30+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
31+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
32+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
33+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
34+
}
35+
36+
TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
37+
torch::jit::script::Module mod;
38+
try {
39+
mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
40+
} catch (const c10::Error& e) {
41+
std::cerr << "error loading the model\n";
42+
return;
43+
}
44+
45+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
46+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
47+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
48+
for (auto in_shape : input_shapes) {
49+
auto in = at::randint(5, in_shape, {at::kCUDA});
50+
jit_inputs_ivalues.push_back(in.clone());
51+
trt_inputs_ivalues.push_back(in.clone());
52+
}
53+
54+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
55+
trtorch::core::CompileSpec cfg(input_ranges);
56+
cfg.partition_info.enabled = true;
57+
58+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
59+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
60+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
61+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
62+
}

tests/modules/hub.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,38 @@ def forward(self, x):
129129
torch.jit.save(module_fallback_script_model, "module_fallback_scripted.jit.pt")
130130

131131

132+
# Sample Looping Modules (for loop fallback testing)
133+
class LoopFallbackEval(nn.Module):
134+
135+
def __init__(self):
136+
super(LoopFallbackEval, self).__init__()
137+
138+
def forward(self, x):
139+
add_list = torch.empty(0).to(x.device)
140+
for i in range(x.shape[1]):
141+
add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0)
142+
return x + add_list
143+
144+
145+
class LoopFallbackNoEval(nn.Module):
146+
147+
def __init__(self):
148+
super(LoopFallbackNoEval, self).__init__()
149+
150+
def forward(self, x):
151+
for _ in range(x.shape[1]):
152+
x = x + torch.ones_like(x)
153+
return x
154+
155+
156+
loop_fallback_eval_model = LoopFallbackEval().eval().cuda()
157+
loop_fallback_eval_script_model = torch.jit.script(loop_fallback_eval_model)
158+
torch.jit.save(loop_fallback_eval_script_model, "loop_fallback_eval_scripted.jit.pt")
159+
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
160+
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
161+
torch.jit.save(loop_fallback_no_eval_script_model, "loop_fallback_no_eval_scripted.jit.pt")
162+
163+
132164
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
133165
class FallbackIf(torch.nn.Module):
134166

0 commit comments

Comments
 (0)