4
4
#include " torch/script.h"
5
5
#include " trtorch/trtorch.h"
6
6
7
- TEST (CppAPITests, LowerResNetModuleFallbackCorrectly ) {
7
+ TEST (CppAPITest, ResNetModuleFallbacksCorrectly ) {
8
8
torch::jit::script::Module mod;
9
9
try {
10
10
mod = torch::jit::load (" tests/modules/resnet18_traced.jit.pt" );
@@ -22,9 +22,7 @@ TEST(CppAPITests, LowerResNetModuleFallbackCorrectly) {
22
22
trt_inputs_ivalues.push_back (in.clone ());
23
23
}
24
24
25
- std::vector<trtorch::CompileSpec::Input> input_ranges{
26
- trtorch::CompileSpec::Input (std::vector<int64_t >({1 , 3 , 224 , 224 }))};
27
- trtorch::CompileSpec cfg (input_ranges);
25
+ trtorch::CompileSpec cfg (input_shapes);
28
26
cfg.torch_fallback .enabled = true ;
29
27
cfg.torch_fallback .forced_fallback_modules .push_back (" torchvision.models.resnet.BasicBlock" );
30
28
@@ -34,7 +32,7 @@ TEST(CppAPITests, LowerResNetModuleFallbackCorrectly) {
34
32
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
35
33
}
36
34
37
- TEST (CppAPITests, LowerAndPartitionMobileNetModuleFallbackCorrectly ) {
35
+ TEST (CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine ) {
38
36
torch::jit::script::Module mod;
39
37
try {
40
38
mod = torch::jit::load (" tests/modules/mobilenet_v2_traced.jit.pt" );
@@ -52,16 +50,24 @@ TEST(CppAPITests, LowerAndPartitionMobileNetModuleFallbackCorrectly) {
52
50
trt_inputs_ivalues.push_back (in.clone ());
53
51
}
54
52
55
- std::vector<trtorch::CompileSpec::Input> input_ranges{
56
- trtorch::CompileSpec::Input (std::vector<int64_t >({1 , 3 , 224 , 224 }))};
57
- trtorch::CompileSpec cfg (input_ranges);
53
+ trtorch::CompileSpec cfg (input_shapes);
58
54
cfg.torch_fallback .enabled = true ;
59
55
cfg.torch_fallback .min_block_size = 5 ;
60
56
cfg.torch_fallback .forced_fallback_modules .push_back (" torchvision.models.mobilenetv2.ConvBNActivation" );
61
57
62
58
auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
63
59
auto trt_mod = trtorch::CompileGraph (mod, cfg);
64
60
61
+ auto g = trt_mod.get_method (" forward" ).graph ();
62
+ auto nodes = g->block ()->nodes ();
63
+ std::size_t trt_count = 0 ;
64
+ for (const auto n : nodes) {
65
+ if (n->kind ().toQualString () == std::string (" tensorrt::execute_engine" )) {
66
+ trt_count++;
67
+ }
68
+ }
69
+ ASSERT_TRUE (trt_count == 1 );
70
+
65
71
auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
66
72
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
67
73
}
0 commit comments