1- #include < string>
2- #include < unordered_set>
3- #include " core/compiler.h"
4- #include " gtest/gtest.h"
5- #include " tests/util/util.h"
6- #include " torch/script.h"
7-
8- TEST (Partitioning, ComputeResNet50FallbackGraphCorrectly) {
9- torch::jit::script::Module mod;
10- try {
11- mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
12- } catch (const c10::Error& e) {
13- std::cerr << " error loading the model\n " ;
14- return ;
15- }
16-
17- const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
18- std::vector<torch::jit::IValue> jit_inputs_ivalues;
19- std::vector<torch::jit::IValue> trt_inputs_ivalues;
20- for (auto in_shape : input_shapes) {
21- auto in = at::randint (5 , in_shape, {at::kCUDA });
22- jit_inputs_ivalues.push_back (in.clone ());
23- trt_inputs_ivalues.push_back (in.clone ());
24- }
25-
26- std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 })};
27-
28- torch_tensorrt::core::CompileSpec cfg (input_ranges);
29- cfg.partition_info .enabled = true ;
30- cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
31-
32- auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
33- auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
34- auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
35- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
36- }
37-
38- TEST (Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
39- torch::jit::script::Module mod;
40- try {
41- mod = torch::jit::load (" tests/modules/mobilenet_v2_traced.jit.pt" );
42- } catch (const c10::Error& e) {
43- std::cerr << " error loading the model\n " ;
44- return ;
45- }
46-
47- const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
48- std::vector<torch::jit::IValue> jit_inputs_ivalues;
49- std::vector<torch::jit::IValue> trt_inputs_ivalues;
50- for (auto in_shape : input_shapes) {
51- auto in = at::randint (5 , in_shape, {at::kCUDA });
52- jit_inputs_ivalues.push_back (in.clone ());
53- trt_inputs_ivalues.push_back (in.clone ());
54- }
55-
56- std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 })};
57- auto g = mod.get_method (" forward" ).graph ();
58- torch_tensorrt::core::CompileSpec cfg (input_ranges);
59- cfg.partition_info .enabled = true ;
60- cfg.partition_info .forced_fallback_operators .push_back (" aten::hardtanh" );
61-
62- auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
63- auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
64- auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
65- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
66- }
67-
68- TEST (Partitioning, ComputeResNet50HalfFallbackGraphCorrectly) {
69- torch::jit::script::Module mod;
70- try {
71- mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
72- } catch (const c10::Error& e) {
73- std::cerr << " error loading the model\n " ;
74- return ;
75- }
76-
77- mod.to (torch::kHalf );
78-
79- const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
80- std::vector<torch::jit::IValue> jit_inputs_ivalues;
81- std::vector<torch::jit::IValue> trt_inputs_ivalues;
82- for (auto in_shape : input_shapes) {
83- auto in = at::randint (5 , in_shape, {at::kCUDA }).to (torch::kHalf );
84- jit_inputs_ivalues.push_back (in.clone ());
85- trt_inputs_ivalues.push_back (in.clone ());
86- }
87-
88- auto in_shape = torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 });
89- in_shape.dtype = nvinfer1::DataType::kHALF ;
90-
91- std::vector<torch_tensorrt::core::ir::Input> input_ranges ({in_shape});
92- auto g = mod.get_method (" forward" ).graph ();
93- torch_tensorrt::core::CompileSpec cfg (input_ranges);
94- cfg.partition_info .enabled = true ;
95- cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
96-
97- auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
98- auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
99- auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
100- // Lower threshold because FP16
101- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-1 ));
102- }
1+ #include < string>
2+ #include < unordered_set>
3+ #include " core/compiler.h"
4+ #include " gtest/gtest.h"
5+ #include " tests/util/util.h"
6+ #include " torch/script.h"
7+
8+ #ifndef DISABLE_TEST_IN_CI
9+
10+ TEST (Partitioning, ComputeResNet50FallbackGraphCorrectly) {
11+ torch::jit::script::Module mod;
12+ try {
13+ mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
14+ } catch (const c10::Error& e) {
15+ std::cerr << " error loading the model\n " ;
16+ return ;
17+ }
18+
19+ const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
20+ std::vector<torch::jit::IValue> jit_inputs_ivalues;
21+ std::vector<torch::jit::IValue> trt_inputs_ivalues;
22+ for (auto in_shape : input_shapes) {
23+ auto in = at::randint (5 , in_shape, {at::kCUDA });
24+ jit_inputs_ivalues.push_back (in.clone ());
25+ trt_inputs_ivalues.push_back (in.clone ());
26+ }
27+
28+ std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 })};
29+
30+ torch_tensorrt::core::CompileSpec cfg (input_ranges);
31+ cfg.partition_info .enabled = true ;
32+ cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
33+
34+ auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
35+ auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
36+ auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
37+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
38+ }
39+
40+ TEST (Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
41+ torch::jit::script::Module mod;
42+ try {
43+ mod = torch::jit::load (" tests/modules/mobilenet_v2_traced.jit.pt" );
44+ } catch (const c10::Error& e) {
45+ std::cerr << " error loading the model\n " ;
46+ return ;
47+ }
48+
49+ const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
50+ std::vector<torch::jit::IValue> jit_inputs_ivalues;
51+ std::vector<torch::jit::IValue> trt_inputs_ivalues;
52+ for (auto in_shape : input_shapes) {
53+ auto in = at::randint (5 , in_shape, {at::kCUDA });
54+ jit_inputs_ivalues.push_back (in.clone ());
55+ trt_inputs_ivalues.push_back (in.clone ());
56+ }
57+
58+ std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 })};
59+ auto g = mod.get_method (" forward" ).graph ();
60+ torch_tensorrt::core::CompileSpec cfg (input_ranges);
61+ cfg.partition_info .enabled = true ;
62+ cfg.partition_info .forced_fallback_operators .push_back (" aten::hardtanh" );
63+
64+ auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
65+ auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
66+ auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
67+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
68+ }
69+
70+ TEST (Partitioning, ComputeResNet50HalfFallbackGraphCorrectly) {
71+ torch::jit::script::Module mod;
72+ try {
73+ mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
74+ } catch (const c10::Error& e) {
75+ std::cerr << " error loading the model\n " ;
76+ return ;
77+ }
78+
79+ mod.to (torch::kHalf );
80+
81+ const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
82+ std::vector<torch::jit::IValue> jit_inputs_ivalues;
83+ std::vector<torch::jit::IValue> trt_inputs_ivalues;
84+ for (auto in_shape : input_shapes) {
85+ auto in = at::randint (5 , in_shape, {at::kCUDA }).to (torch::kHalf );
86+ jit_inputs_ivalues.push_back (in.clone ());
87+ trt_inputs_ivalues.push_back (in.clone ());
88+ }
89+
90+ auto in_shape = torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 });
91+ in_shape.dtype = nvinfer1::DataType::kHALF ;
92+
93+ std::vector<torch_tensorrt::core::ir::Input> input_ranges ({in_shape});
94+ auto g = mod.get_method (" forward" ).graph ();
95+ torch_tensorrt::core::CompileSpec cfg (input_ranges);
96+ cfg.partition_info .enabled = true ;
97+ cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
98+
99+ auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
100+ auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
101+ auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
102+ // Lower threshold because FP16
103+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-1 ));
104+ }
105+ #endif
0 commit comments