@@ -48,9 +48,9 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
48
48
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
49
49
auto g = graph_and_parameters.first ;
50
50
51
- trtorch::core::conversion::TorchFallback fallback_info ;
52
- fallback_info .enabled = true ;
53
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info );
51
+ trtorch::core::partitioning::PartitionInfo partition_info ;
52
+ partition_info .enabled = true ;
53
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info );
54
54
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 2 ));
55
55
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
56
56
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 }, {4 }}));
@@ -67,10 +67,10 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
67
67
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
68
68
auto g = graph_and_parameters.first ;
69
69
70
- trtorch::core::conversion::TorchFallback fallback_info ;
71
- fallback_info .enabled = true ;
72
- fallback_info .min_block_size = 3 ;
73
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info );
70
+ trtorch::core::partitioning::PartitionInfo partition_info ;
71
+ partition_info .enabled = true ;
72
+ partition_info .min_block_size = 3 ;
73
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info );
74
74
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
75
75
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
76
76
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 }}));
@@ -87,10 +87,10 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
87
87
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
88
88
auto g = graph_and_parameters.first ;
89
89
90
- trtorch::core::conversion::TorchFallback fallback_info ;
91
- fallback_info .enabled = true ;
92
- fallback_info .forced_fallback_operators .push_back (" aten::relu" );
93
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info );
90
+ trtorch::core::partitioning::PartitionInfo partition_info ;
91
+ partition_info .enabled = true ;
92
+ partition_info .forced_fallback_operators .push_back (" aten::relu" );
93
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info );
94
94
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 3 ));
95
95
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 2 ));
96
96
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 }, {1 }, {2 }, {3 }, {4 }}));
@@ -107,9 +107,9 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
107
107
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
108
108
auto g = graph_and_parameters.first ;
109
109
110
- trtorch::core::conversion::TorchFallback fallback_info ;
111
- fallback_info .enabled = true ;
112
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info );
110
+ trtorch::core::partitioning::PartitionInfo partition_info ;
111
+ partition_info .enabled = true ;
112
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info );
113
113
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 2 ));
114
114
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
115
115
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 , 4 , 5 , 6 }}));
@@ -126,10 +126,10 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
126
126
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
127
127
auto g = graph_and_parameters.first ;
128
128
129
- trtorch::core::conversion::TorchFallback fallback_info ;
130
- fallback_info .enabled = true ;
131
- fallback_info .min_block_size = 3 ;
132
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info );
129
+ trtorch::core::partitioning::PartitionInfo partition_info ;
130
+ partition_info .enabled = true ;
131
+ partition_info .min_block_size = 3 ;
132
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info );
133
133
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
134
134
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 1 ));
135
135
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 , 5 , 6 }}));
@@ -146,10 +146,10 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
146
146
auto graph_and_parameters = trtorch::core::lowering::Lower (mod, " forward" );
147
147
auto g = graph_and_parameters.first ;
148
148
149
- trtorch::core::conversion::TorchFallback fallback_info ;
150
- fallback_info .enabled = true ;
151
- fallback_info .forced_fallback_operators .push_back (" aten::relu" );
152
- std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, fallback_info );
149
+ trtorch::core::partitioning::PartitionInfo partition_info ;
150
+ partition_info .enabled = true ;
151
+ partition_info .forced_fallback_operators .push_back (" aten::relu" );
152
+ std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph (g, partition_info );
153
153
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT , 3 ));
154
154
ASSERT_TRUE (checkSegmentedBlockNumber (segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch , 2 ));
155
155
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 }, {2 }, {3 }, {4 }, {5 , 6 }}));
0 commit comments