1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " tests/util/util.h"
4
+ #include " torch/script.h"
5
+ #include " trtorch/trtorch.h"
6
+ #include " core/lowering/lowering.h"
7
+ #include " core/partitioning/partitioning.h"
8
+
9
+
10
+ bool checkSegmentedBlockNumber (std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
11
+ trtorch::core::partitioning::SegmentedBlock::SegmentedBlockTarget target, int target_count) {
12
+ for (auto &seg_block : segmented_blocks) {
13
+ if (seg_block.target () == target) {
14
+ target_count--;
15
+ }
16
+ }
17
+ return target_count == 0 ;
18
+ }
19
+
20
+ bool checkSegmentedBlockNodesMapping (std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
21
+ std::shared_ptr<torch::jit::Graph> g, std::vector<std::vector<int >> nodes_index) {
22
+ std::vector<torch::jit::Node*> graph_nodes;
23
+ for (const auto n : g->nodes ()) {
24
+ if (n->kind () != torch::jit::prim::Constant) {
25
+ graph_nodes.push_back (n);
26
+ }
27
+ }
28
+ for (size_t i = 0 ; i < nodes_index.size (); ++i) {
29
+ size_t seg_block_node_id = 0 ;
30
+ for (int j : nodes_index[i]) {
31
+ if (segmented_blocks[i].raw_nodes ()[seg_block_node_id++] != graph_nodes[j]) {
32
+ return false ;
33
+ }
34
+ }
35
+ if (seg_block_node_id != segmented_blocks[i].raw_nodes ().size ()) return false ;
36
+ }
37
+ return true ;
38
+ }
39
+
40
+ TEST (Partitioning, SegmentingGraphDefaultCorrectly) {
41
+ torch::jit::script::Module mod;
42
+ try {
43
+ mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
44
+ } catch (const c10::Error& e) {
45
+ std::cerr << " error loading the model\n " ;
46
+ return ;
47
+ }
48
+ auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
49
+ auto g = graph_and_parameters.first ;
50
+
51
+ trtorch::core::conversion::TorchFallback fallback_info;
52
+ fallback_info.enabled = true ;
53
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info);
54
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 2 ));
55
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
56
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 }, {4 }}));
57
+ }
58
+
59
+ TEST (Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
60
+ torch::jit::script::Module mod;
61
+ try {
62
+ mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
63
+ } catch (const c10::Error& e) {
64
+ std::cerr << " error loading the model\n " ;
65
+ return ;
66
+ }
67
+ auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
68
+ auto g = graph_and_parameters.first ;
69
+
70
+ trtorch::core::conversion::TorchFallback fallback_info;
71
+ fallback_info.enabled = true ;
72
+ fallback_info.min_block_size = 3 ;
73
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info);
74
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
75
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
76
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 }}));
77
+ }
78
+
79
+ TEST (Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
80
+ torch::jit::script::Module mod;
81
+ try {
82
+ mod = torch::jit::load (" tests/core/partitioning/test_base_model.jit" );
83
+ } catch (const c10::Error& e) {
84
+ std::cerr << " error loading the model\n " ;
85
+ return ;
86
+ }
87
+ auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
88
+ auto g = graph_and_parameters.first ;
89
+
90
+ trtorch::core::conversion::TorchFallback fallback_info;
91
+ fallback_info.enabled = true ;
92
+ fallback_info.forced_fallback_operators .push_back (" aten::relu" );
93
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info);
94
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 3 ));
95
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 2 ));
96
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 }, {1 }, {2 }, {3 }, {4 }}));
97
+ }
98
+
99
+ TEST (Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
100
+ torch::jit::script::Module mod;
101
+ try {
102
+ mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
103
+ } catch (const c10::Error& e) {
104
+ std::cerr << " error loading the model\n " ;
105
+ return ;
106
+ }
107
+ auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
108
+ auto g = graph_and_parameters.first ;
109
+
110
+ trtorch::core::conversion::TorchFallback fallback_info;
111
+ fallback_info.enabled = true ;
112
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info);
113
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 2 ));
114
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
115
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 , 4 , 5 , 6 }}));
116
+ }
117
+
118
+ TEST (Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
119
+ torch::jit::script::Module mod;
120
+ try {
121
+ mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
122
+ } catch (const c10::Error& e) {
123
+ std::cerr << " error loading the model\n " ;
124
+ return ;
125
+ }
126
+ auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
127
+ auto g = graph_and_parameters.first ;
128
+
129
+ trtorch::core::conversion::TorchFallback fallback_info;
130
+ fallback_info.enabled = true ;
131
+ fallback_info.min_block_size = 3 ;
132
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info);
133
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
134
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
135
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 , 5 , 6 }}));
136
+ }
137
+
138
+ TEST (Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
139
+ torch::jit::script::Module mod;
140
+ try {
141
+ mod = torch::jit::load (" tests/core/partitioning/test_edge_model.jit" );
142
+ } catch (const c10::Error& e) {
143
+ std::cerr << " error loading the model\n " ;
144
+ return ;
145
+ }
146
+ auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
147
+ auto g = graph_and_parameters.first ;
148
+
149
+ trtorch::core::conversion::TorchFallback fallback_info;
150
+ fallback_info.enabled = true ;
151
+ fallback_info.forced_fallback_operators .push_back (" aten::relu" );
152
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info);
153
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 3 ));
154
+ ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 2 ));
155
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 }, {4 }, {5 , 6 }}));
156
+ }
0 commit comments