@@ -20,30 +20,59 @@ namespace paddle {
20
20
namespace inference {
21
21
namespace tensorrt {
22
22
23
- TEST (split_op, test) {
23
+ template <int BatchSize, int Axis>
24
+ void TensorRTSplitTest (const std::vector<int > &in_shape,
25
+ const std::vector<int > §ions) {
24
26
std::unordered_set<std::string> parameters ({" " });
25
27
framework::Scope scope;
26
- TRTConvertValidation validator (10 , parameters, scope, 1000 );
27
- validator.DeclInputVar (" split_input" , nvinfer1::DimsCHW (3 , 2 , 2 ));
28
- validator.DeclOutputVar (" split_out1" , nvinfer1::DimsCHW (2 , 2 , 2 ));
29
- validator.DeclOutputVar (" split_out2" , nvinfer1::DimsCHW (1 , 2 , 2 ));
28
+ TRTConvertValidation validator (BatchSize + 1 , parameters, scope, 10000 );
29
+
30
+ auto make_dim = [](const std::vector<int > &shape) {
31
+ nvinfer1::DimsCHW dim;
32
+ dim.c () = shape[0 ];
33
+ dim.h () = shape[1 ];
34
+ dim.w () = shape[2 ];
35
+ return dim;
36
+ };
37
+ validator.DeclInputVar (" split_input" , make_dim (in_shape));
38
+ std::vector<std::string> output_vars;
39
+ for (size_t i = 0 ; i < sections.size (); ++i) {
40
+ auto out_shape = in_shape;
41
+ out_shape[Axis - 1 ] = sections[i];
42
+ std::string output_name = " split_out" + std::to_string (i);
43
+ validator.DeclOutputVar (output_name, make_dim (out_shape));
44
+ output_vars.push_back (output_name);
45
+ }
30
46
31
47
// Prepare Op description
32
48
framework::OpDesc desc;
33
49
desc.SetType (" split" );
34
50
desc.SetInput (" X" , {" split_input" });
35
- desc.SetOutput (" Out" , { " split_out1 " , " split_out2 " } );
51
+ desc.SetOutput (" Out" , output_vars );
36
52
37
- int num = 0 ;
38
- int axis = 1 ;
39
- std::vector<int > output_lengths = {2 , 1 };
40
- desc.SetAttr (" axis" , axis);
41
- desc.SetAttr (" num" , num);
42
- desc.SetAttr (" sections" , output_lengths);
53
+ desc.SetAttr (" axis" , Axis);
54
+ desc.SetAttr (" num" , 0 );
55
+ desc.SetAttr (" sections" , sections);
43
56
44
57
validator.SetOp (*desc.Proto ());
45
58
46
- validator.Execute (1 );
59
+ validator.Execute (BatchSize);
60
+ }
61
+
62
+ TEST (split_op, test_same_shape_batch1) {
63
+ TensorRTSplitTest<1 , 1 >({4 , 2 , 2 }, {2 , 2 });
64
+ }
65
+
66
+ TEST (split_op, test_different_shape_batch1) {
67
+ TensorRTSplitTest<1 , 1 >({3 , 2 , 2 }, {2 , 1 });
68
+ }
69
+
70
+ TEST (split_op, test_same_shape_batch10) {
71
+ TensorRTSplitTest<10 , 1 >({4 , 2 , 2 }, {2 , 2 });
72
+ }
73
+
74
+ TEST (split_op, test_different_shape_batch10) {
75
+ TensorRTSplitTest<10 , 1 >({3 , 2 , 2 }, {2 , 1 });
47
76
}
48
77
49
78
} // namespace tensorrt
0 commit comments