Skip to content

Commit e52d90a

Browse files
authored
Merge pull request #14527 from hjchen2/develop
Refine split TensorRT plugin
2 parents 4531281 + 1adda8e commit e52d90a

File tree

4 files changed

+211
-56
lines changed

4 files changed

+211
-56
lines changed

paddle/fluid/inference/tensorrt/convert/split_op.cc

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ namespace paddle {
1919
namespace inference {
2020
namespace tensorrt {
2121

22-
/*
23-
* SplitOp.
24-
*/
2522
class SplitOpConverter : public OpConverter {
2623
public:
2724
void operator()(const framework::proto::OpDesc& op,
@@ -40,16 +37,11 @@ class SplitOpConverter : public OpConverter {
4037
int axis = boost::get<int>(op_desc.GetAttr("axis"));
4138
std::vector<int> output_lengths =
4239
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
40+
// split on batch is not supported in TensorRT
4341
PADDLE_ENFORCE(axis != 0);
44-
if (axis < 0) {
45-
axis += input_dims.nbDims;
46-
} else {
47-
axis -= 1;
48-
}
42+
axis += (axis < 0) ? input_dims.nbDims : -1;
4943

5044
PADDLE_ENFORCE(output_lengths.size() == output_num);
51-
52-
//
5345
plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths);
5446
nvinfer1::IPluginLayer* layer =
5547
engine_->AddPlugin(&input, input_num, plugin);

paddle/fluid/inference/tensorrt/convert/test_split_op.cc

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,92 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(split_op, test) {
23+
template <int BatchSize, int Axis>
24+
void TensorRTSplitTest(const std::vector<int> &in_shape,
25+
const std::vector<int> &sections) {
2426
std::unordered_set<std::string> parameters({""});
2527
framework::Scope scope;
26-
TRTConvertValidation validator(10, parameters, scope, 1000);
27-
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
28-
validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
29-
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
28+
TRTConvertValidation validator(BatchSize + 1, parameters, scope, 10000);
29+
30+
auto make_dim = [](const std::vector<int> &shape) {
31+
nvinfer1::DimsCHW dim;
32+
dim.c() = shape[0];
33+
dim.h() = shape[1];
34+
dim.w() = shape[2];
35+
return dim;
36+
};
37+
validator.DeclInputVar("split_input", make_dim(in_shape));
38+
std::vector<std::string> output_vars;
39+
for (size_t i = 0; i < sections.size(); ++i) {
40+
auto out_shape = in_shape;
41+
out_shape[Axis - 1] = sections[i];
42+
std::string output_name = "split_out" + std::to_string(i);
43+
validator.DeclOutputVar(output_name, make_dim(out_shape));
44+
output_vars.push_back(output_name);
45+
}
3046

3147
// Prepare Op description
3248
framework::OpDesc desc;
3349
desc.SetType("split");
3450
desc.SetInput("X", {"split_input"});
35-
desc.SetOutput("Out", {"split_out1", "split_out2"});
51+
desc.SetOutput("Out", output_vars);
3652

37-
int num = 0;
38-
int axis = 1;
39-
std::vector<int> output_lengths = {2, 1};
40-
desc.SetAttr("axis", axis);
41-
desc.SetAttr("num", num);
42-
desc.SetAttr("sections", output_lengths);
53+
desc.SetAttr("axis", Axis);
54+
desc.SetAttr("num", 0);
55+
desc.SetAttr("sections", sections);
4356

4457
validator.SetOp(*desc.Proto());
4558

46-
validator.Execute(1);
59+
validator.Execute(BatchSize);
60+
}
61+
62+
// batch = 0, axis = 1, same shape
63+
TEST(split_op, test_same_shape_axis1_batch1) {
64+
TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2});
65+
}
66+
// batch = 0, axis = 1, different shape
67+
TEST(split_op, test_different_shape_axis1_batch1) {
68+
TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1});
69+
}
70+
// batch = 10, axis = 1, same shape
71+
TEST(split_op, test_same_shape_axis1_batch10) {
72+
TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2});
73+
}
74+
// batch = 10, axis = 1, different shape
75+
TEST(split_op, test_different_shape_axis1_batch10) {
76+
TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1});
77+
}
78+
// batch = 0, axis = 2, same shape
79+
TEST(split_op, test_same_shape_axis2_batch1) {
80+
TensorRTSplitTest<1, 2>({3, 4, 2}, {2, 2});
81+
}
82+
// batch = 0, axis = 2, different shape
83+
TEST(split_op, test_different_shape_axis2_batch1) {
84+
TensorRTSplitTest<1, 2>({3, 3, 2}, {2, 1});
85+
}
86+
// batch = 10, axis = 2, same shape
87+
TEST(split_op, test_same_shape_axis2_batch10) {
88+
TensorRTSplitTest<10, 2>({3, 4, 2}, {2, 2});
89+
}
90+
// batch = 10, axis = 2, different shape
91+
TEST(split_op, test_different_shape_axis2_batch10) {
92+
TensorRTSplitTest<10, 2>({3, 3, 2}, {2, 1});
93+
}
94+
// batch = 0, axis = 3, same shape
95+
TEST(split_op, test_same_shape_axis3_batch1) {
96+
TensorRTSplitTest<1, 3>({3, 2, 4}, {2, 2});
97+
}
98+
// batch = 0, axis = 3, different shape
99+
TEST(split_op, test_different_shape_axis3_batch1) {
100+
TensorRTSplitTest<1, 3>({3, 2, 3}, {2, 1});
101+
}
102+
// batch = 10, axis = 3, same shape
103+
TEST(split_op, test_same_shape_axis3_batch10) {
104+
TensorRTSplitTest<10, 3>({3, 2, 4}, {2, 2});
105+
}
106+
// batch = 10, axis = 3, different shape
107+
TEST(split_op, test_different_shape_axis3_batch10) {
108+
TensorRTSplitTest<10, 3>({3, 2, 3}, {2, 1});
47109
}
48110

49111
} // namespace tensorrt

paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu

Lines changed: 127 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,61 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <cuda_fp16.h>
16+
#include <algorithm>
1517
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
1618

1719
namespace paddle {
1820
namespace inference {
1921
namespace tensorrt {
2022
namespace plugin {
2123

24+
// copied from operators::math::SplitFunctor
25+
template <typename T>
26+
__global__ void SplitKernel(const T* input_data, const int in_row,
27+
const int in_col, const int* out_cols,
28+
int out_cols_size, T** outputs_data) {
29+
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
30+
int curr_segment = 0;
31+
int curr_offset = out_cols[0];
32+
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
33+
int curr_col_offset = out_cols[curr_segment + 1];
34+
while (curr_col_offset <= tid_x) {
35+
curr_offset = curr_col_offset;
36+
++curr_segment;
37+
curr_col_offset = out_cols[curr_segment + 1];
38+
}
39+
40+
int local_col = tid_x - curr_offset;
41+
int segment_width = curr_col_offset - curr_offset;
42+
T* output_ptr = outputs_data[curr_segment];
43+
if (output_ptr != nullptr) {
44+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
45+
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
46+
output_ptr[tid_y * segment_width + local_col] =
47+
input_data[tid_y * in_col + tid_x];
48+
}
49+
}
50+
}
51+
52+
template <typename T>
53+
__global__ void SplitKernel(const T* input_data, const int in_row,
54+
const int in_col, const int fixed_out_col,
55+
T** outputs_data) {
56+
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
57+
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
58+
int split = tid_x / fixed_out_col;
59+
int in_offset = tid_x - split * fixed_out_col;
60+
T* output_ptr = outputs_data[split];
61+
if (output_ptr != nullptr) {
62+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
63+
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
64+
output_ptr[tid_y * fixed_out_col + in_offset] =
65+
input_data[tid_y * in_col + tid_x];
66+
}
67+
}
68+
}
69+
2270
nvinfer1::Dims SplitPlugin::getOutputDimensions(
2371
int index, const nvinfer1::Dims* input_dims, int num_inputs) {
2472
PADDLE_ENFORCE_EQ(num_inputs, 1);
@@ -31,48 +79,96 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(
3179

3280
int SplitPlugin::initialize() {
3381
PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS);
34-
82+
// notice input dims is [C, H, W]
83+
nvinfer1::Dims dims = this->getInputDims(0);
84+
outer_rows_ = 1;
85+
inner_cols_ = 1;
86+
for (int i = 0; i < axis_; ++i) {
87+
outer_rows_ *= dims.d[i];
88+
}
89+
for (int i = axis_ + 1; i < dims.nbDims; ++i) {
90+
inner_cols_ *= dims.d[i];
91+
}
92+
same_shape_ = true;
3593
std::vector<int> segment_offsets(1, 0);
3694
for (int i = 0; i < this->getNbOutputs(); ++i) {
37-
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
95+
if (output_length_[i] != output_length_[0]) {
96+
same_shape_ = false;
97+
}
98+
segment_offsets.push_back(segment_offsets.back() +
99+
output_length_[i] * inner_cols_);
38100
}
39-
segment_offsets_ = segment_offsets;
40-
nvinfer1::Dims dims = this->getInputDims(0);
41-
nx_ = 1;
42-
for (int i = dims.nbDims - 1; i > axis_; --i) {
43-
nx_ *= dims.d[i];
101+
inner_cols_ *= dims.d[axis_];
102+
d_segment_offsets_ = segment_offsets;
103+
segment_offsets_ = std::move(segment_offsets);
104+
d_output_ptrs_.resize(this->getNbOutputs(), nullptr);
105+
return 0;
106+
}
107+
108+
template <typename T>
109+
inline void Split(cudaStream_t stream, const bool same_shape,
110+
const int outer_rows, const int inner_cols,
111+
const std::vector<int>& segment_offsets,
112+
const int* d_segment_offsets, const T* input, T** outputs) {
113+
const int kThreadsPerBlock = 1024;
114+
const int kMaxBlocks = 65535;
115+
int block_cols = kThreadsPerBlock;
116+
if (inner_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
117+
block_cols = ((inner_cols + 31) >> 5) << 5;
44118
}
45-
ny_ = dims.d[axis_];
46-
nz_ = 1;
47-
for (int i = axis_ - 1; i >= 0; --i) {
48-
nz_ *= dims.d[i];
119+
int block_rows = kThreadsPerBlock / block_cols;
120+
dim3 block_size = dim3(block_cols, block_rows, 1);
121+
122+
int grid_cols =
123+
std::min((inner_cols + block_cols - 1) / block_cols, kMaxBlocks);
124+
int grid_rows =
125+
std::min(kMaxBlocks / grid_cols, std::max(outer_rows / block_rows, 1));
126+
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
127+
128+
if (same_shape) {
129+
SplitKernel<<<grid_size, block_size, 0, stream>>>(
130+
input, outer_rows, inner_cols, segment_offsets[1], outputs);
131+
} else {
132+
SplitKernel<<<grid_size, block_size, 0, stream>>>(
133+
input, outer_rows, inner_cols, d_segment_offsets,
134+
static_cast<int>(segment_offsets.size()), outputs);
49135
}
50-
return 0;
51136
}
52137

53138
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
54139
void** outputs, void* workspace, cudaStream_t stream) {
55-
auto const& input_dims = this->getInputDims(0);
56-
int input_size = 0;
57-
float const* idata = reinterpret_cast<float const*>(inputs[0]);
58-
float** odatas = reinterpret_cast<float**>(outputs);
59-
60-
// kernel impl here.
61-
int inputBatchOffset = nx_ * ny_ * nz_;
62-
for (size_t i = 0; i < this->getNbOutputs(); i++) {
63-
for (size_t j = 0; j < batchSize; j++) {
64-
cudaMemcpyAsync(
65-
odatas[i] +
66-
j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ *
67-
sizeof(float),
68-
inputs[0] +
69-
(inputBatchOffset * j + segment_offsets_[i] * nx_) *
70-
sizeof(float),
71-
(segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float),
72-
cudaMemcpyDeviceToDevice, stream);
140+
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
141+
if (((batchSize == 1 && axis_ == 0) || axis_ == -1) &&
142+
this->getNbOutputs() < 10) {
143+
float** output_ptrs = reinterpret_cast<float**>(outputs);
144+
int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT)
145+
? sizeof(float)
146+
: sizeof(__half);
147+
for (int i = 0; i < this->getNbOutputs(); ++i) {
148+
PADDLE_ENFORCE(
149+
cudaMemcpyAsync(
150+
output_ptrs[i], input_ptr + segment_offsets_[i],
151+
(segment_offsets_[i + 1] - segment_offsets_[i]) * data_type_size,
152+
cudaMemcpyDeviceToDevice, stream) == cudaSuccess);
153+
}
154+
} else {
155+
outer_rows_ *= batchSize;
156+
const int* d_segment_offsets_ptr =
157+
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
158+
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
159+
PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, outputs,
160+
this->getNbOutputs() * sizeof(float*),
161+
cudaMemcpyHostToDevice,
162+
stream) == cudaSuccess);
163+
if (this->getDataType() == nvinfer1::DataType::kFLOAT) {
164+
Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
165+
d_segment_offsets_ptr, input_ptr, output_ptrs);
166+
} else {
167+
Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
168+
d_segment_offsets_ptr, (__half*)input_ptr, // NOLINT
169+
(__half**)output_ptrs); // NOLINT
73170
}
74171
}
75-
76172
return cudaGetLastError() != cudaSuccess;
77173
}
78174

paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include <thrust/device_vector.h>
1718
#include <vector>
1819
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
1920

@@ -25,7 +26,7 @@ namespace plugin {
2526
class SplitPlugin : public PluginTensorRT {
2627
public:
2728
SplitPlugin(int axis, std::vector<int> const &output_lengths)
28-
: axis_(axis), output_length_(output_lengths) {}
29+
: axis_(axis), same_shape_(true), output_length_(output_lengths) {}
2930

3031
SplitPlugin(void const *serial_data, size_t serial_length) {
3132
deserializeBase(serial_data, serial_length);
@@ -60,9 +61,13 @@ class SplitPlugin : public PluginTensorRT {
6061
}
6162

6263
int axis_;
64+
int outer_rows_;
65+
int inner_cols_;
66+
bool same_shape_;
6367
std::vector<int> output_length_;
64-
int nx_, ny_, nz_;
6568
std::vector<int> segment_offsets_;
69+
thrust::device_vector<int> d_segment_offsets_;
70+
thrust::device_vector<float *> d_output_ptrs_;
6671
};
6772

6873
} // namespace plugin

0 commit comments

Comments
 (0)