Skip to content

Commit 2840281

Browse files
committed
chore: change test code according to new APIs
Signed-off-by: Bo Wang <[email protected]>
1 parent 6d826d3 commit 2840281

File tree

3 files changed

+32
-30
lines changed

3 files changed

+32
-30
lines changed

core/partitioning/partitioning.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace partitioning {
1414

1515
typedef std::vector<SegmentedBlock> PartitionedGraph;
1616

17+
PartitionedGraph segment_graph(std::shared_ptr<torch::jit::Graph> g, const PartitionInfo& partition_info);
18+
1719
std::vector<SegmentedBlock> Partition(
1820
std::shared_ptr<torch::jit::Graph> g,
1921
std::vector<ir::InputRange>& input_ranges,

tests/core/partitioning/test_segmentation.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectly) {
4848
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
4949
auto g = graph_and_parameters.first;
5050

51-
trtorch::core::conversion::TorchFallback fallback_info;
52-
fallback_info.enabled = true;
53-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
51+
trtorch::core::partitioning::PartitionInfo partition_info;
52+
partition_info.enabled = true;
53+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, partition_info);
5454
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 2));
5555
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
5656
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3}, {4}}));
@@ -67,10 +67,10 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectly) {
6767
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
6868
auto g = graph_and_parameters.first;
6969

70-
trtorch::core::conversion::TorchFallback fallback_info;
71-
fallback_info.enabled = true;
72-
fallback_info.min_block_size = 3;
73-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
70+
trtorch::core::partitioning::PartitionInfo partition_info;
71+
partition_info.enabled = true;
72+
partition_info.min_block_size = 3;
73+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, partition_info);
7474
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 1));
7575
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
7676
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
@@ -87,10 +87,10 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectly) {
8787
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
8888
auto g = graph_and_parameters.first;
8989

90-
trtorch::core::conversion::TorchFallback fallback_info;
91-
fallback_info.enabled = true;
92-
fallback_info.forced_fallback_operators.push_back("aten::relu");
93-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
90+
trtorch::core::partitioning::PartitionInfo partition_info;
91+
partition_info.enabled = true;
92+
partition_info.forced_fallback_operators.push_back("aten::relu");
93+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, partition_info);
9494
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 3));
9595
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 2));
9696
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0}, {1}, {2}, {3}, {4}}));
@@ -107,9 +107,9 @@ TEST(Partitioning, SegmentingGraphDefaultCorrectlyEdge) {
107107
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
108108
auto g = graph_and_parameters.first;
109109

110-
trtorch::core::conversion::TorchFallback fallback_info;
111-
fallback_info.enabled = true;
112-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
110+
trtorch::core::partitioning::PartitionInfo partition_info;
111+
partition_info.enabled = true;
112+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, partition_info);
113113
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 2));
114114
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
115115
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3, 4, 5, 6}}));
@@ -126,10 +126,10 @@ TEST(Partitioning, SegmentingGraphWithMinBlockSizeCorrectlyEdge) {
126126
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
127127
auto g = graph_and_parameters.first;
128128

129-
trtorch::core::conversion::TorchFallback fallback_info;
130-
fallback_info.enabled = true;
131-
fallback_info.min_block_size = 3;
132-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
129+
trtorch::core::partitioning::PartitionInfo partition_info;
130+
partition_info.enabled = true;
131+
partition_info.min_block_size = 3;
132+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, partition_info);
133133
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 1));
134134
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 1));
135135
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4, 5, 6}}));
@@ -146,10 +146,10 @@ TEST(Partitioning, SegmentingGraphWithForcedOPeCorrectlyEdge) {
146146
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
147147
auto g = graph_and_parameters.first;
148148

149-
trtorch::core::conversion::TorchFallback fallback_info;
150-
fallback_info.enabled = true;
151-
fallback_info.forced_fallback_operators.push_back("aten::relu");
152-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, fallback_info);
149+
trtorch::core::partitioning::PartitionInfo partition_info;
150+
partition_info.enabled = true;
151+
partition_info.forced_fallback_operators.push_back("aten::relu");
152+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::segment_graph(g, partition_info);
153153
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTensorRT, 3));
154154
ASSERT_TRUE(checkSegmentedBlockNumber(segmented_blocks, trtorch::core::partitioning::SegmentedBlock::kTorch, 2));
155155
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3}, {4}, {5, 6}}));

tests/core/partitioning/test_shape_analysis.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ TEST(Partitioning, InferSegmentedBlockShapeCorrectly) {
3232
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
3333
auto g = graph_and_parameters.first;
3434

35-
trtorch::core::conversion::TorchFallback fallback_info;
36-
fallback_info.enabled = true;
37-
std::vector<trtorch::core::conversion::InputRange> input_ranges{trtorch::core::conversion::InputRange({3, 3, 16, 16})};
35+
trtorch::core::partitioning::PartitionInfo partition_info;
36+
partition_info.enabled = true;
37+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
3838

39-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::Partition(g, input_ranges, fallback_info);
39+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::Partition(g, input_ranges, partition_info);
4040
ASSERT_TRUE(checkSegmentedBlockInputShape(segmented_blocks, {{{3, 3, 16, 16}}, {{3, 16, 16, 16}}, {{3, 16, 16, 16}}}));
4141
}
4242

@@ -51,10 +51,10 @@ TEST(Partitioning, InferSegmentedBlockShapeCorrectlyEdge) {
5151
auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward");
5252
auto g = graph_and_parameters.first;
5353

54-
trtorch::core::conversion::TorchFallback fallback_info;
55-
fallback_info.enabled = true;
56-
std::vector<trtorch::core::conversion::InputRange> input_ranges{trtorch::core::conversion::InputRange({3, 3, 16, 16})};
54+
trtorch::core::partitioning::PartitionInfo partition_info;
55+
partition_info.enabled = true;
56+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
5757

58-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::Partition(g, input_ranges, fallback_info);
58+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks = trtorch::core::partitioning::Partition(g, input_ranges, partition_info);
5959
ASSERT_TRUE(checkSegmentedBlockInputShape(segmented_blocks, {{{3, 3, 16, 16}}, {{3, 32, 16, 16}}, {{3, 32, 16, 16}, {3, 16, 16, 16}}}));
6060
}

0 commit comments

Comments
 (0)