Skip to content

Commit 5756169

Browse files
committed
refactor: Refactor testing to use cosine similarity, remove redundancy models and restructuring
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 81f2a5c commit 5756169

20 files changed

+402
-552
lines changed

tests/core/lowering/test_module_fallback_passes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,5 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
124124
}
125125

126126
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
127-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
127+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
128128
}

tests/core/partitioning/BUILD

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,6 @@ cc_test(
5555
}),
5656
)
5757

58-
cc_test(
59-
name = "test_fallback_graph_output",
60-
srcs = ["test_fallback_graph_output.cpp"],
61-
data = [
62-
":jit_models",
63-
],
64-
deps = [
65-
"//tests/util",
66-
"@googletest//:gtest_main",
67-
] + select({
68-
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
69-
"//conditions:default": ["@libtorch//:libtorch"],
70-
}),
71-
)
72-
7358
cc_test(
7459
name = "test_loop_fallback",
7560
srcs = ["test_loop_fallback.cpp"],
@@ -104,7 +89,6 @@ test_suite(
10489
name = "partitioning_tests",
10590
tests = [
10691
":test_conditionals",
107-
":test_fallback_graph_output",
10892
":test_loading_model",
10993
":test_loop_fallback",
11094
":test_resolve_nontensor_inputs",

tests/core/partitioning/test_fallback_graph_output.cpp

Lines changed: 0 additions & 69 deletions
This file was deleted.

tests/cpp/BUILD

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@ test_suite(
1313
name = "api_tests",
1414
tests = [
1515
":test_collections",
16-
":test_compiled_modules",
1716
":test_default_input_types",
1817
":test_example_tensors",
19-
":test_module_fallback",
2018
":test_modules_as_engines",
21-
":test_multiple_registered_engines",
2219
":test_runtime_thread_safety",
2320
":test_serialization",
2421
],
@@ -28,12 +25,9 @@ test_suite(
2825
name = "aarch64_api_tests",
2926
tests = [
3027
":test_collections",
31-
":test_compiled_modules",
3228
":test_default_input_types",
3329
":test_example_tensors",
34-
":test_module_fallback",
3530
":test_modules_as_engines",
36-
":test_multiple_registered_engines",
3731
":test_runtime_thread_safety",
3832
":test_serialization",
3933
],
@@ -72,21 +66,6 @@ cc_test(
7266
],
7367
)
7468

75-
cc_test(
76-
name = "test_multiple_registered_engines",
77-
srcs = ["test_multiple_registered_engines.cpp"],
78-
data = [
79-
"//tests/modules:jit_models",
80-
],
81-
deps = [
82-
"//tests/util",
83-
"@googletest//:gtest_main",
84-
] + select({
85-
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
86-
"//conditions:default": ["@libtorch//:libtorch"],
87-
}),
88-
)
89-
9069
cc_test(
9170
name = "test_modules_as_engines",
9271
timeout = "long",
@@ -110,21 +89,6 @@ cc_test(
11089
],
11190
)
11291

113-
cc_test(
114-
name = "test_module_fallback",
115-
srcs = ["test_module_fallback.cpp"],
116-
data = [
117-
"//tests/modules:jit_models",
118-
],
119-
deps = [
120-
"//tests/util",
121-
"@googletest//:gtest_main",
122-
] + select({
123-
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
124-
"//conditions:default": ["@libtorch//:libtorch"],
125-
}),
126-
)
127-
12892
cc_test(
12993
name = "test_collections",
13094
srcs = ["test_collections.cpp"],
@@ -140,17 +104,6 @@ cc_test(
140104
}),
141105
)
142106

143-
cc_test(
144-
name = "test_compiled_modules",
145-
srcs = ["test_compiled_modules.cpp"],
146-
data = [
147-
"//tests/modules:jit_models",
148-
],
149-
deps = [
150-
":cpp_api_test",
151-
],
152-
)
153-
154107
cc_test(
155108
name = "test_multi_gpu_serde",
156109
srcs = ["test_multi_gpu_serde.cpp"],

tests/cpp/test_compiled_modules.cpp

Lines changed: 0 additions & 65 deletions
This file was deleted.

tests/cpp/test_module_fallback.cpp

Lines changed: 0 additions & 74 deletions
This file was deleted.

tests/cpp/test_modules_as_engines.cpp

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,7 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) {
1515
auto trt_results = torch_tensorrt::tests::util::RunModuleForwardAsEngine(mod, inputs);
1616

1717
ASSERT_TRUE(
18-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold));
19-
}
20-
21-
TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
22-
std::vector<at::Tensor> inputs;
23-
std::vector<torch::jit::IValue> inputs_ivalues;
24-
for (uint64_t i = 0; i < input_shapes.size(); i++) {
25-
inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]));
26-
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
27-
}
28-
29-
torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, inputs_ivalues);
30-
std::vector<at::Tensor> jit_results;
31-
jit_results.push_back(jit_results_ivalues.toTensor());
32-
33-
std::vector<c10::ArrayRef<int64_t>> input_ranges;
34-
for (auto in : inputs) {
35-
input_ranges.push_back(in.sizes());
36-
}
37-
38-
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_ranges});
39-
int device_id = 0;
40-
cudaGetDevice(&device_id);
41-
compile_spec.device.device_type = torch_tensorrt::Device::DeviceType::kGPU;
42-
compile_spec.device.gpu_id = device_id;
43-
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", input_ranges);
44-
auto trt_mod = torch_tensorrt::ts::embed_engine_in_new_module(engine, compile_spec.device);
45-
46-
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, inputs_ivalues);
47-
std::vector<at::Tensor> trt_results;
48-
trt_results.push_back(trt_results_ivalues.toTensor());
49-
50-
ASSERT_TRUE(
51-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold));
18+
torch_tensorrt::tests::util::cosineSimEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), threshold));
5219
}
5320

5421
#ifndef DISABLE_TEST_IN_CI
@@ -57,12 +24,8 @@ INSTANTIATE_TEST_SUITE_P(
5724
ModuleAsEngineForwardIsCloseSuite,
5825
CppAPITests,
5926
testing::Values(
60-
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
61-
PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
62-
PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
63-
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
64-
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
65-
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
66-
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 1e-4}),
67-
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2})));
27+
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
28+
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
29+
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
30+
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99})));
6831
#endif

0 commit comments

Comments
 (0)