Skip to content

Commit 1adda8e

Browse files
committed
Add more unit tests for split plugin
test=develop
1 parent 6eba5bd commit 1adda8e

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

paddle/fluid/inference/tensorrt/convert/split_op.cc

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ namespace paddle {
1919
namespace inference {
2020
namespace tensorrt {
2121

22-
/*
23-
* SplitOp.
24-
*/
2522
class SplitOpConverter : public OpConverter {
2623
public:
2724
void operator()(const framework::proto::OpDesc& op,
@@ -40,15 +37,11 @@ class SplitOpConverter : public OpConverter {
4037
int axis = boost::get<int>(op_desc.GetAttr("axis"));
4138
std::vector<int> output_lengths =
4239
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
43-
// PADDLE_ENFORCE(axis != 0);
44-
if (axis < 0) {
45-
axis += input_dims.nbDims;
46-
} else {
47-
axis -= 1;
48-
}
40+
// split on batch is not supported in TensorRT
41+
PADDLE_ENFORCE(axis != 0);
42+
axis += (axis < 0) ? input_dims.nbDims : -1;
4943

5044
PADDLE_ENFORCE(output_lengths.size() == output_num);
51-
//
5245
plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths);
5346
nvinfer1::IPluginLayer* layer =
5447
engine_->AddPlugin(&input, input_num, plugin);

paddle/fluid/inference/tensorrt/convert/test_split_op.cc

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,54 @@ void TensorRTSplitTest(const std::vector<int> &in_shape,
5959
validator.Execute(BatchSize);
6060
}
6161

62-
TEST(split_op, test_same_shape_batch1) {
62+
// batch = 0, axis = 1, same shape
63+
TEST(split_op, test_same_shape_axis1_batch1) {
6364
TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2});
6465
}
65-
66-
TEST(split_op, test_different_shape_batch1) {
66+
// batch = 0, axis = 1, different shape
67+
TEST(split_op, test_different_shape_axis1_batch1) {
6768
TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1});
6869
}
69-
70-
TEST(split_op, test_same_shape_batch10) {
70+
// batch = 10, axis = 1, same shape
71+
TEST(split_op, test_same_shape_axis1_batch10) {
7172
TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2});
7273
}
73-
74-
TEST(split_op, test_different_shape_batch10) {
74+
// batch = 10, axis = 1, different shape
75+
TEST(split_op, test_different_shape_axis1_batch10) {
7576
TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1});
7677
}
78+
// batch = 0, axis = 2, same shape
79+
TEST(split_op, test_same_shape_axis2_batch1) {
80+
TensorRTSplitTest<1, 2>({3, 4, 2}, {2, 2});
81+
}
82+
// batch = 0, axis = 2, different shape
83+
TEST(split_op, test_different_shape_axis2_batch1) {
84+
TensorRTSplitTest<1, 2>({3, 3, 2}, {2, 1});
85+
}
86+
// batch = 10, axis = 2, same shape
87+
TEST(split_op, test_same_shape_axis2_batch10) {
88+
TensorRTSplitTest<10, 2>({3, 4, 2}, {2, 2});
89+
}
90+
// batch = 10, axis = 2, different shape
91+
TEST(split_op, test_different_shape_axis2_batch10) {
92+
TensorRTSplitTest<10, 2>({3, 3, 2}, {2, 1});
93+
}
94+
// batch = 0, axis = 3, same shape
95+
TEST(split_op, test_same_shape_axis3_batch1) {
96+
TensorRTSplitTest<1, 3>({3, 2, 4}, {2, 2});
97+
}
98+
// batch = 0, axis = 3, different shape
99+
TEST(split_op, test_different_shape_axis3_batch1) {
100+
TensorRTSplitTest<1, 3>({3, 2, 3}, {2, 1});
101+
}
102+
// batch = 10, axis = 3, same shape
103+
TEST(split_op, test_same_shape_axis3_batch10) {
104+
TensorRTSplitTest<10, 3>({3, 2, 4}, {2, 2});
105+
}
106+
// batch = 10, axis = 3, different shape
107+
TEST(split_op, test_different_shape_axis3_batch10) {
108+
TensorRTSplitTest<10, 3>({3, 2, 3}, {2, 1});
109+
}
77110

78111
} // namespace tensorrt
79112
} // namespace inference

0 commit comments

Comments
 (0)