@@ -59,21 +59,54 @@ void TensorRTSplitTest(const std::vector<int> &in_shape,
59
59
validator.Execute (BatchSize);
60
60
}
61
61
62
- TEST (split_op, test_same_shape_batch1) {
62
+ // batch = 0, axis = 1, same shape
63
+ TEST (split_op, test_same_shape_axis1_batch1) {
63
64
TensorRTSplitTest<1 , 1 >({4 , 2 , 2 }, {2 , 2 });
64
65
}
65
-
66
- TEST (split_op, test_different_shape_batch1 ) {
66
+ // batch = 0, axis = 1, different shape
67
+ TEST (split_op, test_different_shape_axis1_batch1 ) {
67
68
TensorRTSplitTest<1 , 1 >({3 , 2 , 2 }, {2 , 1 });
68
69
}
69
-
70
- TEST (split_op, test_same_shape_batch10 ) {
70
+ // batch = 10, axis = 1, same shape
71
+ TEST (split_op, test_same_shape_axis1_batch10 ) {
71
72
TensorRTSplitTest<10 , 1 >({4 , 2 , 2 }, {2 , 2 });
72
73
}
73
-
74
- TEST (split_op, test_different_shape_batch10 ) {
74
+ // batch = 10, axis = 1, different shape
75
+ TEST (split_op, test_different_shape_axis1_batch10 ) {
75
76
TensorRTSplitTest<10 , 1 >({3 , 2 , 2 }, {2 , 1 });
76
77
}
78
+ // batch = 0, axis = 2, same shape
79
+ TEST (split_op, test_same_shape_axis2_batch1) {
80
+ TensorRTSplitTest<1 , 2 >({3 , 4 , 2 }, {2 , 2 });
81
+ }
82
+ // batch = 0, axis = 2, different shape
83
+ TEST (split_op, test_different_shape_axis2_batch1) {
84
+ TensorRTSplitTest<1 , 2 >({3 , 3 , 2 }, {2 , 1 });
85
+ }
86
+ // batch = 10, axis = 2, same shape
87
+ TEST (split_op, test_same_shape_axis2_batch10) {
88
+ TensorRTSplitTest<10 , 2 >({3 , 4 , 2 }, {2 , 2 });
89
+ }
90
+ // batch = 10, axis = 2, different shape
91
+ TEST (split_op, test_different_shape_axis2_batch10) {
92
+ TensorRTSplitTest<10 , 2 >({3 , 3 , 2 }, {2 , 1 });
93
+ }
94
+ // batch = 0, axis = 3, same shape
95
+ TEST (split_op, test_same_shape_axis3_batch1) {
96
+ TensorRTSplitTest<1 , 3 >({3 , 2 , 4 }, {2 , 2 });
97
+ }
98
+ // batch = 0, axis = 3, different shape
99
+ TEST (split_op, test_different_shape_axis3_batch1) {
100
+ TensorRTSplitTest<1 , 3 >({3 , 2 , 3 }, {2 , 1 });
101
+ }
102
+ // batch = 10, axis = 3, same shape
103
+ TEST (split_op, test_same_shape_axis3_batch10) {
104
+ TensorRTSplitTest<10 , 3 >({3 , 2 , 4 }, {2 , 2 });
105
+ }
106
+ // batch = 10, axis = 3, different shape
107
+ TEST (split_op, test_different_shape_axis3_batch10) {
108
+ TensorRTSplitTest<10 , 3 >({3 , 2 , 3 }, {2 , 1 });
109
+ }
77
110
78
111
} // namespace tensorrt
79
112
} // namespace inference
0 commit comments