1
1
#include < string>
2
+ #include " core/lowering/lowering.h"
3
+ #include " core/partitioning/partitioning.h"
2
4
#include " gtest/gtest.h"
3
5
#include " tests/util/util.h"
4
6
#include " torch/script.h"
5
7
#include " trtorch/trtorch.h"
6
- #include " core/lowering/lowering.h"
7
- #include " core/partitioning/partitioning.h"
8
-
9
8
10
- bool checkSegmentedBlockNumber (std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
11
- trtorch::core::partitioning::SegmentedBlock::SegmentedBlockTarget target, int target_count) {
12
- for (auto &seg_block : segmented_blocks) {
9
+ bool checkSegmentedBlockNumber (
10
+ std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
11
+ trtorch::core::partitioning::SegmentedBlock::SegmentedBlockTarget target,
12
+ int target_count) {
13
+ for (auto & seg_block : segmented_blocks) {
13
14
if (seg_block.target () == target) {
14
15
target_count--;
15
16
}
16
17
}
17
18
return target_count == 0 ;
18
19
}
19
20
20
- bool checkSegmentedBlockNodesMapping (std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
21
- std::shared_ptr<torch::jit::Graph> g, std::vector<std::vector<int >> nodes_index) {
21
+ bool checkSegmentedBlockNodesMapping (
22
+ std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
23
+ std::shared_ptr<torch::jit::Graph> g,
24
+ std::vector<std::vector<int >> nodes_index) {
22
25
std::vector<torch::jit::Node*> graph_nodes;
23
26
for (const auto n : g->nodes ()) {
24
27
if (n->kind () != torch::jit::prim::Constant) {
@@ -32,25 +35,27 @@ bool checkSegmentedBlockNodesMapping(std::vector<trtorch::core::partitioning::Se
32
35
return false ;
33
36
}
34
37
}
35
- if (seg_block_node_id != segmented_blocks[i].raw_nodes ().size ()) return false ;
38
+ if (seg_block_node_id != segmented_blocks[i].raw_nodes ().size ())
39
+ return false ;
36
40
}
37
41
return true ;
38
42
}
39
43
40
44
TEST (Partitioning, SegmentingGraphDefaultCorrectly) {
41
45
torch::jit::script::Module mod;
42
46
try {
43
- mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
47
+ mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
44
48
} catch (const c10::Error& e) {
45
- std::cerr << " error loading the model\n " ;
46
- return ;
49
+ std::cerr << " error loading the model\n " ;
50
+ return ;
47
51
}
48
52
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
49
53
auto g = graph_and_parameters.first ;
50
54
51
55
trtorch::core::partitioning::PartitionInfo partition_info;
52
56
partition_info.enabled = true ;
53
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info);
57
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
58
+ trtorch::core::partitioning::segment_graph (g, partition_info);
54
59
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 2 ));
55
60
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
56
61
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 }, {4 }}));
@@ -59,18 +64,19 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
59
64
TEST (Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
60
65
torch::jit::script::Module mod;
61
66
try {
62
- mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
67
+ mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
63
68
} catch (const c10::Error& e) {
64
- std::cerr << " error loading the model\n " ;
65
- return ;
69
+ std::cerr << " error loading the model\n " ;
70
+ return ;
66
71
}
67
72
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
68
73
auto g = graph_and_parameters.first ;
69
74
70
75
trtorch::core::partitioning::PartitionInfo partition_info;
71
76
partition_info.enabled = true ;
72
77
partition_info.min_block_size = 3 ;
73
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info);
78
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
79
+ trtorch::core::partitioning::segment_graph (g, partition_info);
74
80
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
75
81
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
76
82
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 }}));
@@ -79,18 +85,19 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
79
85
TEST (Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
80
86
torch::jit::script::Module mod;
81
87
try {
82
- mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
88
+ mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
83
89
} catch (const c10::Error& e) {
84
- std::cerr << " error loading the model\n " ;
85
- return ;
90
+ std::cerr << " error loading the model\n " ;
91
+ return ;
86
92
}
87
93
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
88
94
auto g = graph_and_parameters.first ;
89
95
90
96
trtorch::core::partitioning::PartitionInfo partition_info;
91
97
partition_info.enabled = true ;
92
98
partition_info.forced_fallback_operators .push_back (" aten::relu" );
93
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info);
99
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
100
+ trtorch::core::partitioning::segment_graph (g, partition_info);
94
101
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 3 ));
95
102
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 2 ));
96
103
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 }, {1 }, {2 }, {3 }, {4 }}));
@@ -99,17 +106,18 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
99
106
TEST (Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
100
107
torch::jit::script::Module mod;
101
108
try {
102
- mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
109
+ mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
103
110
} catch (const c10::Error& e) {
104
- std::cerr << " error loading the model\n " ;
105
- return ;
111
+ std::cerr << " error loading the model\n " ;
112
+ return ;
106
113
}
107
114
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
108
115
auto g = graph_and_parameters.first ;
109
116
110
117
trtorch::core::partitioning::PartitionInfo partition_info;
111
118
partition_info.enabled = true ;
112
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info);
119
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
120
+ trtorch::core::partitioning::segment_graph (g, partition_info);
113
121
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 2 ));
114
122
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
115
123
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 , 4 , 5 , 6 }}));
@@ -118,18 +126,19 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
118
126
TEST (Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
119
127
torch::jit::script::Module mod;
120
128
try {
121
- mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
129
+ mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
122
130
} catch (const c10::Error& e) {
123
- std::cerr << " error loading the model\n " ;
124
- return ;
131
+ std::cerr << " error loading the model\n " ;
132
+ return ;
125
133
}
126
134
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
127
135
auto g = graph_and_parameters.first ;
128
136
129
137
trtorch::core::partitioning::PartitionInfo partition_info;
130
138
partition_info.enabled = true ;
131
139
partition_info.min_block_size = 3 ;
132
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info);
140
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
141
+ trtorch::core::partitioning::segment_graph (g, partition_info);
133
142
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
134
143
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
135
144
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 , 5 , 6 }}));
@@ -138,18 +147,19 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
138
147
TEST (Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
139
148
torch::jit::script::Module mod;
140
149
try {
141
- mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
150
+ mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
142
151
} catch (const c10::Error& e) {
143
- std::cerr << " error loading the model\n " ;
144
- return ;
152
+ std::cerr << " error loading the model\n " ;
153
+ return ;
145
154
}
146
155
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
147
156
auto g = graph_and_parameters.first ;
148
157
149
158
trtorch::core::partitioning::PartitionInfo partition_info;
150
159
partition_info.enabled = true ;
151
160
partition_info.forced_fallback_operators .push_back (" aten::relu" );
152
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info);
161
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
162
+ trtorch::core::partitioning::segment_graph (g, partition_info);
153
163
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 3 ));
154
164
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 2 ));
155
165
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 }, {4 }, {5 , 6 }}));
0 commit comments