Skip to content

Commit efee3e6

Browse files
authored
Merge pull request #657 from NVIDIA/fix_tests
fix: Fix modules_as_engines test case to use trt_mod instead of pyt_mod
2 parents b47e926 + 32e8b53 commit efee3e6

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

tests/cpp/test_module_fallback.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
88
torch::jit::script::Module mod;
99
try {
10-
mod = torch::jit::load("tests/modules/resnet18_traced.jit.pt");
10+
mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt");
1111
} catch (const c10::Error& e) {
1212
std::cerr << "error loading the model\n";
1313
ASSERT_TRUE(false);
@@ -35,7 +35,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
3535
TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
3636
torch::jit::script::Module mod;
3737
try {
38-
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
38+
mod = torch::jit::load("tests/modules/mobilenet_v2_scripted.jit.pt");
3939
} catch (const c10::Error& e) {
4040
std::cerr << "error loading the model\n";
4141
ASSERT_TRUE(false);

tests/cpp/test_modules_as_engines.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
2929
std::vector<at::Tensor> jit_results;
3030
jit_results.push_back(jit_results_ivalues.toTensor());
3131

32-
auto forward_graph = mod.get_method("forward");
3332
std::vector<c10::ArrayRef<int64_t>> input_ranges;
3433
for (auto in : inputs) {
3534
input_ranges.push_back(in.sizes());
@@ -43,7 +42,7 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
4342
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
4443
auto trt_mod = trtorch::EmbedEngineInNewModule(engine, compile_spec.device);
4544

46-
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
45+
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, inputs_ivalues);
4746
std::vector<at::Tensor> trt_results;
4847
trt_results.push_back(trt_results_ivalues.toTensor());
4948

@@ -61,4 +60,4 @@ INSTANTIATE_TEST_SUITE_P(
6160
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
6261
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
6362
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
64-
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));
63+
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));

0 commit comments

Comments
 (0)