12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ #include < cuda_fp16.h>
16
+ #include < algorithm>
15
17
#include " paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
16
18
17
19
namespace paddle {
18
20
namespace inference {
19
21
namespace tensorrt {
20
22
namespace plugin {
21
23
24
+ // copied from operators::math::SplitFunctor
25
+ template <typename T>
26
+ __global__ void SplitKernel (const T* input_data, const int in_row,
27
+ const int in_col, const int * out_cols,
28
+ int out_cols_size, T** outputs_data) {
29
+ int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
30
+ int curr_segment = 0 ;
31
+ int curr_offset = out_cols[0 ];
32
+ for (; tid_x < in_col; tid_x += blockDim .x * gridDim .x ) {
33
+ int curr_col_offset = out_cols[curr_segment + 1 ];
34
+ while (curr_col_offset <= tid_x) {
35
+ curr_offset = curr_col_offset;
36
+ ++curr_segment;
37
+ curr_col_offset = out_cols[curr_segment + 1 ];
38
+ }
39
+
40
+ int local_col = tid_x - curr_offset;
41
+ int segment_width = curr_col_offset - curr_offset;
42
+ T* output_ptr = outputs_data[curr_segment];
43
+ if (output_ptr != nullptr ) {
44
+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
45
+ for (; tid_y < in_row; tid_y += blockDim .y * gridDim .y )
46
+ output_ptr[tid_y * segment_width + local_col] =
47
+ input_data[tid_y * in_col + tid_x];
48
+ }
49
+ }
50
+ }
51
+
52
+ template <typename T>
53
+ __global__ void SplitKernel (const T* input_data, const int in_row,
54
+ const int in_col, const int fixed_out_col,
55
+ T** outputs_data) {
56
+ int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
57
+ for (; tid_x < in_col; tid_x += blockDim .x * gridDim .x ) {
58
+ int split = tid_x / fixed_out_col;
59
+ int in_offset = tid_x - split * fixed_out_col;
60
+ T* output_ptr = outputs_data[split];
61
+ if (output_ptr != nullptr ) {
62
+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
63
+ for (; tid_y < in_row; tid_y += blockDim .y * gridDim .y )
64
+ output_ptr[tid_y * fixed_out_col + in_offset] =
65
+ input_data[tid_y * in_col + tid_x];
66
+ }
67
+ }
68
+ }
69
+
22
70
nvinfer1::Dims SplitPlugin::getOutputDimensions (
23
71
int index, const nvinfer1::Dims* input_dims, int num_inputs) {
24
72
PADDLE_ENFORCE_EQ (num_inputs, 1 );
@@ -31,48 +79,96 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(
31
79
32
80
int SplitPlugin::initialize () {
33
81
PADDLE_ENFORCE_LE (axis_, nvinfer1::Dims::MAX_DIMS);
34
-
82
+ // notice input dims is [C, H, W]
83
+ nvinfer1::Dims dims = this ->getInputDims (0 );
84
+ outer_rows_ = 1 ;
85
+ inner_cols_ = 1 ;
86
+ for (int i = 0 ; i < axis_; ++i) {
87
+ outer_rows_ *= dims.d [i];
88
+ }
89
+ for (int i = axis_ + 1 ; i < dims.nbDims ; ++i) {
90
+ inner_cols_ *= dims.d [i];
91
+ }
92
+ same_shape_ = true ;
35
93
std::vector<int > segment_offsets (1 , 0 );
36
94
for (int i = 0 ; i < this ->getNbOutputs (); ++i) {
37
- segment_offsets.push_back (segment_offsets.back () + output_length_[i]);
95
+ if (output_length_[i] != output_length_[0 ]) {
96
+ same_shape_ = false ;
97
+ }
98
+ segment_offsets.push_back (segment_offsets.back () +
99
+ output_length_[i] * inner_cols_);
38
100
}
39
- segment_offsets_ = segment_offsets;
40
- nvinfer1::Dims dims = this ->getInputDims (0 );
41
- nx_ = 1 ;
42
- for (int i = dims.nbDims - 1 ; i > axis_; --i) {
43
- nx_ *= dims.d [i];
101
+ inner_cols_ *= dims.d [axis_];
102
+ d_segment_offsets_ = segment_offsets;
103
+ segment_offsets_ = std::move (segment_offsets);
104
+ d_output_ptrs_.resize (this ->getNbOutputs (), nullptr );
105
+ return 0 ;
106
+ }
107
+
108
+ template <typename T>
109
+ inline void Split (cudaStream_t stream, const bool same_shape,
110
+ const int outer_rows, const int inner_cols,
111
+ const std::vector<int >& segment_offsets,
112
+ const int * d_segment_offsets, const T* input, T** outputs) {
113
+ const int kThreadsPerBlock = 1024 ;
114
+ const int kMaxBlocks = 65535 ;
115
+ int block_cols = kThreadsPerBlock ;
116
+ if (inner_cols < kThreadsPerBlock ) { // block_cols is aligned by 32.
117
+ block_cols = ((inner_cols + 31 ) >> 5 ) << 5 ;
44
118
}
45
- ny_ = dims.d [axis_];
46
- nz_ = 1 ;
47
- for (int i = axis_ - 1 ; i >= 0 ; --i) {
48
- nz_ *= dims.d [i];
119
+ int block_rows = kThreadsPerBlock / block_cols;
120
+ dim3 block_size = dim3 (block_cols, block_rows, 1 );
121
+
122
+ int grid_cols =
123
+ std::min ((inner_cols + block_cols - 1 ) / block_cols, kMaxBlocks );
124
+ int grid_rows =
125
+ std::min (kMaxBlocks / grid_cols, std::max (outer_rows / block_rows, 1 ));
126
+ dim3 grid_size = dim3 (grid_cols, grid_rows, 1 );
127
+
128
+ if (same_shape) {
129
+ SplitKernel<<<grid_size, block_size, 0 , stream>>> (
130
+ input, outer_rows, inner_cols, segment_offsets[1 ], outputs);
131
+ } else {
132
+ SplitKernel<<<grid_size, block_size, 0 , stream>>> (
133
+ input, outer_rows, inner_cols, d_segment_offsets,
134
+ static_cast <int >(segment_offsets.size ()), outputs);
49
135
}
50
- return 0 ;
51
136
}
52
137
53
138
int SplitPlugin::enqueue (int batchSize, const void * const * inputs,
54
139
void ** outputs, void * workspace, cudaStream_t stream) {
55
- auto const & input_dims = this ->getInputDims (0 );
56
- int input_size = 0 ;
57
- float const * idata = reinterpret_cast <float const *>(inputs[0 ]);
58
- float ** odatas = reinterpret_cast <float **>(outputs);
59
-
60
- // kernel impl here.
61
- int inputBatchOffset = nx_ * ny_ * nz_;
62
- for (size_t i = 0 ; i < this ->getNbOutputs (); i++) {
63
- for (size_t j = 0 ; j < batchSize; j++) {
64
- cudaMemcpyAsync (
65
- odatas[i] +
66
- j * (segment_offsets_[i + 1 ] - segment_offsets_[i]) * nx_ *
67
- sizeof (float ),
68
- inputs[0 ] +
69
- (inputBatchOffset * j + segment_offsets_[i] * nx_) *
70
- sizeof (float ),
71
- (segment_offsets_[i + 1 ] - segment_offsets_[i]) * nx_ * sizeof (float ),
72
- cudaMemcpyDeviceToDevice, stream);
140
+ float const * input_ptr = reinterpret_cast <float const *>(inputs[0 ]);
141
+ if (((batchSize == 1 && axis_ == 0 ) || axis_ == -1 ) &&
142
+ this ->getNbOutputs () < 10 ) {
143
+ float ** output_ptrs = reinterpret_cast <float **>(outputs);
144
+ int data_type_size = (this ->getDataType () == nvinfer1::DataType::kFLOAT )
145
+ ? sizeof (float )
146
+ : sizeof (__half);
147
+ for (int i = 0 ; i < this ->getNbOutputs (); ++i) {
148
+ PADDLE_ENFORCE (
149
+ cudaMemcpyAsync (
150
+ output_ptrs[i], input_ptr + segment_offsets_[i],
151
+ (segment_offsets_[i + 1 ] - segment_offsets_[i]) * data_type_size,
152
+ cudaMemcpyDeviceToDevice, stream) == cudaSuccess);
153
+ }
154
+ } else {
155
+ outer_rows_ *= batchSize;
156
+ const int * d_segment_offsets_ptr =
157
+ thrust::raw_pointer_cast (&d_segment_offsets_[0 ]);
158
+ float ** output_ptrs = thrust::raw_pointer_cast (&d_output_ptrs_[0 ]);
159
+ PADDLE_ENFORCE (cudaMemcpyAsync (output_ptrs, outputs,
160
+ this ->getNbOutputs () * sizeof (float *),
161
+ cudaMemcpyHostToDevice,
162
+ stream) == cudaSuccess);
163
+ if (this ->getDataType () == nvinfer1::DataType::kFLOAT ) {
164
+ Split (stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
165
+ d_segment_offsets_ptr, input_ptr, output_ptrs);
166
+ } else {
167
+ Split (stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_,
168
+ d_segment_offsets_ptr, (__half*)input_ptr, // NOLINT
169
+ (__half**)output_ptrs); // NOLINT
73
170
}
74
171
}
75
-
76
172
return cudaGetLastError () != cudaSuccess;
77
173
}
78
174
0 commit comments