Skip to content

Commit d73dc42

Browse files
committed
tests: update the dependent models for fallback graph conversion, stitch and computation
Signed-off-by: Bo Wang <[email protected]>
1 parent 824b555 commit d73dc42

File tree

5 files changed

+33
-76
lines changed

5 files changed

+33
-76
lines changed

tests/core/partitioning/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ config_setting(
99

1010
filegroup(
1111
name = "jit_models",
12-
srcs = glob(["**/*.jit"])
12+
srcs = ["//tests/modules:resnet50_traced.jit.pt",
13+
"//tests/modules:mobilenet_v2_traced.jit.pt"]
1314
)
1415

1516
partitioning_test(

tests/core/partitioning/gen_test_model.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

tests/core/partitioning/test_fallback_graph_output.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
#include "tests/util/util.h"
66
#include "torch/script.h"
77

8-
TEST(Partitioning, StitchSegmentedBlockCorrectly) {
8+
TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
99
torch::jit::script::Module mod;
1010
try {
11-
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
11+
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
1212
} catch (const c10::Error& e) {
1313
std::cerr << "error loading the model\n";
1414
return;
1515
}
1616

17-
const std::vector<std::vector<int64_t>> input_shapes = {{3, 3, 16, 16}};
17+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
1818
std::vector<torch::jit::IValue> jit_inputs_ivalues;
1919
std::vector<torch::jit::IValue> trt_inputs_ivalues;
2020
for (auto in_shape : input_shapes) {
@@ -23,26 +23,27 @@ TEST(Partitioning, StitchSegmentedBlockCorrectly) {
2323
trt_inputs_ivalues.push_back(in.clone());
2424
}
2525

26-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
26+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
2727
trtorch::core::CompileSpec cfg(input_ranges);
2828
cfg.partition_info.enabled = true;
29+
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
2930

3031
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
3132
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
3233
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
33-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
34+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-8));
3435
}
3536

36-
TEST(Partitioning, StitchSegmentedBlockCorrectlyEdge) {
37+
TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
3738
torch::jit::script::Module mod;
3839
try {
39-
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
40+
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
4041
} catch (const c10::Error& e) {
4142
std::cerr << "error loading the model\n";
4243
return;
4344
}
4445

45-
const std::vector<std::vector<int64_t>> input_shapes = {{3, 3, 16, 16}};
46+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
4647
std::vector<torch::jit::IValue> jit_inputs_ivalues;
4748
std::vector<torch::jit::IValue> trt_inputs_ivalues;
4849
for (auto in_shape : input_shapes) {
@@ -51,9 +52,10 @@ TEST(Partitioning, StitchSegmentedBlockCorrectlyEdge) {
5152
trt_inputs_ivalues.push_back(in.clone());
5253
}
5354

54-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
55+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
5556
trtorch::core::CompileSpec cfg(input_ranges);
5657
cfg.partition_info.enabled = true;
58+
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
5759

5860
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
5961
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);

tests/core/partitioning/test_stitched_graph.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,38 @@ bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
2222
return true;
2323
}
2424

25-
TEST(Partitioning, StitchSegmentedBlockCorrectly) {
25+
TEST(Partitioning, StitchResNet50SegmentedBlockCorrectly) {
2626
torch::jit::script::Module mod;
2727
try {
28-
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
28+
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
2929
} catch (const c10::Error& e) {
3030
std::cerr << "error loading the model\n";
3131
return;
3232
}
3333

34-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
34+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
3535
trtorch::core::CompileSpec cfg(input_ranges);
3636
cfg.partition_info.enabled = true;
37+
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
3738
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
3839
auto g = new_mod.get_method("forward").graph();
3940
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
4041
}
4142

42-
TEST(Partitioning, StitchSegmentedBlockCorrectlyEdge) {
43+
TEST(Partitioning, StitchMobileNetSegmentedBlockCorrectlyEdge) {
4344
torch::jit::script::Module mod;
4445
try {
45-
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
46+
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
4647
} catch (const c10::Error& e) {
4748
std::cerr << "error loading the model\n";
4849
return;
4950
}
5051

51-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
52+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
5253
trtorch::core::CompileSpec cfg(input_ranges);
5354
cfg.partition_info.enabled = true;
55+
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
56+
5457
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
5558
auto g = new_mod.get_method("forward").graph();
5659
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));

tests/core/partitioning/test_tensorrt_conversion.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "core/compiler.h"
33
#include "core/util/trt_util.h"
44
#include "gtest/gtest.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
56
#include "torch/script.h"
67

78
int count_trt_engines(std::shared_ptr<torch::jit::Graph> g) {
@@ -14,38 +15,40 @@ int count_trt_engines(std::shared_ptr<torch::jit::Graph> g) {
1415
return count;
1516
}
1617

17-
TEST(Partitioning, ConvertSegmentedBlockCorrectly) {
18+
TEST(Partitioning, ConvertResNet50SegmentedBlockCorrectly) {
1819
torch::jit::script::Module mod;
1920
try {
20-
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
21+
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
2122
} catch (const c10::Error& e) {
2223
std::cerr << "error loading the model\n";
2324
return;
2425
}
2526

26-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
27+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
2728
trtorch::core::CompileSpec cfg(input_ranges);
2829
cfg.partition_info.enabled = true;
30+
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
2931
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
3032
auto g = new_mod.get_method("forward").graph();
3133
int count = count_trt_engines(g);
32-
ASSERT_TRUE(count == 2);
34+
ASSERT_TRUE(count == 17);
3335
}
3436

35-
TEST(Partitioning, ConvertSegmentedBlockCorrectlyEdge) {
37+
TEST(Partitioning, ConvertMobileNetSegmentedBlockCorrectly) {
3638
torch::jit::script::Module mod;
3739
try {
38-
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
40+
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
3941
} catch (const c10::Error& e) {
4042
std::cerr << "error loading the model\n";
4143
return;
4244
}
4345

44-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
46+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
4547
trtorch::core::CompileSpec cfg(input_ranges);
4648
cfg.partition_info.enabled = true;
49+
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
4750
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
4851
auto g = new_mod.get_method("forward").graph();
4952
int count = count_trt_engines(g);
50-
ASSERT_TRUE(count == 2);
53+
ASSERT_TRUE(count == 11);
5154
}

0 commit comments

Comments
 (0)