Skip to content

Commit 3d39d7c

Browse files
committed
test: add tests for TRT conversion, graph stitch, results comparison in partitioning
Signed-off-by: Bo Wang <[email protected]>
1 parent 2840281 commit 3d39d7c

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

tests/core/partitioning/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,25 @@ partitioning_test(
2020
name = "test_shape_analysis",
2121
)
2222

23+
partitioning_test(
24+
name = "test_tensorrt_conversion",
25+
)
26+
27+
partitioning_test(
28+
name = "test_stitched_graph",
29+
)
30+
31+
partitioning_test(
32+
name = "test_fallback_graph_output",
33+
)
34+
2335
test_suite(
2436
name = "partitioning_test",
2537
tests = [
2638
":test_segmentation",
2739
":test_shape_analysis",
40+
":test_tensorrt_conversion",
41+
":test_stitched_graph",
42+
":test_fallback_graph_output"
2843
]
2944
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "gtest/gtest.h"
4+
#include "torch/script.h"
5+
#include "core/compiler.h"
6+
#include "tests/util/util.h"
7+
8+
9+
TEST(Partitioning, StitchSegmentedBlockCorrectly) {
10+
torch::jit::script::Module mod;
11+
try {
12+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
13+
} catch (const c10::Error& e) {
14+
std::cerr << "error loading the model\n";
15+
return;
16+
}
17+
18+
const std::vector<std::vector<int64_t>> input_shapes = {{3, 3, 16, 16}};
19+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
20+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
21+
for (auto in_shape : input_shapes) {
22+
auto in = at::randint(5, in_shape, {at::kCUDA});
23+
jit_inputs_ivalues.push_back(in.clone());
24+
trt_inputs_ivalues.push_back(in.clone());
25+
}
26+
27+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
28+
trtorch::core::CompileSpec cfg(input_ranges);
29+
cfg.partition_info.enabled = true;
30+
31+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
32+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
33+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
34+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
35+
}
36+
37+
TEST(Partitioning, StitchSegmentedBlockCorrectlyEdge) {
38+
torch::jit::script::Module mod;
39+
try {
40+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
41+
} catch (const c10::Error& e) {
42+
std::cerr << "error loading the model\n";
43+
return;
44+
}
45+
46+
const std::vector<std::vector<int64_t>> input_shapes = {{3, 3, 16, 16}};
47+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
48+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
49+
for (auto in_shape : input_shapes) {
50+
auto in = at::randint(5, in_shape, {at::kCUDA});
51+
jit_inputs_ivalues.push_back(in.clone());
52+
trt_inputs_ivalues.push_back(in.clone());
53+
}
54+
55+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
56+
trtorch::core::CompileSpec cfg(input_ranges);
57+
cfg.partition_info.enabled = true;
58+
59+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
60+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
61+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
62+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
63+
}
64+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "gtest/gtest.h"
4+
#include "torch/script.h"
5+
#include "core/compiler.h"
6+
#include "core/util/trt_util.h"
7+
8+
bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
9+
std::unordered_set<torch::jit::Value*> available_values;
10+
for (auto v : g->inputs()) {
11+
available_values.insert(v);
12+
}
13+
for (const auto n : g->nodes()) {
14+
for (auto input : n->inputs()) {
15+
if (!available_values.count(input))
16+
return false;
17+
}
18+
for (auto output : n->outputs()) {
19+
available_values.insert(output);
20+
}
21+
}
22+
return true;
23+
}
24+
25+
TEST(Partitioning, StitchSegmentedBlockCorrectly) {
26+
torch::jit::script::Module mod;
27+
try {
28+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
29+
} catch (const c10::Error& e) {
30+
std::cerr << "error loading the model\n";
31+
return;
32+
}
33+
34+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
35+
trtorch::core::CompileSpec cfg(input_ranges);
36+
cfg.partition_info.enabled = true;
37+
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
38+
auto g = new_mod.get_method("forward").graph();
39+
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
40+
}
41+
42+
43+
TEST(Partitioning, StitchSegmentedBlockCorrectlyEdge) {
44+
torch::jit::script::Module mod;
45+
try {
46+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
47+
} catch (const c10::Error& e) {
48+
std::cerr << "error loading the model\n";
49+
return;
50+
}
51+
52+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
53+
trtorch::core::CompileSpec cfg(input_ranges);
54+
cfg.partition_info.enabled = true;
55+
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
56+
auto g = new_mod.get_method("forward").graph();
57+
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
58+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/script.h"
4+
#include "core/compiler.h"
5+
#include "core/util/trt_util.h"
6+
7+
int count_trt_engines(std::shared_ptr<torch::jit::Graph> g) {
8+
int count = 0;
9+
for (const auto n : g->nodes()) {
10+
if (n->kind().toQualString() == std::string("tensorrt::execute_engine")) {
11+
++count;
12+
}
13+
}
14+
return count;
15+
}
16+
17+
TEST(Partitioning, ConvertSegmentedBlockCorrectly) {
18+
torch::jit::script::Module mod;
19+
try {
20+
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
21+
} catch (const c10::Error& e) {
22+
std::cerr << "error loading the model\n";
23+
return;
24+
}
25+
26+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
27+
trtorch::core::CompileSpec cfg(input_ranges);
28+
cfg.partition_info.enabled = true;
29+
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
30+
auto g = new_mod.get_method("forward").graph();
31+
int count = count_trt_engines(g);
32+
ASSERT_TRUE(count == 2);
33+
}
34+
35+
36+
TEST(Partitioning, ConvertSegmentedBlockCorrectlyEdge) {
37+
torch::jit::script::Module mod;
38+
try {
39+
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
40+
} catch (const c10::Error& e) {
41+
std::cerr << "error loading the model\n";
42+
return;
43+
}
44+
45+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
46+
trtorch::core::CompileSpec cfg(input_ranges);
47+
cfg.partition_info.enabled = true;
48+
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
49+
auto g = new_mod.get_method("forward").graph();
50+
int count = count_trt_engines(g);
51+
ASSERT_TRUE(count == 2);
52+
}

0 commit comments

Comments
 (0)