Skip to content

Commit da28184

Browse files
committed
test(serialization): Changes the serialization test to cover both
dynamic (plugin) and static, uses a pooling based model instead of interpolate since that is handled by TensorRT now Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 010b801 commit da28184

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

tests/modules/hub.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@
7979
script_model = torch.jit.script(m["model"])
8080
torch.jit.save(script_model, n + '_scripted.jit.pt')
8181

82-
# Sample Interpolation Model (align_corners=False, for Testing Interpolate Plugin specifically)
83-
class Interpolate(nn.Module):
82+
# Sample Pool Model (for testing plugin serialization)
83+
class Pool(nn.Module):
8484
def __init__(self):
85-
super(Interpolate, self).__init__()
85+
super(Pool, self).__init__()
8686

8787
def forward(self, x):
88-
return F.interpolate(x, size=(10, 10, 10), align_corners=False, mode="trilinear")
88+
return F.adaptive_avg_pool2d(x, (5, 5))
8989

90-
model = Interpolate().eval().cuda()
91-
x = torch.ones([1, 3, 5, 5, 5]).cuda()
90+
model = Pool().eval().cuda()
91+
x = torch.ones([1, 3, 10, 10]).cuda()
9292

9393
trace_model = torch.jit.trace(model, x)
94-
torch.jit.save(trace_model, "interpolate_traced.jit.pt")
94+
torch.jit.save(trace_model, "pooling_traced.jit.pt")
9595

9696

tests/modules/test_serialization.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
#include "module_test.h"
22

3+
std::vector<trtorch::ExtraInfo::InputRange> toInputRangesDynamic(std::vector<std::vector<int64_t>> opts) {
4+
std::vector<trtorch::ExtraInfo::InputRange> a;
5+
6+
for (auto opt : opts) {
7+
std::vector<int64_t> min_range(opt);
8+
std::vector<int64_t> max_range(opt);
9+
10+
min_range[3] = ceil(opt[3]/2.0);
11+
max_range[3] = 2*opt[3];
12+
min_range[2] = ceil(opt[2]/2.0);
13+
max_range[2] = 2*opt[2];
14+
15+
a.push_back(trtorch::ExtraInfo::InputRange(min_range, opt, max_range));
16+
}
17+
18+
return std::move(a);
19+
}
20+
321
TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
422
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
523
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
@@ -26,11 +44,37 @@ TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
2644
}
2745
}
2846

47+
TEST_P(ModuleTests, SerializedDynamicModuleIsStillCorrect) {
48+
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
49+
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
50+
for (auto in_shape : input_shapes) {
51+
auto in = at::randint(5, in_shape, {at::kCUDA});
52+
post_serialized_inputs_ivalues.push_back(in.clone());
53+
pre_serialized_inputs_ivalues.push_back(in.clone());
54+
}
55+
56+
auto pre_serialized_mod = trtorch::CompileGraph(mod, toInputRangesDynamic(input_shapes));
57+
torch::jit::IValue pre_serialized_results_ivalues = trtorch::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
58+
std::vector<at::Tensor> pre_serialized_results;
59+
pre_serialized_results.push_back(pre_serialized_results_ivalues.toTensor());
60+
61+
pre_serialized_mod.save("test_serialization_mod.ts");
62+
auto post_serialized_mod = torch::jit::load("test_serialization_mod.ts");
63+
64+
torch::jit::IValue post_serialized_results_ivalues = trtorch::tests::util::RunModuleForward(post_serialized_mod, post_serialized_inputs_ivalues);
65+
std::vector<at::Tensor> post_serialized_results;
66+
post_serialized_results.push_back(post_serialized_results_ivalues.toTensor());
67+
68+
for (size_t i = 0; i < pre_serialized_results.size(); i++) {
69+
ASSERT_TRUE(trtorch::tests::util::almostEqual(post_serialized_results[i], pre_serialized_results[i].reshape_as(post_serialized_results[i]), 2e-5));
70+
}
71+
}
72+
2973

3074
INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite,
3175
ModuleTests,
3276
testing::Values(
3377
PathAndInSize({"tests/modules/resnet18_traced.jit.pt",
3478
{{1,3,224,224}}}),
35-
PathAndInSize({"tests/modules/interpolate_traced.jit.pt",
36-
{{1,3,5,5,5}}})));
79+
PathAndInSize({"tests/modules/pooling_traced.jit.pt",
80+
{{1,3,10,10}}})));

0 commit comments

Comments
 (0)