Skip to content

Commit c864096

Browse files
committed
chore: Add cpp tests with cosine sim
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 0ca049f commit c864096

File tree

6 files changed

+332
-0
lines changed

6 files changed

+332
-0
lines changed

tests/core/partitioning/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ cc_test(
5555
}),
5656
)
5757

58+
cc_test(
59+
name = "test_fallback_graph_output",
60+
srcs = ["test_fallback_graph_output.cpp"],
61+
data = [
62+
":jit_models",
63+
],
64+
deps = [
65+
"//tests/util",
66+
"@googletest//:gtest_main",
67+
] + select({
68+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
69+
"//conditions:default": ["@libtorch//:libtorch"],
70+
}),
71+
)
72+
5873
cc_test(
5974
name = "test_loop_fallback",
6075
srcs = ["test_loop_fallback.cpp"],
@@ -89,6 +104,7 @@ test_suite(
89104
name = "partitioning_tests",
90105
tests = [
91106
":test_conditionals",
107+
":test_fallback_graph_output",
92108
":test_loading_model",
93109
":test_loop_fallback",
94110
":test_resolve_nontensor_inputs",
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/script.h"
7+
8+
#ifndef DISABLE_TEST_IN_CI
9+
10+
TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
11+
torch::jit::script::Module mod;
12+
try {
13+
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
14+
} catch (const c10::Error& e) {
15+
std::cerr << "error loading the model\n";
16+
return;
17+
}
18+
19+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
20+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
21+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
22+
for (auto in_shape : input_shapes) {
23+
auto in = at::randint(5, in_shape, {at::kCUDA});
24+
jit_inputs_ivalues.push_back(in.clone());
25+
trt_inputs_ivalues.push_back(in.clone());
26+
}
27+
28+
std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
29+
30+
torch_tensorrt::core::CompileSpec cfg(input_ranges);
31+
cfg.partition_info.enabled = true;
32+
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
33+
34+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
35+
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
36+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
37+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
38+
}
39+
40+
TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
41+
torch::jit::script::Module mod;
42+
try {
43+
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
44+
} catch (const c10::Error& e) {
45+
std::cerr << "error loading the model\n";
46+
return;
47+
}
48+
49+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
50+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
51+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
52+
for (auto in_shape : input_shapes) {
53+
auto in = at::randint(5, in_shape, {at::kCUDA});
54+
jit_inputs_ivalues.push_back(in.clone());
55+
trt_inputs_ivalues.push_back(in.clone());
56+
}
57+
58+
std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
59+
auto g = mod.get_method("forward").graph();
60+
torch_tensorrt::core::CompileSpec cfg(input_ranges);
61+
cfg.partition_info.enabled = true;
62+
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
63+
64+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
65+
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
66+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
67+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
68+
}
69+
#endif

tests/cpp/BUILD

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@ test_suite(
1313
name = "api_tests",
1414
tests = [
1515
":test_collections",
16+
":test_compiled_modules",
1617
":test_default_input_types",
1718
":test_example_tensors",
19+
":test_module_fallback",
1820
":test_modules_as_engines",
21+
":test_multiple_registered_engines",
1922
":test_runtime_thread_safety",
2023
":test_serialization",
2124
],
@@ -25,9 +28,12 @@ test_suite(
2528
name = "aarch64_api_tests",
2629
tests = [
2730
":test_collections",
31+
":test_compiled_modules",
2832
":test_default_input_types",
2933
":test_example_tensors",
34+
":test_module_fallback",
3035
":test_modules_as_engines",
36+
":test_multiple_registered_engines",
3137
":test_runtime_thread_safety",
3238
":test_serialization",
3339
],
@@ -66,6 +72,21 @@ cc_test(
6672
],
6773
)
6874

75+
cc_test(
76+
name = "test_multiple_registered_engines",
77+
srcs = ["test_multiple_registered_engines.cpp"],
78+
data = [
79+
"//tests/modules:jit_models",
80+
],
81+
deps = [
82+
"//tests/util",
83+
"@googletest//:gtest_main",
84+
] + select({
85+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
86+
"//conditions:default": ["@libtorch//:libtorch"],
87+
}),
88+
)
89+
6990
cc_test(
7091
name = "test_modules_as_engines",
7192
timeout = "long",
@@ -89,6 +110,21 @@ cc_test(
89110
],
90111
)
91112

113+
cc_test(
114+
name = "test_module_fallback",
115+
srcs = ["test_module_fallback.cpp"],
116+
data = [
117+
"//tests/modules:jit_models",
118+
],
119+
deps = [
120+
"//tests/util",
121+
"@googletest//:gtest_main",
122+
] + select({
123+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
124+
"//conditions:default": ["@libtorch//:libtorch"],
125+
}),
126+
)
127+
92128
cc_test(
93129
name = "test_collections",
94130
srcs = ["test_collections.cpp"],
@@ -104,6 +140,17 @@ cc_test(
104140
}),
105141
)
106142

143+
cc_test(
144+
name = "test_compiled_modules",
145+
srcs = ["test_compiled_modules.cpp"],
146+
data = [
147+
"//tests/modules:jit_models",
148+
],
149+
deps = [
150+
":cpp_api_test",
151+
],
152+
)
153+
107154
cc_test(
108155
name = "test_multi_gpu_serde",
109156
srcs = ["test_multi_gpu_serde.cpp"],

tests/cpp/test_compiled_modules.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "cpp_api_test.h"
2+
3+
TEST_P(CppAPITests, CompiledModuleIsClose) {
4+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
5+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6+
std::vector<torch_tensorrt::Input> shapes;
7+
for (uint64_t i = 0; i < input_shapes.size(); i++) {
8+
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
9+
jit_inputs_ivalues.push_back(in.clone());
10+
trt_inputs_ivalues.push_back(in.clone());
11+
auto in_spec = torch_tensorrt::Input(input_shapes[i]);
12+
in_spec.dtype = input_types[i];
13+
shapes.push_back(in_spec);
14+
std::cout << in_spec << std::endl;
15+
}
16+
17+
torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, jit_inputs_ivalues);
18+
std::vector<at::Tensor> jit_results;
19+
if (jit_results_ivalues.isTuple()) {
20+
auto tuple = jit_results_ivalues.toTuple();
21+
for (auto t : tuple->elements()) {
22+
jit_results.push_back(t.toTensor());
23+
}
24+
} else {
25+
jit_results.push_back(jit_results_ivalues.toTensor());
26+
}
27+
28+
auto spec = torch_tensorrt::ts::CompileSpec(shapes);
29+
spec.truncate_long_and_double = true;
30+
31+
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
32+
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
33+
std::vector<at::Tensor> trt_results;
34+
if (trt_results_ivalues.isTuple()) {
35+
auto tuple = trt_results_ivalues.toTuple();
36+
for (auto t : tuple->elements()) {
37+
trt_results.push_back(t.toTensor());
38+
}
39+
} else {
40+
trt_results.push_back(trt_results_ivalues.toTensor());
41+
}
42+
43+
for (size_t i = 0; i < trt_results.size(); i++) {
44+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 0.99));
45+
}
46+
}
47+
48+
#ifndef DISABLE_TEST_IN_CI
49+
50+
INSTANTIATE_TEST_SUITE_P(
51+
CompiledModuleForwardIsCloseSuite,
52+
CppAPITests,
53+
testing::Values(
54+
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
55+
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
56+
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}),
57+
PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}),
58+
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2})));
59+
60+
#endif

tests/cpp/test_module_fallback.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
4+
#include "torch/script.h"
5+
#include "torch_tensorrt/torch_tensorrt.h"
6+
7+
#ifndef DISABLE_TEST_IN_CI
8+
9+
TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
10+
torch::jit::script::Module mod;
11+
try {
12+
mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt");
13+
} catch (const c10::Error& e) {
14+
std::cerr << "error loading the model\n";
15+
ASSERT_TRUE(false);
16+
}
17+
18+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
19+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
20+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
21+
for (auto in_shape : input_shapes) {
22+
auto in = at::randint(5, in_shape, {at::kCUDA});
23+
jit_inputs_ivalues.push_back(in.clone());
24+
trt_inputs_ivalues.push_back(in.clone());
25+
}
26+
27+
torch_tensorrt::ts::CompileSpec cfg(input_shapes);
28+
cfg.torch_executed_modules.push_back("torchvision.models.resnet.BasicBlock");
29+
30+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
31+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
32+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
33+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
34+
}
35+
36+
TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
37+
torch::jit::script::Module mod;
38+
try {
39+
mod = torch::jit::load("tests/modules/mobilenet_v2_scripted.jit.pt");
40+
} catch (const c10::Error& e) {
41+
std::cerr << "error loading the model\n";
42+
ASSERT_TRUE(false);
43+
}
44+
45+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
46+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
47+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
48+
for (auto in_shape : input_shapes) {
49+
auto in = at::randint(5, in_shape, {at::kCUDA});
50+
jit_inputs_ivalues.push_back(in.clone());
51+
trt_inputs_ivalues.push_back(in.clone());
52+
}
53+
54+
torch_tensorrt::ts::CompileSpec cfg(input_shapes);
55+
cfg.min_block_size = 5;
56+
cfg.torch_executed_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
57+
58+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
59+
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
60+
61+
auto g = trt_mod.get_method("forward").graph();
62+
auto nodes = g->block()->nodes();
63+
std::size_t trt_count = 0;
64+
for (const auto n : nodes) {
65+
if (n->kind().toQualString() == std::string("tensorrt::execute_engine")) {
66+
trt_count++;
67+
}
68+
}
69+
ASSERT_TRUE(trt_count == 1);
70+
71+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
72+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
73+
}
74+
#endif
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
4+
#include "torch/script.h"
5+
#include "torch_tensorrt/torch_tensorrt.h"
6+
7+
#ifndef DISABLE_TEST_IN_CI
8+
9+
TEST(CppAPITest, CanRunMultipleEngines) {
10+
torch::jit::script::Module mod1;
11+
torch::jit::script::Module mod2;
12+
try {
13+
mod1 = torch::jit::load("tests/modules/resnet18_traced.jit.pt");
14+
mod2 = torch::jit::load("tests/modules/resnet18_traced.jit.pt");
15+
} catch (const c10::Error& e) {
16+
std::cerr << "error loading the model\n";
17+
return;
18+
}
19+
20+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
21+
22+
std::vector<torch::jit::IValue> jit1_inputs_ivalues;
23+
std::vector<torch::jit::IValue> trt1_inputs_ivalues;
24+
for (auto in_shape : input_shapes) {
25+
auto in = at::randint(5, in_shape, {at::kCUDA});
26+
jit1_inputs_ivalues.push_back(in.clone());
27+
trt1_inputs_ivalues.push_back(in.clone());
28+
}
29+
30+
std::vector<torch::jit::IValue> jit2_inputs_ivalues;
31+
std::vector<torch::jit::IValue> trt2_inputs_ivalues;
32+
for (auto in_shape : input_shapes) {
33+
auto in = at::randint(5, in_shape, {at::kCUDA});
34+
jit2_inputs_ivalues.push_back(in.clone());
35+
trt2_inputs_ivalues.push_back(in.clone());
36+
}
37+
38+
torch::jit::IValue jit1_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod1, jit1_inputs_ivalues);
39+
std::vector<at::Tensor> jit1_results;
40+
jit1_results.push_back(jit1_results_ivalues.toTensor());
41+
42+
torch::jit::IValue jit2_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod2, jit2_inputs_ivalues);
43+
std::vector<at::Tensor> jit2_results;
44+
jit2_results.push_back(jit2_results_ivalues.toTensor());
45+
46+
auto trt_mod1 = torch_tensorrt::ts::compile(mod1, input_shapes);
47+
torch::jit::IValue trt1_results_ivalues =
48+
torch_tensorrt::tests::util::RunModuleForward(trt_mod1, trt1_inputs_ivalues);
49+
std::vector<at::Tensor> trt1_results;
50+
trt1_results.push_back(trt1_results_ivalues.toTensor());
51+
52+
auto trt_mod2 = torch_tensorrt::ts::compile(mod2, input_shapes);
53+
torch::jit::IValue trt2_results_ivalues =
54+
torch_tensorrt::tests::util::RunModuleForward(trt_mod2, trt2_inputs_ivalues);
55+
std::vector<at::Tensor> trt2_results;
56+
trt2_results.push_back(trt2_results_ivalues.toTensor());
57+
58+
for (size_t i = 0; i < trt1_results.size(); i++) {
59+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit1_results[i], trt1_results[i].reshape_as(jit1_results[i]), 0.99));
60+
}
61+
62+
for (size_t i = 0; i < trt2_results.size(); i++) {
63+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit2_results[i], trt2_results[i].reshape_as(jit2_results[i]), 0.99));
64+
}
65+
}
66+
#endif

0 commit comments

Comments
 (0)