Skip to content

Commit 6eba5bd

Browse files
committed
Fix direct copy and refine split ut
test=develop
1 parent 5857fb3 commit 6eba5bd

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

paddle/fluid/inference/tensorrt/convert/test_split_op.cc

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,59 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(split_op, test) {
23+
template <int BatchSize, int Axis>
24+
void TensorRTSplitTest(const std::vector<int> &in_shape,
25+
const std::vector<int> &sections) {
2426
std::unordered_set<std::string> parameters({""});
2527
framework::Scope scope;
26-
TRTConvertValidation validator(10, parameters, scope, 1000);
27-
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
28-
validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
29-
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
28+
TRTConvertValidation validator(BatchSize + 1, parameters, scope, 10000);
29+
30+
auto make_dim = [](const std::vector<int> &shape) {
31+
nvinfer1::DimsCHW dim;
32+
dim.c() = shape[0];
33+
dim.h() = shape[1];
34+
dim.w() = shape[2];
35+
return dim;
36+
};
37+
validator.DeclInputVar("split_input", make_dim(in_shape));
38+
std::vector<std::string> output_vars;
39+
for (size_t i = 0; i < sections.size(); ++i) {
40+
auto out_shape = in_shape;
41+
out_shape[Axis - 1] = sections[i];
42+
std::string output_name = "split_out" + std::to_string(i);
43+
validator.DeclOutputVar(output_name, make_dim(out_shape));
44+
output_vars.push_back(output_name);
45+
}
3046

3147
// Prepare Op description
3248
framework::OpDesc desc;
3349
desc.SetType("split");
3450
desc.SetInput("X", {"split_input"});
35-
desc.SetOutput("Out", {"split_out1", "split_out2"});
51+
desc.SetOutput("Out", output_vars);
3652

37-
int num = 0;
38-
int axis = 1;
39-
std::vector<int> output_lengths = {2, 1};
40-
desc.SetAttr("axis", axis);
41-
desc.SetAttr("num", num);
42-
desc.SetAttr("sections", output_lengths);
53+
desc.SetAttr("axis", Axis);
54+
desc.SetAttr("num", 0);
55+
desc.SetAttr("sections", sections);
4356

4457
validator.SetOp(*desc.Proto());
4558

46-
validator.Execute(1);
59+
validator.Execute(BatchSize);
60+
}
61+
62+
TEST(split_op, test_same_shape_batch1) {
63+
TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2});
64+
}
65+
66+
TEST(split_op, test_different_shape_batch1) {
67+
TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1});
68+
}
69+
70+
TEST(split_op, test_same_shape_batch10) {
71+
TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2});
72+
}
73+
74+
TEST(split_op, test_different_shape_batch10) {
75+
TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1});
4776
}
4877

4978
} // namespace tensorrt

paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,12 @@ inline void Split(cudaStream_t stream, const bool same_shape,
138138
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
139139
void** outputs, void* workspace, cudaStream_t stream) {
140140
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
141-
if (axis_ == -1 && this->getNbOutputs() < 10) {
141+
if (((batchSize == 1 && axis_ == 0) || axis_ == -1) &&
142+
this->getNbOutputs() < 10) {
142143
float** output_ptrs = reinterpret_cast<float**>(outputs);
143144
int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT)
144-
? sizeof(__half)
145-
: sizeof(float);
145+
? sizeof(float)
146+
: sizeof(__half);
146147
for (int i = 0; i < this->getNbOutputs(); ++i) {
147148
PADDLE_ENFORCE(
148149
cudaMemcpyAsync(

0 commit comments

Comments
 (0)