Skip to content

Commit 3a72dc3

Browse files
committed
test: remove the jit file dependency from tests
Signed-off-by: Bo Wang <[email protected]>
1 parent 116b001 commit 3a72dc3

File tree

6 files changed

+224
-121
lines changed

6 files changed

+224
-121
lines changed

core/compiler.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,6 @@
2727
namespace trtorch {
2828
namespace core {
2929

30-
c10::FunctionSchema GenerateGraphSchema(
31-
torch::jit::script::Module mod,
32-
std::string method_name,
33-
std::shared_ptr<torch::jit::Graph>& g) {
34-
std::vector<c10::Argument> args;
35-
for (auto in : g->inputs()) {
36-
args.push_back(c10::Argument(in->debugName(), in->type()));
37-
}
38-
39-
std::vector<c10::Argument> returns;
40-
for (auto out : g->outputs()) {
41-
returns.push_back(c10::Argument(out->debugName(), out->type()));
42-
}
43-
44-
return c10::FunctionSchema(method_name, method_name, args, returns);
45-
}
46-
4730
void AddEngineToGraph(
4831
torch::jit::script::Module mod,
4932
std::shared_ptr<torch::jit::Graph>& g,
@@ -246,7 +229,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
246229
LOG_INFO(*new_g << "(FallbackGraph)\n");
247230

248231
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
249-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
232+
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
250233
new_mod.type()->addMethod(new_method);
251234
new_method->setSchema(schema);
252235
}
@@ -272,7 +255,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
272255
auto new_g = std::make_shared<torch::jit::Graph>();
273256
AddEngineToGraph(new_mod, new_g, engine);
274257
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
275-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
258+
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
276259
new_mod.type()->addMethod(new_method);
277260
new_method->setSchema(schema);
278261
}

core/partitioning/shape_analysis.cpp

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,6 @@ std::vector<torch::jit::IValue> generateRandomInputs(std::vector<ir::InputRange>
1919
return random_inputs;
2020
}
2121

22-
c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
23-
std::vector<c10::Argument> args;
24-
for (auto in : g->inputs()) {
25-
args.push_back(c10::Argument(in->debugName(), in->type()));
26-
}
27-
28-
std::vector<c10::Argument> returns;
29-
for (auto out : g->outputs()) {
30-
returns.push_back(c10::Argument(out->debugName(), out->type()));
31-
}
32-
33-
return c10::FunctionSchema(method_name, method_name, args, returns);
34-
}
35-
3622
void getSegmentsOutputByRunning(
3723
SegmentedBlock& seg_block,
3824
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
@@ -56,7 +42,7 @@ void getSegmentsOutputByRunning(
5642
self->setType(cur_mod.type());
5743

5844
auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), copy_g);
59-
auto schema = getFunctionSchema(cur_method->name(), copy_g);
45+
auto schema = util::GenerateGraphSchema(cur_method->name(), copy_g);
6046
cur_mod.type()->addMethod(cur_method);
6147
cur_method->setSchema(schema);
6248

core/util/trt_util.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,20 @@ c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype) {
350350
}
351351
}
352352

353+
c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
354+
std::vector<c10::Argument> args;
355+
for (auto in : g->inputs()) {
356+
args.push_back(c10::Argument(in->debugName(), in->type()));
357+
}
358+
359+
std::vector<c10::Argument> returns;
360+
for (auto out : g->outputs()) {
361+
returns.push_back(c10::Argument(out->debugName(), out->type()));
362+
}
363+
364+
return c10::FunctionSchema(method_name, method_name, args, returns);
365+
}
366+
353367
torch::jit::Value* getOrAddInputForValue(
354368
torch::jit::Value* old_value,
355369
std::shared_ptr<torch::jit::Graph>& graph,

core/util/trt_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ATen/Tensor.h"
44
#include "ATen/core/List.h"
55
#include "NvInfer.h"
6+
#include "torch/csrc/jit/api/module.h"
67
#include "torch/csrc/jit/ir/ir.h"
78

89
namespace nvinfer1 {
@@ -108,6 +109,7 @@ std::string toStr(nvinfer1::Dims d);
108109
at::ScalarType toATenDType(nvinfer1::DataType t);
109110
nvinfer1::DataType toTRTDataType(at::ScalarType t);
110111
c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype);
112+
c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g);
111113
torch::jit::Node* cloneNode(
112114
torch::jit::Node* node,
113115
std::shared_ptr<torch::jit::Graph>& graph,

tests/core/partitioning/test_segmentation.cpp

Lines changed: 136 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <string>
2-
#include "core/lowering/lowering.h"
32
#include "core/partitioning/partitioning.h"
43
#include "gtest/gtest.h"
54
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
66
#include "torch/script.h"
77
#include "trtorch/trtorch.h"
88

@@ -41,16 +41,28 @@ bool checkSegmentedBlockNodesMapping(
4141
return true;
4242
}
4343

44-
TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
45-
torch::jit::script::Module mod;
46-
try {
47-
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
48-
} catch (const c10::Error& e) {
49-
std::cerr << "error loading the model\n";
50-
return;
51-
}
52-
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
53-
auto g = graph_and_parameters.first;
44+
TEST(Partitioning, SegmentSequentialModelCorrectly) {
45+
const auto graph = R"IR(
46+
graph(%0 : Tensor,
47+
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
48+
%b1 : Float(32),
49+
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
50+
%b2 : Float(16),
51+
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
52+
%b3 : Float(8)):
53+
%2 : int[] = prim::Constant[value=[1, 1]]()
54+
%3 : int = prim::Constant[value=1]()
55+
%10 : bool = prim::Constant[value=0]()
56+
%11 : int[] = prim::Constant[value=[0, 0]]()
57+
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
58+
%13 : Tensor = aten::relu(%12)
59+
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
60+
%15 : Tensor = aten::log_sigmoid(%14)
61+
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
62+
return (%16))IR";
63+
64+
auto g = std::make_shared<torch::jit::Graph>();
65+
torch::jit::parseIR(graph, g.get());
5466

5567
trtorch::core::partitioning::PartitionInfo partition_info;
5668
partition_info.enabled = true;
@@ -61,16 +73,28 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
6173
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3}, {4}}));
6274
}
6375

64-
TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
65-
torch::jit::script::Module mod;
66-
try {
67-
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
68-
} catch (const c10::Error& e) {
69-
std::cerr << "error loading the model\n";
70-
return;
71-
}
72-
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
73-
auto g = graph_and_parameters.first;
76+
TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
77+
const auto graph = R"IR(
78+
graph(%0 : Tensor,
79+
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
80+
%b1 : Float(32),
81+
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
82+
%b2 : Float(16),
83+
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
84+
%b3 : Float(8)):
85+
%2 : int[] = prim::Constant[value=[1, 1]]()
86+
%3 : int = prim::Constant[value=1]()
87+
%10 : bool = prim::Constant[value=0]()
88+
%11 : int[] = prim::Constant[value=[0, 0]]()
89+
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
90+
%13 : Tensor = aten::relu(%12)
91+
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
92+
%15 : Tensor = aten::log_sigmoid(%14)
93+
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
94+
return (%16))IR";
95+
96+
auto g = std::make_shared<torch::jit::Graph>();
97+
torch::jit::parseIR(graph, g.get());
7498

7599
trtorch::core::partitioning::PartitionInfo partition_info;
76100
partition_info.enabled = true;
@@ -82,16 +106,28 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
82106
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
83107
}
84108

85-
TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
86-
torch::jit::script::Module mod;
87-
try {
88-
mod = torch::jit::load("tests/core/partitioning/test_base_model.jit");
89-
} catch (const c10::Error& e) {
90-
std::cerr << "error loading the model\n";
91-
return;
92-
}
93-
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
94-
auto g = graph_and_parameters.first;
109+
TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
110+
const auto graph = R"IR(
111+
graph(%0 : Tensor,
112+
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
113+
%b1 : Float(32),
114+
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
115+
%b2 : Float(16),
116+
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
117+
%b3 : Float(8)):
118+
%2 : int[] = prim::Constant[value=[1, 1]]()
119+
%3 : int = prim::Constant[value=1]()
120+
%10 : bool = prim::Constant[value=0]()
121+
%11 : int[] = prim::Constant[value=[0, 0]]()
122+
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
123+
%13 : Tensor = aten::relu(%12)
124+
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
125+
%15 : Tensor = aten::log_sigmoid(%14)
126+
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
127+
return (%16))IR";
128+
129+
auto g = std::make_shared<torch::jit::Graph>();
130+
torch::jit::parseIR(graph, g.get());
95131

96132
trtorch::core::partitioning::PartitionInfo partition_info;
97133
partition_info.enabled = true;
@@ -103,16 +139,29 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
103139
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0}, {1}, {2}, {3}, {4}}));
104140
}
105141

106-
TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
107-
torch::jit::script::Module mod;
108-
try {
109-
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
110-
} catch (const c10::Error& e) {
111-
std::cerr << "error loading the model\n";
112-
return;
113-
}
114-
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
115-
auto g = graph_and_parameters.first;
142+
TEST(Partitioning, SegmentBranchModelCorrectly) {
143+
const auto graph = R"IR(
144+
graph(%0 : Tensor,
145+
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
146+
%2 : Float(32),
147+
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
148+
%4 : Float(16)):
149+
%5 : int[] = prim::Constant[value=[0, 0]]()
150+
%6 : int[] = prim::Constant[value=[2, 2]]()
151+
%7 : bool = prim::Constant[value=0]()
152+
%8 : int[] = prim::Constant[value=[1, 1]]()
153+
%9 : int = prim::Constant[value=1]()
154+
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
155+
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
156+
%12: Tensor = aten::log_sigmoid(%10)
157+
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
158+
%14 : Tensor = aten::relu(%11)
159+
%15 : Tensor = aten::add(%13, %14, %9)
160+
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
161+
return (%16))IR";
162+
163+
auto g = std::make_shared<torch::jit::Graph>();
164+
torch::jit::parseIR(graph, g.get());
116165

117166
trtorch::core::partitioning::PartitionInfo partition_info;
118167
partition_info.enabled = true;
@@ -123,16 +172,29 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
123172
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3, 4, 5, 6}}));
124173
}
125174

126-
TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
127-
torch::jit::script::Module mod;
128-
try {
129-
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
130-
} catch (const c10::Error& e) {
131-
std::cerr << "error loading the model\n";
132-
return;
133-
}
134-
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
135-
auto g = graph_and_parameters.first;
175+
TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) {
176+
const auto graph = R"IR(
177+
graph(%0 : Tensor,
178+
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
179+
%2 : Float(32),
180+
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
181+
%4 : Float(16)):
182+
%5 : int[] = prim::Constant[value=[0, 0]]()
183+
%6 : int[] = prim::Constant[value=[2, 2]]()
184+
%7 : bool = prim::Constant[value=0]()
185+
%8 : int[] = prim::Constant[value=[1, 1]]()
186+
%9 : int = prim::Constant[value=1]()
187+
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
188+
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
189+
%12: Tensor = aten::log_sigmoid(%10)
190+
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
191+
%14 : Tensor = aten::relu(%11)
192+
%15 : Tensor = aten::add(%13, %14, %9)
193+
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
194+
return (%16))IR";
195+
196+
auto g = std::make_shared<torch::jit::Graph>();
197+
torch::jit::parseIR(graph, g.get());
136198

137199
trtorch::core::partitioning::PartitionInfo partition_info;
138200
partition_info.enabled = true;
@@ -144,16 +206,29 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
144206
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4, 5, 6}}));
145207
}
146208

147-
TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
148-
torch::jit::script::Module mod;
149-
try {
150-
mod = torch::jit::load("tests/core/partitioning/test_edge_model.jit");
151-
} catch (const c10::Error& e) {
152-
std::cerr << "error loading the model\n";
153-
return;
154-
}
155-
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
156-
auto g = graph_and_parameters.first;
209+
TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) {
210+
const auto graph = R"IR(
211+
graph(%0 : Tensor,
212+
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
213+
%2 : Float(32),
214+
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
215+
%4 : Float(16)):
216+
%5 : int[] = prim::Constant[value=[0, 0]]()
217+
%6 : int[] = prim::Constant[value=[2, 2]]()
218+
%7 : bool = prim::Constant[value=0]()
219+
%8 : int[] = prim::Constant[value=[1, 1]]()
220+
%9 : int = prim::Constant[value=1]()
221+
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
222+
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
223+
%12: Tensor = aten::log_sigmoid(%10)
224+
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
225+
%14 : Tensor = aten::relu(%11)
226+
%15 : Tensor = aten::add(%13, %14, %9)
227+
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
228+
return (%16))IR";
229+
230+
auto g = std::make_shared<torch::jit::Graph>();
231+
torch::jit::parseIR(graph, g.get());
157232

158233
trtorch::core::partitioning::PartitionInfo partition_info;
159234
partition_info.enabled = true;

0 commit comments

Comments
 (0)