Skip to content

Commit 532efed

Browse files
ArvindSridharnarendasan
authored andcommitted
Add tests
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 5cdaaa5 commit 532efed

File tree

5 files changed

+192
-8
lines changed

5 files changed

+192
-8
lines changed

core/lowering/lowering.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,21 @@ void LowerBlock(torch::jit::Block* b) {
2525
}
2626

2727
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
28-
passes::MarkNodesForFallback(g, false);
29-
passes::UnpackHardSwish(g);
3028
torch::jit::EliminateRedundantGuards(g);
3129
torch::jit::RemoveListMutation(g);
3230
torch::jit::RemoveTensorMutation(g);
3331
torch::jit::CreateFunctionalGraphs(g);
3432
torch::jit::InlineFunctionalGraphs(g);
3533
torch::jit::PeepholeOptimize(g, false);
36-
passes::EliminateExceptionOrPassPattern(g);
3734
torch::jit::FuseLinear(g);
3835
torch::jit::LowerAllTuples(g);
36+
if (!lower_info.disable_cse) {
37+
torch::jit::EliminateCommonSubexpression(g);
38+
}
39+
torch::jit::EliminateDeadCode(g);
40+
passes::MarkNodesForFallback(g, true);
41+
passes::UnpackHardSwish(g);
42+
passes::EliminateExceptionOrPassPattern(g);
3943
passes::ReduceToOperation(g);
4044
passes::RemoveContiguous(g);
4145
passes::RemoveDropout(g);
@@ -44,9 +48,6 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4448
passes::Conv3DToConvolution(g);
4549
passes::FuseAddMMBranches(g);
4650
passes::RemoveBNDimCheck(g);
47-
if (!lower_info.disable_cse) {
48-
torch::jit::EliminateCommonSubexpression(g);
49-
}
5051
// torch::jit::UnrollLoops(g);
5152
passes::UnpackAddMM(g);
5253
// passes::UnpackBatchNorm(g);
@@ -56,8 +57,6 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
5657
passes::RemoveNOPs(g);
5758
passes::AliasOperators(g);
5859
passes::SiluToSigmoidMultipication(g);
59-
torch::jit::EliminateDeadCode(g);
60-
passes::MarkNodesForFallback(g, true);
6160
LOG_GRAPH(*g);
6261
}
6362

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gtest/gtest.h"
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/torch.h"
67

78
TEST(Evaluators, DivIntEvaluatesCorrectly) {
89
const auto graph = R"IR(

tests/core/lowering/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ lowering_test(
1111
name = "test_linear_to_addmm",
1212
)
1313

14+
cc_test(
15+
name = "test_module_level_fallback",
16+
srcs = ["test_module_level_fallback.cpp"],
17+
deps = [
18+
"//tests/util",
19+
"//core",
20+
"@googletest//:gtest_main",
21+
] + select({
22+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
23+
"//conditions:default": ["@libtorch//:libtorch"],
24+
}),
25+
data = [
26+
"//tests/modules:jit_models"
27+
]
28+
)
29+
1430
lowering_test(
1531
name = "test_remove_contiguous_pass",
1632
)
@@ -47,6 +63,7 @@ test_suite(
4763
name = "lowering_tests",
4864
tests = [
4965
":test_linear_to_addmm",
66+
":test_module_level_fallback",
5067
":test_operator_aliasing_pass",
5168
":test_remove_contiguous_pass",
5269
":test_remove_detach_pass",
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "core/compiler.h"
4+
#include "core/lowering/lowering.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/script.h"
8+
9+
TEST(Lowering, LowerResNet18ModuleFallbackCorrectly) {
10+
torch::jit::script::Module mod;
11+
try {
12+
mod = torch::jit::load("tests/modules/resnet18_traced.jit.pt");
13+
} catch (const c10::Error& e) {
14+
std::cerr << "error loading the model\n";
15+
return;
16+
}
17+
18+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
19+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
20+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
21+
for (auto in_shape : input_shapes) {
22+
auto in = at::randint(5, in_shape, {at::kCUDA});
23+
jit_inputs_ivalues.push_back(in.clone());
24+
trt_inputs_ivalues.push_back(in.clone());
25+
}
26+
27+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 3, 224, 224})};
28+
trtorch::core::CompileSpec cfg(input_ranges);
29+
cfg.partition_info.enabled = true;
30+
cfg.lower_info.forced_fallback_modules.push_back("torchvision.models.resnet.BasicBlock");
31+
32+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
33+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
34+
35+
auto g = trt_mod.get_method("forward").graph();
36+
auto nodes = g->block()->nodes();
37+
std::size_t count = 0;
38+
for (const auto n : nodes) {
39+
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
40+
if (has_compile_attribute && n->i(c10::Symbol::attr("to_compile")) == (int64_t) false) {
41+
count++;
42+
}
43+
}
44+
ASSERT_TRUE(count == 62);
45+
46+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
47+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
48+
}
49+
50+
TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
51+
torch::jit::script::Module mod;
52+
try {
53+
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
54+
} catch (const c10::Error& e) {
55+
std::cerr << "error loading the model\n";
56+
return;
57+
}
58+
59+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 1, 16, 16}};
60+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
61+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
62+
for (auto in_shape : input_shapes) {
63+
auto in = at::randint(5, in_shape, {at::kCUDA});
64+
jit_inputs_ivalues.push_back(in.clone());
65+
trt_inputs_ivalues.push_back(in.clone());
66+
}
67+
68+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 1, 16, 16})};
69+
trtorch::core::CompileSpec cfg(input_ranges);
70+
cfg.partition_info.enabled = true;
71+
cfg.lower_info.forced_fallback_modules.push_back("ModuleFallbackSub");
72+
73+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
74+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
75+
76+
auto g = trt_mod.get_method("forward").graph();
77+
auto nodes = g->block()->nodes();
78+
std::size_t curr_node = 0;
79+
for (const auto n : nodes) {
80+
if (curr_node == 5) {
81+
ASSERT_TRUE(n->kind() == torch::jit::aten::conv2d);
82+
ASSERT_TRUE(n->i(c10::Symbol::attr("to_compile")) == (int64_t) false);
83+
} else if (curr_node == 6) {
84+
ASSERT_TRUE(n->kind() == torch::jit::aten::relu);
85+
ASSERT_TRUE(n->i(c10::Symbol::attr("to_compile")) == (int64_t) false);
86+
} else if (curr_node == 7) {
87+
ASSERT_TRUE(n->kind() == torch::jit::prim::GetAttr);
88+
ASSERT_TRUE(n->s(c10::Symbol::attr("name")).find("trt_engine") != std::string::npos);
89+
}
90+
curr_node++;
91+
}
92+
93+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
94+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
95+
}
96+
97+
TEST(Lowering, LowerAndPartitionMobileNetModuleFallbackCorrectly) {
98+
torch::jit::script::Module mod;
99+
try {
100+
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
101+
} catch (const c10::Error& e) {
102+
std::cerr << "error loading the model\n";
103+
return;
104+
}
105+
106+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
107+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
108+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
109+
for (auto in_shape : input_shapes) {
110+
auto in = at::randint(5, in_shape, {at::kCUDA});
111+
jit_inputs_ivalues.push_back(in.clone());
112+
trt_inputs_ivalues.push_back(in.clone());
113+
}
114+
115+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 3, 224, 224})};
116+
trtorch::core::CompileSpec cfg(input_ranges);
117+
cfg.partition_info.enabled = true;
118+
cfg.partition_info.min_block_size = 5;
119+
cfg.lower_info.forced_fallback_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
120+
121+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
122+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
123+
124+
auto g = trt_mod.get_method("forward").graph();
125+
auto nodes = g->block()->nodes();
126+
std::size_t trt_count = 0;
127+
std::size_t fallback_count = 0;
128+
for (const auto n : nodes) {
129+
auto has_name_attribute = n->hasAttribute(c10::Symbol::attr("name"));
130+
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
131+
if (has_name_attribute && n->s(c10::Symbol::attr("name")).find("trt_engine") != std::string::npos) {
132+
trt_count++;
133+
} else if (has_compile_attribute && n->i(c10::Symbol::attr("to_compile")) == (int64_t) false) {
134+
fallback_count++;
135+
}
136+
}
137+
ASSERT_TRUE(trt_count == 1);
138+
ASSERT_TRUE(fallback_count == 105);
139+
140+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
141+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
142+
}

tests/modules/hub.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,31 @@ def forward(self, x):
9797
trace_model = torch.jit.trace(model, x)
9898
torch.jit.save(trace_model, "pooling_traced.jit.pt")
9999

100+
# Sample Nested Module (for module-level fallback testing)
101+
class ModuleFallbackSub(nn.Module):
102+
103+
def __init__(self):
104+
super(ModuleFallbackSub, self).__init__()
105+
self.conv = nn.Conv2d(1, 3, 3)
106+
self.relu = nn.ReLU()
107+
108+
def forward(self, x):
109+
return self.relu(self.conv(x))
110+
111+
class ModuleFallbackMain(nn.Module):
112+
113+
def __init__(self):
114+
super(ModuleFallbackMain, self).__init__()
115+
self.layer1 = ModuleFallbackSub()
116+
self.conv = nn.Conv2d(3, 6, 3)
117+
self.relu = nn.ReLU()
118+
119+
def forward(self, x):
120+
return self.relu(self.conv(self.layer1(x)))
121+
122+
module_fallback_model = ModuleFallbackMain().eval().cuda()
123+
module_fallback_script_model = torch.jit.script(module_fallback_model)
124+
torch.jit.save(module_fallback_script_model, "module_fallback_scripted.jit.pt")
100125

101126
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
102127
class FallbackIf(torch.nn.Module):

0 commit comments

Comments
 (0)