Skip to content

Commit a473bcf

Browse files
committed
refactor: way simpler module test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b72a5fe commit a473bcf

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

tests/cpp/test_module_fallback.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "torch/script.h"
55
#include "trtorch/trtorch.h"
66

7-
TEST(CppAPITests, LowerResNetModuleFallbackCorrectly) {
7+
TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
88
torch::jit::script::Module mod;
99
try {
1010
mod = torch::jit::load("tests/modules/resnet18_traced.jit.pt");
@@ -22,9 +22,7 @@ TEST(CppAPITests, LowerResNetModuleFallbackCorrectly) {
2222
trt_inputs_ivalues.push_back(in.clone());
2323
}
2424

25-
std::vector<trtorch::CompileSpec::Input> input_ranges{
26-
trtorch::CompileSpec::Input(std::vector<int64_t>({1, 3, 224, 224}))};
27-
trtorch::CompileSpec cfg(input_ranges);
25+
trtorch::CompileSpec cfg(input_shapes);
2826
cfg.torch_fallback.enabled = true;
2927
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.resnet.BasicBlock");
3028

@@ -34,7 +32,7 @@ TEST(CppAPITests, LowerResNetModuleFallbackCorrectly) {
3432
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
3533
}
3634

37-
TEST(CppAPITests, LowerAndPartitionMobileNetModuleFallbackCorrectly) {
35+
TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
3836
torch::jit::script::Module mod;
3937
try {
4038
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
@@ -52,16 +50,24 @@ TEST(CppAPITests, LowerAndPartitionMobileNetModuleFallbackCorrectly) {
5250
trt_inputs_ivalues.push_back(in.clone());
5351
}
5452

55-
std::vector<trtorch::CompileSpec::Input> input_ranges{
56-
trtorch::CompileSpec::Input(std::vector<int64_t>({1, 3, 224, 224}))};
57-
trtorch::CompileSpec cfg(input_ranges);
53+
trtorch::CompileSpec cfg(input_shapes);
5854
cfg.torch_fallback.enabled = true;
5955
cfg.torch_fallback.min_block_size = 5;
6056
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
6157

6258
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
6359
auto trt_mod = trtorch::CompileGraph(mod, cfg);
6460

61+
auto g = trt_mod.get_method("forward").graph();
62+
auto nodes = g->block()->nodes();
63+
std::size_t trt_count = 0;
64+
for (const auto n : nodes) {
65+
if (n->kind().toQualString() == std::string("tensorrt::execute_engine")) {
66+
trt_count++;
67+
}
68+
}
69+
ASSERT_TRUE(trt_count == 1);
70+
6571
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
6672
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
6773
}

0 commit comments

Comments
 (0)