@@ -22,7 +22,7 @@ namespace math {
2222// TODO(zcd): This can be replaced by tensor,
2323// if that, maybe we should add int8 to VarType::Type.
2424// Or replaced by tensorArray.
25- static constexpr int MaxSize = 32 ;
25+ static constexpr int MaxSize = 8 ;
2626template <typename T>
2727struct CUDADeviceArray {
2828 T data[MaxSize];
@@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
5454 const int output_rows, const int output_cols,
5555 T* output) {
5656 int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
57- int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
5857 int segment = upper_bound<int >(input_cols.data , input_cols.size , tid_x) - 1 ;
5958
6059 int curr_offset = input_cols.data [segment];
@@ -69,31 +68,87 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
6968 int local_col = tid_x - curr_offset;
7069 int segment_width = curr_col_offset - curr_offset;
7170 const T* input_ptr = inputs.data [curr_segment];
72-
71+ int tid_y = blockIdx . y * blockDim . y + threadIdx . y ;
7372 for (; tid_y < output_rows; tid_y += blockDim .y * gridDim .y )
7473 output[tid_y * output_cols + tid_x] =
7574 input_ptr[tid_y * segment_width + local_col];
7675 }
7776}
7877
78+ template <typename T>
79+ __global__ void KernelConcat (const CUDADeviceArray<const T*> inputs,
80+ const int input_col, const int output_rows,
81+ const int output_cols, T* output) {
82+ int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
83+ float inv_input_col = 1.0 / input_col;
84+ for (; tid_x < output_cols; tid_x += blockDim .x * gridDim .x ) {
85+ int split = tid_x * inv_input_col;
86+ int in_offset = tid_x - split * input_col;
87+ const T* input_ptr = inputs.data [split];
88+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
89+ for (; tid_y < output_rows; tid_y += blockDim .y * gridDim .y )
90+ output[tid_y * output_cols + tid_x] =
91+ input_ptr[tid_y * input_col + in_offset];
92+ }
93+ }
94+
95+ template <typename T>
96+ __global__ void KernelConcatGrad (const T* input, const int input_row,
97+ const int input_col,
98+ CUDADeviceArray<int > output_cols,
99+ CUDADeviceArray<T*> outputs) {
100+ int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
101+ int segment = upper_bound<int >(output_cols.data , output_cols.size , tid_x) - 1 ;
102+ int curr_offset = output_cols.data [segment];
103+ int curr_segment = segment;
104+ for (; tid_x < input_col; tid_x += blockDim .x * gridDim .x ) {
105+ T curr_col_offset;
106+ while ((curr_col_offset = output_cols.data [curr_segment + 1 ]) <= tid_x) {
107+ curr_offset = curr_col_offset;
108+ ++curr_segment;
109+ }
110+
111+ int local_col = tid_x - curr_offset;
112+ int segment_width = curr_col_offset - curr_offset;
113+ T* output_ptr = outputs.data [curr_segment];
114+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
115+ for (; tid_y < input_row; tid_y += blockDim .y * gridDim .y )
116+ output_ptr[tid_y * segment_width + local_col] =
117+ input[tid_y * input_col + tid_x];
118+ }
119+ }
120+
121+ template <typename T>
122+ __global__ void KernelConcatGrad (const T* input, const int input_row,
123+ const int input_col, const int output_cols,
124+ CUDADeviceArray<T*> outputs) {
125+ int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
126+ float inv_input_col = 1.0 / input_col;
127+ for (; tid_x < input_col; tid_x += blockDim .x * gridDim .x ) {
128+ int split = tid_x * inv_input_col;
129+ int in_offset = tid_x - split * input_col;
130+ T* output_ptr = outputs.data [split];
131+ int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
132+ for (; tid_y < input_row; tid_y += blockDim .y * gridDim .y )
133+ output_ptr[tid_y * output_cols + in_offset] =
134+ input[tid_y * input_col + tid_x];
135+ }
136+ }
137+
79138/*
80139 * All tensors' dimension should be the same.
81140 */
82141template <typename T>
83142class ConcatFunctor <platform::CUDADeviceContext, T> {
84143 public:
85144 void operator ()(const platform::CUDADeviceContext& context,
86- std::vector<framework::Tensor>& input, const int axis,
145+ const std::vector<framework::Tensor>& input, const int axis,
87146 framework::Tensor* output) {
88147 // assume the the max size of input is less than 8 and see the performance
89148 // save origin dim
90149 int num = input.size ();
91- // std::vector<paddle::framework::DDim> origin_dim(num);
92- // for (int j = 0; j < num; ++j) {
93- // origin_dim[j] = input[j].dims();
94- // }
95- auto out_dim = output->dims ();
96-
150+ PADDLE_ENFORCE_LT (num, MaxSize, " input number should be less than %d" ,
151+ MaxSize);
97152 // get the matrix size
98153 int rows = 1 ;
99154 auto dim_0 = input[0 ].dims ();
@@ -117,30 +172,96 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
117172 if (t_cols != cols) sameShape = false ;
118173 }
119174 out_cols += t_cols;
120- input[i].Resize ({rows, t_cols});
121175 inputs_cols.data [i + 1 ] = out_cols;
122176 inputs_data.data [i] = input[i].data <T>();
123177 }
124- output->Resize ({out_rows, out_cols});
125178
126179 // computation
127- const int kThreadsPerBlock = 256 ;
180+ // set the thread block and grid according to CurrentDeviceId
181+ const int kThreadsPerBlock = 1024 ;
128182 int block_cols = std::min (out_cols, kThreadsPerBlock );
129183 int block_rows = std::max (kThreadsPerBlock / block_cols, 1 );
130184 dim3 block_size = dim3 (block_cols, block_rows, 1 );
131185
132- int grid_cols = (out_cols + block_cols - 1 ) / block_cols;
133- int grid_rows = (out_rows + block_rows - 1 ) / block_rows;
186+ int dev_id = paddle::platform::GetCurrentDeviceId ();
187+ int multi_process = paddle::platform::GetCUDAMultiProcessors (dev_id);
188+ int max_threads_per_mp =
189+ paddle::platform::GetCUDAMaxThreadsPerMultiProcessor (dev_id);
190+ int max_threads = multi_process * max_threads_per_mp;
191+ int max_blocks = std::max (max_threads / kThreadsPerBlock , 1 );
192+
193+ int grid_cols =
194+ std::min ((out_cols + block_cols - 1 ) / block_cols, max_blocks);
195+ int grid_rows =
196+ std::min (max_blocks / grid_cols, std::max (out_rows / block_rows, 1 ));
134197 dim3 grid_size = dim3 (grid_cols, grid_rows, 1 );
135198
136- KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
137- inputs_data, inputs_cols, out_rows, out_cols, output->data <T>());
199+ if (sameShape) {
200+ KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
201+ inputs_data, cols, out_rows, out_cols, output->data <T>());
202+ } else {
203+ KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
204+ inputs_data, inputs_cols, out_rows, out_cols, output->data <T>());
205+ }
206+ }
207+ };
208+
209+ template <typename T>
210+ class ConcatGradFunctor <platform::CUDADeviceContext, T> {
211+ public:
212+ void operator ()(const platform::CUDADeviceContext& context,
213+ const framework::Tensor& input, const int axis,
214+ std::vector<framework::Tensor>& outputs) {
215+ // assume the the max size of input is less than 8 and see the performance
216+ // save origin dim
217+ int num = outputs.size ();
218+ PADDLE_ENFORCE_LT (num, MaxSize, " input number should be less than %d" ,
219+ MaxSize);
220+
221+ // get the matrix size
222+ int input_row = 1 ;
223+ auto dim_0 = outputs[0 ].dims ();
224+ for (int i = 0 ; i < axis; ++i) {
225+ input_row *= dim_0[i];
226+ }
227+
228+ int output_col_0 = outputs[0 ].numel () / input_row;
229+ int input_col = 0 ;
230+ bool sameShape = true ;
231+
232+ CUDADeviceArray<T*> outputs_data;
233+ CUDADeviceArray<int > outputs_cols;
234+ outputs_data.size = num;
235+ outputs_cols.size = num + 1 ;
236+ outputs_cols.data [0 ] = 0 ;
138237
139- // recover origin dim
140- // for (int j = 0; j < num; ++j) {
141- // input[j].Resize(origin_dim[j]);
142- // }
143- output->Resize (out_dim);
238+ for (int i = 0 ; i < num; ++i) {
239+ int t_col = outputs[i].numel () / input_row;
240+ if (sameShape) {
241+ if (t_col != output_col_0) sameShape = false ;
242+ }
243+ input_col += t_col;
244+ outputs_cols.data [i + 1 ] = input_col;
245+ outputs_data.data [i] = outputs[i].data <T>();
246+ }
247+
248+ // computation
249+ const int kThreadsPerBlock = 256 ;
250+ int block_cols = std::min (input_col, kThreadsPerBlock );
251+ int block_rows = std::max (kThreadsPerBlock / block_cols, 1 );
252+ dim3 block_size = dim3 (block_cols, block_rows, 1 );
253+
254+ int grid_cols = (input_col + block_cols - 1 ) / block_cols;
255+ int grid_rows = (input_row + block_rows - 1 ) / block_rows;
256+ dim3 grid_size = dim3 (grid_cols, grid_rows, 1 );
257+
258+ if (sameShape) {
259+ KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
260+ input.data <T>(), input_row, input_col, output_col_0, outputs_data);
261+ } else {
262+ KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
263+ input.data <T>(), input_row, input_col, outputs_cols, outputs_data);
264+ }
144265 }
145266};
146267
@@ -149,6 +270,11 @@ template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
149270template class ConcatFunctor <platform::CUDADeviceContext, float >;
150271template class ConcatFunctor <platform::CUDADeviceContext, double >;
151272
273+ template class ConcatGradFunctor <platform::CUDADeviceContext, int >;
274+ template class ConcatGradFunctor <platform::CUDADeviceContext, int64_t >;
275+ template class ConcatGradFunctor <platform::CUDADeviceContext, float >;
276+ template class ConcatGradFunctor <platform::CUDADeviceContext, double >;
277+
152278} // namespace math
153279} // namespace operators
154280} // namespace paddle
0 commit comments