Skip to content

Commit 6d826d3

Browse files
committed
test: add tests for graph segmentation and shape analysis in partitioning
Signed-off-by: Bo Wang <[email protected]>
1 parent 569d011 commit 6d826d3

File tree

6 files changed

+316
-0
lines changed

6 files changed

+316
-0
lines changed

tests/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ test_suite(
33
tests = [
44
"//tests/core/conversion:conversion_tests",
55
"//tests/core/lowering:lowering_tests",
6+
"//tests/core/partitioning::partitioning_tests"
67
],
78
)

tests/core/partitioning/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
load("//tests/core/partitioning:partitioning_test.bzl", "partitioning_test")
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
10+
filegroup(
11+
name = "jit_models",
12+
srcs = glob(["**/*.jit"])
13+
)
14+
15+
partitioning_test(
16+
name = "test_segmentation",
17+
)
18+
19+
partitioning_test(
20+
name = "test_shape_analysis",
21+
)
22+
23+
test_suite(
24+
name = "partitioning_test",
25+
tests = [
26+
":test_segmentation",
27+
":test_shape_analysis",
28+
]
29+
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import pdb
3+
import trtorch
4+
5+
class FallbackBase(torch.nn.Module):
6+
def __init__(self):
7+
super(FallbackBase, self).__init__()
8+
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
9+
self.conv2 = torch.nn.Conv2d(32, 16, 3, 1, 1)
10+
self.relu1 = torch.nn.ReLU()
11+
self.log_sig = torch.nn.LogSigmoid()
12+
self.conv3 = torch.nn.Conv2d(16, 8, 3, 1, 1)
13+
14+
def forward(self, x):
15+
x = self.conv1(x)
16+
x = self.relu1(x)
17+
x = self.conv2(x)
18+
x = self.log_sig(x)
19+
x = self.conv3(x)
20+
return x
21+
22+
class FallbackEdge(torch.nn.Module):
23+
def __init__(self):
24+
super(FallbackEdge, self).__init__()
25+
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
26+
self.log_sig = torch.nn.LogSigmoid()
27+
self.conv2 = torch.nn.Conv2d(32, 16, 3, 1, 1)
28+
self.relu = torch.nn.ReLU()
29+
self.pooling = torch.nn.MaxPool2d(2)
30+
31+
def forward(self, x):
32+
x = self.conv1(x)
33+
x1 = self.log_sig(x)
34+
x1 = self.conv2(x1)
35+
x2 = self.conv2(x)
36+
x2 = self.relu(x2)
37+
x = x1 + x2
38+
x = self.pooling(x)
39+
return x
40+
41+
def main():
42+
model1 = FallbackBase().eval().cuda()
43+
44+
scripted_model1 = torch.jit.script(model1)
45+
torch.jit.save(scripted_model1, 'test_base_model.jit')
46+
47+
model2 = FallbackEdge().eval().cuda()
48+
scripted_model2 = torch.jit.script(model2)
49+
torch.jit.save(scripted_model2, 'test_edge_model.jit')
50+
51+
if __name__ == "__main__":
52+
main()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
def partitioning_test(name, visibility=None):
2+
native.cc_test(
3+
name = name,
4+
srcs = [name + ".cpp"],
5+
visibility = visibility,
6+
deps = [
7+
"//tests/util",
8+
"//core",
9+
"@googletest//:gtest_main",
10+
] + select({
11+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
12+
"//conditions:default": ["@libtorch//:libtorch"],
13+
}),
14+
data = [
15+
":jit_models"
16+
],
17+
timeout="short"
18+
)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
4+
#include "torch/script.h"
5+
#include "trtorch/trtorch.h"
6+
#include "core/lowering/lowering.h"
7+
#include "core/partitioning/partitioning.h"
8+
9+
10+
bool checkSegmentedBlockNumber(std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
11+
trtorch::core::partitioning::SegmentedBlock::SegmentedBlockTarget target, int target_count) {
12+
for (auto &seg_block : segmented_blocks) {
13+
if (seg_block.target() == target) {
14+
target_count--;
15+
}
16+
}
17+
return target_count == 0;
18+
}
19+
20+
bool checkSegmentedBlockNodesMapping(std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
21+
std::shared_ptr<torch::jit::Graph> g, std::vector<std::vector<int>> nodes_index) {
22+
std::vector<torch::jit::Node*> graph_nodes;
23+
for (const auto n : g->nodes()) {
24+
if (n->kind() != torch::jit::prim::Constant) {
25+
graph_nodes.push_back(n);
26+
}
27+
}
28+
for (size_t i = 0; i < nodes_index.size(); ++i) {
29+
size_t seg_block_node_id = 0;
30+
for (int j : nodes_index[i]) {
31+
if (segmented_blocks[i].raw_nodes()[seg_block_node_id++] != graph_nodes[j]) {
32+
return false;
33+
}
34+
}
35+
if (seg_block_node_id != segmented_blocks[i].raw_nodes().size()) return false;
36+
}
37+
return true;
38+
}
39+
40+
TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
41+
torch::jit::script::Module mod;
42+
try {
43+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
44+
} catch (const c10::Error& e) {
45+
std::cerr << "error loading the model\n";
46+
return;
47+
}
48+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
49+
auto g = graph_and_parameters.first;
50+
51+
trtorch::core::conversion::TorchFallback fallback_info;
52+
fallback_info.enabled = true;
53+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
54+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 2));
55+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
56+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3}, {4}}));
57+
}
58+
59+
TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
60+
torch::jit::script::Module mod;
61+
try {
62+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
63+
} catch (const c10::Error& e) {
64+
std::cerr << "error loading the model\n";
65+
return;
66+
}
67+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
68+
auto g = graph_and_parameters.first;
69+
70+
trtorch::core::conversion::TorchFallback fallback_info;
71+
fallback_info.enabled = true;
72+
fallback_info.min_block_size = 3;
73+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
74+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 1));
75+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
76+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
77+
}
78+
79+
TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
80+
torch::jit::script::Module mod;
81+
try {
82+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
83+
} catch (const c10::Error& e) {
84+
std::cerr << "error loading the model\n";
85+
return;
86+
}
87+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
88+
auto g = graph_and_parameters.first;
89+
90+
trtorch::core::conversion::TorchFallback fallback_info;
91+
fallback_info.enabled = true;
92+
fallback_info.forced_fallback_operators.push_back("aten::relu");
93+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
94+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 3));
95+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 2));
96+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0}, {1}, {2}, {3}, {4}}));
97+
}
98+
99+
TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
100+
torch::jit::script::Module mod;
101+
try {
102+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
103+
} catch (const c10::Error& e) {
104+
std::cerr << "error loading the model\n";
105+
return;
106+
}
107+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
108+
auto g = graph_and_parameters.first;
109+
110+
trtorch::core::conversion::TorchFallback fallback_info;
111+
fallback_info.enabled = true;
112+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
113+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 2));
114+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
115+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3, 4, 5, 6}}));
116+
}
117+
118+
TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
119+
torch::jit::script::Module mod;
120+
try {
121+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
122+
} catch (const c10::Error& e) {
123+
std::cerr << "error loading the model\n";
124+
return;
125+
}
126+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
127+
auto g = graph_and_parameters.first;
128+
129+
trtorch::core::conversion::TorchFallback fallback_info;
130+
fallback_info.enabled = true;
131+
fallback_info.min_block_size = 3;
132+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
133+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 1));
134+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
135+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4, 5, 6}}));
136+
}
137+
138+
TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
139+
torch::jit::script::Module mod;
140+
try {
141+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
142+
} catch (const c10::Error& e) {
143+
std::cerr << "error loading the model\n";
144+
return;
145+
}
146+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
147+
auto g = graph_and_parameters.first;
148+
149+
trtorch::core::conversion::TorchFallback fallback_info;
150+
fallback_info.enabled = true;
151+
fallback_info.forced_fallback_operators.push_back("aten::relu");
152+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
153+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 3));
154+
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 2));
155+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3}, {4}, {5, 6}}));
156+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/script.h"
4+
#include "core/lowering/lowering.h"
5+
#include "core/partitioning/partitioning.h"
6+
#include "core/util/trt_util.h"
7+
8+
bool checkSegmentedBlockInputShape(std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks, std::vector<std::vector<std::vector<int>>> in_shape) {
9+
if (segmented_blocks.size() != in_shape.size()) return false;
10+
for (size_t i = 0; i < segmented_blocks.size(); ++i) {
11+
auto cur_block_in_shapes = segmented_blocks[i].in_shape();
12+
if (cur_block_in_shapes.size() != in_shape[i].size()) return false;
13+
for (size_t j = 0; j < cur_block_in_shapes.size(); ++j) {
14+
auto cur_input_shape = trtorch::core::util::toVec(cur_block_in_shapes[j]);
15+
for (size_t k = 0; k < cur_input_shape.size(); ++k) {
16+
if (cur_input_shape[k] != in_shape[i][j][k])
17+
return false;
18+
}
19+
}
20+
}
21+
return true;
22+
}
23+
24+
TEST(Partitioning, InferSegmentedBlockShapeCorrectly) {
25+
torch::jit::script::Module mod;
26+
try {
27+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
28+
} catch (const c10::Error& e) {
29+
std::cerr << "error loading the model\n";
30+
return;
31+
}
32+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
33+
auto g = graph_and_parameters.first;
34+
35+
trtorch::core::conversion::TorchFallback fallback_info;
36+
fallback_info.enabled = true;
37+
std::vector<trtorch::core::conversion::InputRange> input_ranges{trtorch::core::conversion::InputRange({3, 3, 16, 16})};
38+
39+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::Partition(g, input_ranges, fallback_info);
40+
ASSERT_TRUE(checkSegmentedBlockInputShape(segmented_blocks, {{{3, 3, 16, 16}}, {{3, 16, 16, 16}}, {{3, 16, 16, 16}}}));
41+
}
42+
43+
TEST(Partitioning, InferSegmentedBlockShapeCorrectlyEdge) {
44+
torch::jit::script::Module mod;
45+
try {
46+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
47+
} catch (const c10::Error& e) {
48+
std::cerr << "error loading the model\n";
49+
return;
50+
}
51+
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
52+
auto g = graph_and_parameters.first;
53+
54+
trtorch::core::conversion::TorchFallback fallback_info;
55+
fallback_info.enabled = true;
56+
std::vector<trtorch::core::conversion::InputRange> input_ranges{trtorch::core::conversion::InputRange({3, 3, 16, 16})};
57+
58+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::Partition(g, input_ranges, fallback_info);
59+
ASSERT_TRUE(checkSegmentedBlockInputShape(segmented_blocks, {{{3, 3, 16, 16}}, {{3, 32, 16, 16}}, {{3, 32, 16, 16}, {3, 16, 16, 16}}}));
60+
}

0 commit comments

Comments
 (0)