1
1
#include " module_test.h"
2
2
3
+ std::vector<trtorch::ExtraInfo::InputRange> toInputRangesDynamic (std::vector<std::vector<int64_t >> opts) {
4
+ std::vector<trtorch::ExtraInfo::InputRange> a;
5
+
6
+ for (auto opt : opts) {
7
+ std::vector<int64_t > min_range (opt);
8
+ std::vector<int64_t > max_range (opt);
9
+
10
+ min_range[3 ] = ceil (opt[3 ]/2.0 );
11
+ max_range[3 ] = 2 *opt[3 ];
12
+ min_range[2 ] = ceil (opt[2 ]/2.0 );
13
+ max_range[2 ] = 2 *opt[2 ];
14
+
15
+ a.push_back (trtorch::ExtraInfo::InputRange (min_range, opt, max_range));
16
+ }
17
+
18
+ return std::move (a);
19
+ }
20
+
3
21
TEST_P (ModuleTests, SerializedModuleIsStillCorrect) {
4
22
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
5
23
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
@@ -26,11 +44,37 @@ TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
26
44
}
27
45
}
28
46
47
+ TEST_P (ModuleTests, SerializedDynamicModuleIsStillCorrect) {
48
+ std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
49
+ std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
50
+ for (auto in_shape : input_shapes) {
51
+ auto in = at::randint (5 , in_shape, {at::kCUDA });
52
+ post_serialized_inputs_ivalues.push_back (in.clone ());
53
+ pre_serialized_inputs_ivalues.push_back (in.clone ());
54
+ }
55
+
56
+ auto pre_serialized_mod = trtorch::CompileGraph (mod, toInputRangesDynamic (input_shapes));
57
+ torch::jit::IValue pre_serialized_results_ivalues = trtorch::tests::util::RunModuleForward (pre_serialized_mod, pre_serialized_inputs_ivalues);
58
+ std::vector<at::Tensor> pre_serialized_results;
59
+ pre_serialized_results.push_back (pre_serialized_results_ivalues.toTensor ());
60
+
61
+ pre_serialized_mod.save (" test_serialization_mod.ts" );
62
+ auto post_serialized_mod = torch::jit::load (" test_serialization_mod.ts" );
63
+
64
+ torch::jit::IValue post_serialized_results_ivalues = trtorch::tests::util::RunModuleForward (post_serialized_mod, post_serialized_inputs_ivalues);
65
+ std::vector<at::Tensor> post_serialized_results;
66
+ post_serialized_results.push_back (post_serialized_results_ivalues.toTensor ());
67
+
68
+ for (size_t i = 0 ; i < pre_serialized_results.size (); i++) {
69
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (post_serialized_results[i], pre_serialized_results[i].reshape_as (post_serialized_results[i]), 2e-5 ));
70
+ }
71
+ }
72
+
29
73
30
74
INSTANTIATE_TEST_SUITE_P (CompiledModuleForwardIsCloseSuite,
31
75
ModuleTests,
32
76
testing::Values (
33
77
PathAndInSize ({" tests/modules/resnet18_traced.jit.pt" ,
34
78
{{1 ,3 ,224 ,224 }}}),
35
- PathAndInSize({" tests/modules/interpolate_traced .jit.pt" ,
36
- {{1 ,3 ,5 , 5 , 5 }}})));
79
+ PathAndInSize({" tests/modules/pooling_traced .jit.pt" ,
80
+ {{1 ,3 ,10 , 10 }}})));
0 commit comments