1
1
#include < string>
2
2
#include " gtest/gtest.h"
3
- #include " torch/csrc/jit/irparser .h"
3
+ #include " torch/script .h"
4
4
#include " tests/util/util.h"
5
- #include " cpp /trtorch.h"
5
+ #include " trtorch /trtorch.h"
6
6
7
7
TEST (ModuleTests, CanRunMultipleEngines) {
8
8
torch::jit::script::Module mod1;
@@ -16,7 +16,7 @@ TEST(ModuleTests, CanRunMultipleEngines) {
16
16
return ;
17
17
}
18
18
19
- const std::vector<int64_t > input_shape = {1 ,3 ,224 ,224 };
19
+ const std::vector<std::vector< int64_t >> input_shapes = {{ 1 ,3 ,224 ,224 } };
20
20
21
21
std::vector<torch::jit::IValue> jit1_inputs_ivalues;
22
22
std::vector<torch::jit::IValue> trt1_inputs_ivalues;
@@ -38,18 +38,18 @@ TEST(ModuleTests, CanRunMultipleEngines) {
38
38
std::vector<at::Tensor> jit1_results;
39
39
jit1_results.push_back (jit1_results_ivalues.toTensor ());
40
40
41
- torch::jit::IValue jit2_results_ivalues = trtorch::tests::util::RunModuleForward (mod2, jit2_inputs_ivalues);
41
+ torch::jit::IValue jit2_results_ivalues = trtorch::tests::util::RunModuleForward (mod2, jit2_inputs_ivalues);
42
42
std::vector<at::Tensor> jit2_results;
43
43
jit2_results.push_back (jit2_results_ivalues.toTensor ());
44
44
45
45
46
46
auto trt_mod1 = trtorch::CompileGraph (mod1, input_shapes);
47
- torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward (trt1_mod , trt1_inputs_ivalues);
47
+ torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward (trt_mod1 , trt1_inputs_ivalues);
48
48
std::vector<at::Tensor> trt1_results;
49
49
trt1_results.push_back (trt1_results_ivalues.toTensor ());
50
50
51
51
auto trt_mod2 = trtorch::CompileGraph (mod2, input_shapes);
52
- torch::jit::IValue trt2_results_ivalues = trtorch::tests::util::RunModuleForward (trt2_mod , trt2_inputs_ivalues);
52
+ torch::jit::IValue trt2_results_ivalues = trtorch::tests::util::RunModuleForward (trt_mod2 , trt2_inputs_ivalues);
53
53
std::vector<at::Tensor> trt2_results;
54
54
trt2_results.push_back (trt2_results_ivalues.toTensor ());
55
55
0 commit comments