1
1
#include < string>
2
- #include " core/lowering/lowering.h"
3
2
#include " core/partitioning/partitioning.h"
4
3
#include " gtest/gtest.h"
5
4
#include " tests/util/util.h"
5
+ #include " torch/csrc/jit/ir/irparser.h"
6
6
#include " torch/script.h"
7
7
#include " trtorch/trtorch.h"
8
8
@@ -41,16 +41,28 @@ bool checkSegmentedBlockNodesMapping(
41
41
return true ;
42
42
}
43
43
44
- TEST (Partitioning, SegmentingGraphDefaultCorrectly) {
45
- torch::jit::script::Module mod;
46
- try {
47
- mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
48
- } catch (const c10::Error& e) {
49
- std::cerr << " error loading the model\n " ;
50
- return ;
51
- }
52
- auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
53
- auto g = graph_and_parameters.first ;
44
+ TEST (Partitioning, SegmentSequentialModelCorrectly) {
45
+ const auto graph = R"IR(
46
+ graph(%0 : Tensor,
47
+ %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
48
+ %b1 : Float(32),
49
+ %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
50
+ %b2 : Float(16),
51
+ %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
52
+ %b3 : Float(8)):
53
+ %2 : int[] = prim::Constant[value=[1, 1]]()
54
+ %3 : int = prim::Constant[value=1]()
55
+ %10 : bool = prim::Constant[value=0]()
56
+ %11 : int[] = prim::Constant[value=[0, 0]]()
57
+ %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
58
+ %13 : Tensor = aten::relu(%12)
59
+ %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
60
+ %15 : Tensor = aten::log_sigmoid(%14)
61
+ %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
62
+ return (%16))IR" ;
63
+
64
+ auto g = std::make_shared<torch::jit::Graph>();
65
+ torch::jit::parseIR (graph, g.get ());
54
66
55
67
trtorch::core::partitioning::PartitionInfo partition_info;
56
68
partition_info.enabled = true ;
@@ -61,16 +73,28 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
61
73
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 }, {4 }}));
62
74
}
63
75
64
- TEST (Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
65
- torch::jit::script::Module mod;
66
- try {
67
- mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
68
- } catch (const c10::Error& e) {
69
- std::cerr << " error loading the model\n " ;
70
- return ;
71
- }
72
- auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
73
- auto g = graph_and_parameters.first ;
76
+ TEST (Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
77
+ const auto graph = R"IR(
78
+ graph(%0 : Tensor,
79
+ %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
80
+ %b1 : Float(32),
81
+ %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
82
+ %b2 : Float(16),
83
+ %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
84
+ %b3 : Float(8)):
85
+ %2 : int[] = prim::Constant[value=[1, 1]]()
86
+ %3 : int = prim::Constant[value=1]()
87
+ %10 : bool = prim::Constant[value=0]()
88
+ %11 : int[] = prim::Constant[value=[0, 0]]()
89
+ %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
90
+ %13 : Tensor = aten::relu(%12)
91
+ %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
92
+ %15 : Tensor = aten::log_sigmoid(%14)
93
+ %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
94
+ return (%16))IR" ;
95
+
96
+ auto g = std::make_shared<torch::jit::Graph>();
97
+ torch::jit::parseIR (graph, g.get ());
74
98
75
99
trtorch::core::partitioning::PartitionInfo partition_info;
76
100
partition_info.enabled = true ;
@@ -82,16 +106,28 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
82
106
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 }}));
83
107
}
84
108
85
- TEST (Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
86
- torch::jit::script::Module mod;
87
- try {
88
- mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
89
- } catch (const c10::Error& e) {
90
- std::cerr << " error loading the model\n " ;
91
- return ;
92
- }
93
- auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
94
- auto g = graph_and_parameters.first ;
109
+ TEST (Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
110
+ const auto graph = R"IR(
111
+ graph(%0 : Tensor,
112
+ %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
113
+ %b1 : Float(32),
114
+ %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
115
+ %b2 : Float(16),
116
+ %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
117
+ %b3 : Float(8)):
118
+ %2 : int[] = prim::Constant[value=[1, 1]]()
119
+ %3 : int = prim::Constant[value=1]()
120
+ %10 : bool = prim::Constant[value=0]()
121
+ %11 : int[] = prim::Constant[value=[0, 0]]()
122
+ %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
123
+ %13 : Tensor = aten::relu(%12)
124
+ %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
125
+ %15 : Tensor = aten::log_sigmoid(%14)
126
+ %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
127
+ return (%16))IR" ;
128
+
129
+ auto g = std::make_shared<torch::jit::Graph>();
130
+ torch::jit::parseIR (graph, g.get ());
95
131
96
132
trtorch::core::partitioning::PartitionInfo partition_info;
97
133
partition_info.enabled = true ;
@@ -103,16 +139,29 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
103
139
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 }, {1 }, {2 }, {3 }, {4 }}));
104
140
}
105
141
106
- TEST (Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
107
- torch::jit::script::Module mod;
108
- try {
109
- mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
110
- } catch (const c10::Error& e) {
111
- std::cerr << " error loading the model\n " ;
112
- return ;
113
- }
114
- auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
115
- auto g = graph_and_parameters.first ;
142
+ TEST (Partitioning, SegmentBranchModelCorrectly) {
143
+ const auto graph = R"IR(
144
+ graph(%0 : Tensor,
145
+ %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
146
+ %2 : Float(32),
147
+ %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
148
+ %4 : Float(16)):
149
+ %5 : int[] = prim::Constant[value=[0, 0]]()
150
+ %6 : int[] = prim::Constant[value=[2, 2]]()
151
+ %7 : bool = prim::Constant[value=0]()
152
+ %8 : int[] = prim::Constant[value=[1, 1]]()
153
+ %9 : int = prim::Constant[value=1]()
154
+ %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
155
+ %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
156
+ %12: Tensor = aten::log_sigmoid(%10)
157
+ %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
158
+ %14 : Tensor = aten::relu(%11)
159
+ %15 : Tensor = aten::add(%13, %14, %9)
160
+ %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
161
+ return (%16))IR" ;
162
+
163
+ auto g = std::make_shared<torch::jit::Graph>();
164
+ torch::jit::parseIR (graph, g.get ());
116
165
117
166
trtorch::core::partitioning::PartitionInfo partition_info;
118
167
partition_info.enabled = true ;
@@ -123,16 +172,29 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
123
172
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 , 4 , 5 , 6 }}));
124
173
}
125
174
126
- TEST (Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
127
- torch::jit::script::Module mod;
128
- try {
129
- mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
130
- } catch (const c10::Error& e) {
131
- std::cerr << " error loading the model\n " ;
132
- return ;
133
- }
134
- auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
135
- auto g = graph_and_parameters.first ;
175
+ TEST (Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) {
176
+ const auto graph = R"IR(
177
+ graph(%0 : Tensor,
178
+ %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
179
+ %2 : Float(32),
180
+ %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
181
+ %4 : Float(16)):
182
+ %5 : int[] = prim::Constant[value=[0, 0]]()
183
+ %6 : int[] = prim::Constant[value=[2, 2]]()
184
+ %7 : bool = prim::Constant[value=0]()
185
+ %8 : int[] = prim::Constant[value=[1, 1]]()
186
+ %9 : int = prim::Constant[value=1]()
187
+ %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
188
+ %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
189
+ %12: Tensor = aten::log_sigmoid(%10)
190
+ %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
191
+ %14 : Tensor = aten::relu(%11)
192
+ %15 : Tensor = aten::add(%13, %14, %9)
193
+ %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
194
+ return (%16))IR" ;
195
+
196
+ auto g = std::make_shared<torch::jit::Graph>();
197
+ torch::jit::parseIR (graph, g.get ());
136
198
137
199
trtorch::core::partitioning::PartitionInfo partition_info;
138
200
partition_info.enabled = true ;
@@ -144,16 +206,29 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
144
206
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 , 5 , 6 }}));
145
207
}
146
208
147
- TEST (Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
148
- torch::jit::script::Module mod;
149
- try {
150
- mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
151
- } catch (const c10::Error& e) {
152
- std::cerr << " error loading the model\n " ;
153
- return ;
154
- }
155
- auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
156
- auto g = graph_and_parameters.first ;
209
+ TEST (Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) {
210
+ const auto graph = R"IR(
211
+ graph(%0 : Tensor,
212
+ %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
213
+ %2 : Float(32),
214
+ %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
215
+ %4 : Float(16)):
216
+ %5 : int[] = prim::Constant[value=[0, 0]]()
217
+ %6 : int[] = prim::Constant[value=[2, 2]]()
218
+ %7 : bool = prim::Constant[value=0]()
219
+ %8 : int[] = prim::Constant[value=[1, 1]]()
220
+ %9 : int = prim::Constant[value=1]()
221
+ %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
222
+ %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
223
+ %12: Tensor = aten::log_sigmoid(%10)
224
+ %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
225
+ %14 : Tensor = aten::relu(%11)
226
+ %15 : Tensor = aten::add(%13, %14, %9)
227
+ %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
228
+ return (%16))IR" ;
229
+
230
+ auto g = std::make_shared<torch::jit::Graph>();
231
+ torch::jit::parseIR (graph, g.get ());
157
232
158
233
trtorch::core::partitioning::PartitionInfo partition_info;
159
234
partition_info.enabled = true ;
0 commit comments