@@ -70,7 +70,7 @@ __global__ void KernelConcat(T** inputs, const int input_col,
70
70
const int output_rows, const int output_cols,
71
71
T* output) {
72
72
int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
73
- float inv_input_col = 1.0 / input_col;
73
+ double inv_input_col = 1.0 / input_col;
74
74
for (; tid_x < output_cols; tid_x += blockDim .x * gridDim .x ) {
75
75
int split = tid_x * inv_input_col;
76
76
int in_offset = tid_x - split * input_col;
@@ -113,7 +113,7 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
113
113
const int input_col, const int output_cols,
114
114
T** outputs) {
115
115
int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
116
- float inv_input_col = 1.0 / input_col;
116
+ double inv_input_col = 1.0 / input_col;
117
117
for (; tid_x < input_col; tid_x += blockDim .x * gridDim .x ) {
118
118
int split = tid_x * inv_input_col;
119
119
int in_offset = tid_x - split * input_col;
@@ -145,8 +145,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
145
145
int cols = input[0 ].numel () / rows;
146
146
int out_rows = rows, out_cols = 0 ;
147
147
148
- paddle:: framework::Vector<int16_t > inputs_data (num * sizeof (T*) / 2 );
149
- paddle:: framework::Vector<int > inputs_cols (num + 1 );
148
+ framework::Vector<int16_t > inputs_data (num * sizeof (T*) / 2 );
149
+ framework::Vector<int > inputs_cols (num + 1 );
150
150
inputs_cols[0 ] = 0 ;
151
151
T** inputs_ptr = reinterpret_cast <T**>(inputs_data.data ());
152
152
@@ -168,15 +168,14 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
168
168
// computation
169
169
// set the thread block and grid according to CurrentDeviceId
170
170
const int kThreadsPerBlock = 1024 ;
171
- int block_cols = std::min (out_cols, kThreadsPerBlock );
172
- int block_rows = std::max (kThreadsPerBlock / block_cols, 1 );
171
+ int block_cols = kThreadsPerBlock ;
172
+ if (out_cols < kThreadsPerBlock ) { // block_cols is aligned by 32.
173
+ block_cols = ((out_cols + 31 ) >> 5 ) << 5 ;
174
+ }
175
+ int block_rows = kThreadsPerBlock / block_cols;
173
176
dim3 block_size = dim3 (block_cols, block_rows, 1 );
174
177
175
- int dev_id = paddle::platform::GetCurrentDeviceId ();
176
- int multi_process = paddle::platform::GetCUDAMultiProcessors (dev_id);
177
- int max_threads_per_mp =
178
- paddle::platform::GetCUDAMaxThreadsPerMultiProcessor (dev_id);
179
- int max_threads = multi_process * max_threads_per_mp;
178
+ int max_threads = context.GetMaxPhysicalThreadCount ();
180
179
int max_blocks = std::max (max_threads / kThreadsPerBlock , 1 );
181
180
182
181
int grid_cols =
@@ -218,8 +217,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
218
217
int input_col = 0 ;
219
218
bool sameShape = true ;
220
219
221
- paddle:: framework::Vector<int16_t > outputs_data (num * sizeof (T*) / 2 );
222
- paddle:: framework::Vector<int > outputs_cols (num + 1 );
220
+ framework::Vector<int16_t > outputs_data (num * sizeof (T*) / 2 );
221
+ framework::Vector<int > outputs_cols (num + 1 );
223
222
outputs_cols[0 ] = 0 ;
224
223
T** outputs_ptr = reinterpret_cast <T**>(outputs_data.data ());
225
224
@@ -239,12 +238,20 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
239
238
240
239
// computation
241
240
const int kThreadsPerBlock = 1024 ;
242
- int block_cols = std::min (input_col, kThreadsPerBlock );
243
- int block_rows = std::max (kThreadsPerBlock / block_cols, 1 );
241
+ int block_cols = kThreadsPerBlock ;
242
+ if (input_col < kThreadsPerBlock ) { // block_cols is aligned by 32.
243
+ block_cols = ((input_col + 31 ) >> 5 ) << 5 ;
244
+ }
245
+ int block_rows = kThreadsPerBlock / block_cols;
244
246
dim3 block_size = dim3 (block_cols, block_rows, 1 );
245
247
246
- int grid_cols = (input_col + block_cols - 1 ) / block_cols;
247
- int grid_rows = (input_row + block_rows - 1 ) / block_rows;
248
+ int max_threads = context.GetMaxPhysicalThreadCount ();
249
+ int max_blocks = std::max (max_threads / kThreadsPerBlock , 1 );
250
+
251
+ int grid_cols =
252
+ std::min ((input_col + block_cols - 1 ) / block_cols, max_blocks);
253
+ int grid_rows =
254
+ std::min (max_blocks / grid_cols, std::max (input_row / block_rows, 1 ));
248
255
dim3 grid_size = dim3 (grid_cols, grid_rows, 1 );
249
256
250
257
if (sameShape) {
0 commit comments