Skip to content

Commit f722035

Browse files
committed
tests: use IRParser in test_tensorrt_conversion and test_stitched_graph
Signed-off-by: Bo Wang <[email protected]>
1 parent 437670e commit f722035

File tree

3 files changed

+206
-48
lines changed

3 files changed

+206
-48
lines changed

core/compiler.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ void AddSegmentedBlockToGraph(
172172
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
173173
}
174174

175-
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
176175
return;
177176
}
178177

@@ -187,7 +186,6 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
187186
if (method.name().rfind("_", 0)) {
188187
auto new_g = std::make_shared<torch::jit::Graph>();
189188
auto graph_and_parameters = lowering::Lower(mod, method.name());
190-
// LOG_INFO(*(method.graph()) << "Original graph\n");
191189

192190
auto g = graph_and_parameters.first;
193191
auto params = graph_and_parameters.second;
@@ -204,6 +202,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
204202
int trt_engine_id = 1;
205203
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
206204
for (auto& seg_block : segmented_blocks) {
205+
LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n");
207206
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
208207
std::vector<ir::InputRange> input_ranges;
209208
for (auto& shape : seg_block.in_shape()) {

tests/core/partitioning/test_stitched_graph.cpp

Lines changed: 102 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "core/compiler.h"
44
#include "core/util/trt_util.h"
55
#include "gtest/gtest.h"
6+
#include "torch/csrc/jit/ir/constants.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
68
#include "torch/script.h"
79

810
bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
@@ -22,39 +24,117 @@ bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
2224
return true;
2325
}
2426

25-
TEST(Partitioning, StitchResNet50SegmentedBlockCorrectly) {
26-
torch::jit::script::Module mod;
27-
try {
28-
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
29-
} catch (const c10::Error& e) {
30-
std::cerr << "error loading the model\n";
31-
return;
27+
TEST(Partitioning, StitchSequentialModelSegmentedBlockCorrectly) {
28+
const auto graph = R"IR(
29+
graph(%0 : Tensor,
30+
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
31+
%b1 : Float(32),
32+
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
33+
%b2 : Float(16),
34+
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
35+
%b3 : Float(8)):
36+
%2 : int[] = prim::Constant[value=[1, 1]]()
37+
%3 : int = prim::Constant[value=1]()
38+
%10 : bool = prim::Constant[value=0]()
39+
%11 : int[] = prim::Constant[value=[0, 0]]()
40+
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
41+
%13 : Tensor = aten::relu(%12)
42+
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
43+
%15 : Tensor = aten::log_sigmoid(%14)
44+
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
45+
return (%16))IR";
46+
47+
auto parsed_g = std::make_shared<torch::jit::Graph>();
48+
torch::jit::parseIR(graph, parsed_g.get());
49+
50+
auto g = std::make_shared<torch::jit::Graph>();
51+
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}, {8, 16, 3, 3}, {8}};
52+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
53+
for (size_t i = 0; i < all_shapes.size(); ++i) {
54+
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
55+
torch::jit::IValue cur_val = in.clone();
56+
auto new_val = g->insertConstant(cur_val);
57+
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
58+
}
59+
for (auto node : parsed_g->nodes()) {
60+
if (node->kind() == torch::jit::prim::Constant)
61+
continue;
62+
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
3263
}
64+
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);
3365

34-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
66+
std::vector<trtorch::core::ir::InputRange> input_ranges;
67+
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
3568
trtorch::core::CompileSpec cfg(input_ranges);
3669
cfg.partition_info.enabled = true;
37-
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
70+
torch::jit::script::Module mod(c10::QualifiedName("module"));
71+
72+
auto self = g->insertInput(0, "self_1");
73+
self->setType(mod.type());
74+
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
75+
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
76+
mod.type()->addMethod(cur_method);
77+
cur_method->setSchema(schema);
78+
3879
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
39-
auto g = new_mod.get_method("forward").graph();
40-
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
80+
auto fallback_g = new_mod.get_method("forward").graph();
81+
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g));
4182
}
4283

43-
TEST(Partitioning, StitchMobileNetSegmentedBlockCorrectlyEdge) {
44-
torch::jit::script::Module mod;
45-
try {
46-
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
47-
} catch (const c10::Error& e) {
48-
std::cerr << "error loading the model\n";
49-
return;
84+
TEST(Partitioning, StitchBranchModelSegmentedBlockCorrectly) {
85+
const auto graph = R"IR(
86+
graph(%0 : Tensor,
87+
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
88+
%2 : Float(32),
89+
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
90+
%4 : Float(16)):
91+
%5 : int[] = prim::Constant[value=[0, 0]]()
92+
%6 : int[] = prim::Constant[value=[2, 2]]()
93+
%7 : bool = prim::Constant[value=0]()
94+
%8 : int[] = prim::Constant[value=[1, 1]]()
95+
%9 : int = prim::Constant[value=1]()
96+
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
97+
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
98+
%12: Tensor = aten::log_sigmoid(%10)
99+
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
100+
%14 : Tensor = aten::relu(%11)
101+
%15 : Tensor = aten::add(%13, %14, %9)
102+
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
103+
return (%16))IR";
104+
105+
auto parsed_g = std::make_shared<torch::jit::Graph>();
106+
torch::jit::parseIR(graph, parsed_g.get());
107+
108+
auto g = std::make_shared<torch::jit::Graph>();
109+
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}};
110+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
111+
for (size_t i = 0; i < all_shapes.size(); ++i) {
112+
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
113+
torch::jit::IValue cur_val = in.clone();
114+
auto new_val = g->insertConstant(cur_val);
115+
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
50116
}
117+
for (auto node : parsed_g->nodes()) {
118+
if (node->kind() == torch::jit::prim::Constant)
119+
continue;
120+
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
121+
}
122+
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);
51123

52-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
124+
std::vector<trtorch::core::ir::InputRange> input_ranges;
125+
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
53126
trtorch::core::CompileSpec cfg(input_ranges);
54127
cfg.partition_info.enabled = true;
55-
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
128+
torch::jit::script::Module mod(c10::QualifiedName("module"));
129+
130+
auto self = g->insertInput(0, "self_1");
131+
self->setType(mod.type());
132+
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
133+
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
134+
mod.type()->addMethod(cur_method);
135+
cur_method->setSchema(schema);
56136

57137
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
58-
auto g = new_mod.get_method("forward").graph();
59-
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
138+
auto fallback_g = new_mod.get_method("forward").graph();
139+
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g));
60140
}

tests/core/partitioning/test_tensorrt_conversion.cpp

Lines changed: 103 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,119 @@ int count_trt_engines(std::shared_ptr<torch::jit::Graph> g) {
1515
return count;
1616
}
1717

18-
TEST(Partitioning, ConvertResNet50SegmentedBlockCorrectly) {
19-
torch::jit::script::Module mod;
20-
try {
21-
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
22-
} catch (const c10::Error& e) {
23-
std::cerr << "error loading the model\n";
24-
return;
18+
TEST(Partitioning, ConvertSequentialModelSegmentedBlockCorrectly) {
19+
const auto graph = R"IR(
20+
graph(%0 : Tensor,
21+
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
22+
%b1 : Float(32),
23+
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
24+
%b2 : Float(16),
25+
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
26+
%b3 : Float(8)):
27+
%2 : int[] = prim::Constant[value=[1, 1]]()
28+
%3 : int = prim::Constant[value=1]()
29+
%10 : bool = prim::Constant[value=0]()
30+
%11 : int[] = prim::Constant[value=[0, 0]]()
31+
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
32+
%13 : Tensor = aten::relu(%12)
33+
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
34+
%15 : Tensor = aten::log_sigmoid(%14)
35+
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
36+
return (%16))IR";
37+
38+
auto parsed_g = std::make_shared<torch::jit::Graph>();
39+
torch::jit::parseIR(graph, parsed_g.get());
40+
41+
auto g = std::make_shared<torch::jit::Graph>();
42+
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}, {8, 16, 3, 3}, {8}};
43+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
44+
for (size_t i = 0; i < all_shapes.size(); ++i) {
45+
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
46+
torch::jit::IValue cur_val = in.clone();
47+
auto new_val = g->insertConstant(cur_val);
48+
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
49+
}
50+
for (auto node : parsed_g->nodes()) {
51+
if (node->kind() == torch::jit::prim::Constant)
52+
continue;
53+
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
2554
}
55+
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);
2656

27-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
57+
std::vector<trtorch::core::ir::InputRange> input_ranges;
58+
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
2859
trtorch::core::CompileSpec cfg(input_ranges);
2960
cfg.partition_info.enabled = true;
30-
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
61+
torch::jit::script::Module mod(c10::QualifiedName("module"));
62+
63+
auto self = g->insertInput(0, "self_1");
64+
self->setType(mod.type());
65+
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
66+
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
67+
mod.type()->addMethod(cur_method);
68+
cur_method->setSchema(schema);
69+
3170
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
32-
auto g = new_mod.get_method("forward").graph();
33-
int count = count_trt_engines(g);
34-
ASSERT_TRUE(count == 17);
71+
auto fallback_g = new_mod.get_method("forward").graph();
72+
int count = count_trt_engines(fallback_g);
73+
ASSERT_TRUE(count == 2);
3574
}
3675

37-
TEST(Partitioning, ConvertMobileNetSegmentedBlockCorrectly) {
38-
torch::jit::script::Module mod;
39-
try {
40-
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
41-
} catch (const c10::Error& e) {
42-
std::cerr << "error loading the model\n";
43-
return;
76+
TEST(Partitioning, ConvertBranchModelSegmentedBlockCorrectly) {
77+
const auto graph = R"IR(
78+
graph(%0 : Tensor,
79+
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
80+
%2 : Float(32),
81+
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
82+
%4 : Float(16)):
83+
%5 : int[] = prim::Constant[value=[0, 0]]()
84+
%6 : int[] = prim::Constant[value=[2, 2]]()
85+
%7 : bool = prim::Constant[value=0]()
86+
%8 : int[] = prim::Constant[value=[1, 1]]()
87+
%9 : int = prim::Constant[value=1]()
88+
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
89+
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
90+
%12: Tensor = aten::log_sigmoid(%10)
91+
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
92+
%14 : Tensor = aten::relu(%11)
93+
%15 : Tensor = aten::add(%13, %14, %9)
94+
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
95+
return (%16))IR";
96+
97+
auto parsed_g = std::make_shared<torch::jit::Graph>();
98+
torch::jit::parseIR(graph, parsed_g.get());
99+
100+
auto g = std::make_shared<torch::jit::Graph>();
101+
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}};
102+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
103+
for (size_t i = 0; i < all_shapes.size(); ++i) {
104+
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
105+
torch::jit::IValue cur_val = in.clone();
106+
auto new_val = g->insertConstant(cur_val);
107+
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
108+
}
109+
for (auto node : parsed_g->nodes()) {
110+
if (node->kind() == torch::jit::prim::Constant)
111+
continue;
112+
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
44113
}
114+
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);
45115

46-
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
116+
std::vector<trtorch::core::ir::InputRange> input_ranges;
117+
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
47118
trtorch::core::CompileSpec cfg(input_ranges);
48119
cfg.partition_info.enabled = true;
49-
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
120+
torch::jit::script::Module mod(c10::QualifiedName("module"));
121+
122+
auto self = g->insertInput(0, "self_1");
123+
self->setType(mod.type());
124+
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
125+
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
126+
mod.type()->addMethod(cur_method);
127+
cur_method->setSchema(schema);
128+
50129
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
51-
auto g = new_mod.get_method("forward").graph();
52-
int count = count_trt_engines(g);
53-
ASSERT_TRUE(count == 11);
130+
auto fallback_g = new_mod.get_method("forward").graph();
131+
int count = count_trt_engines(fallback_g);
132+
ASSERT_TRUE(count == 2);
54133
}

0 commit comments

Comments
 (0)