@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
256
256
* 3. go to the second setp, until one thread's topk value is null;
257
257
* 4. go to the first setp, until get the topk value.
258
258
*/
259
+
259
260
template <typename T, int MaxLength, int BlockSize>
260
261
__global__ void KeMatrixTopK (T* output, int output_stride, int64_t * indices,
261
- const T* src, int lds, int dim, int k) {
262
+ const T* src, int lds, int dim, int k,
263
+ int grid_dim, int num) {
262
264
__shared__ Pair<T> sh_topk[BlockSize];
263
265
__shared__ int maxid[BlockSize / 2 ];
264
266
const int tid = threadIdx .x ;
265
267
const int warp = threadIdx .x / 32 ;
266
- output += blockIdx .x * output_stride;
267
- indices += blockIdx .x * k;
268
268
269
- Pair<T> topk[MaxLength];
270
- int beam = MaxLength;
271
- Pair<T> max;
272
- bool is_empty = false ;
273
- bool firststep = true ;
269
+ const int bid = blockIdx .x ;
270
+ for (int i = bid; i < num; i += grid_dim) {
271
+ output += i * output_stride;
272
+ indices += i * k;
273
+
274
+ Pair<T> topk[MaxLength];
275
+ int beam = MaxLength;
276
+ Pair<T> max;
277
+ bool is_empty = false ;
278
+ bool firststep = true ;
279
+
280
+ for (int k = 0 ; k < MaxLength; k++) {
281
+ topk[k].set (-INFINITY, -1 );
282
+ }
283
+ while (k) {
284
+ ThreadGetTopK<T, MaxLength, BlockSize>(
285
+ topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
274
286
275
- for (int k = 0 ; k < MaxLength; k++) {
276
- topk[k].set (-INFINITY, -1 );
287
+ sh_topk[tid] = topk[0 ];
288
+ BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
289
+ &indices, &beam, &k, tid, warp);
290
+ }
277
291
}
278
- while (k) {
279
- ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k,
280
- src + blockIdx .x * lds, &firststep,
281
- &is_empty, &max, dim, tid);
282
-
283
- sh_topk[tid] = topk[0 ];
284
- BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
285
- &indices, &beam, &k, tid, warp);
292
+ }
293
+
294
+ inline static int GetDesiredBlockDim (int dim) {
295
+ if (dim > 128 ) {
296
+ return 256 ;
297
+ } else if (dim > 64 ) {
298
+ return 128 ;
299
+ } else if (dim > 32 ) {
300
+ return 64 ;
301
+ } else {
302
+ return 32 ;
286
303
}
287
304
}
288
305
306
+ #define FIXED_BLOCK_DIM_BASE (dim, ...) \
307
+ case (dim): { \
308
+ constexpr auto kBlockDim = (dim); \
309
+ __VA_ARGS__; \
310
+ } break
311
+
312
+ #define FIXED_BLOCK_DIM (...) \
313
+ FIXED_BLOCK_DIM_BASE (256 , ##__VA_ARGS__); \
314
+ FIXED_BLOCK_DIM_BASE (128 , ##__VA_ARGS__); \
315
+ FIXED_BLOCK_DIM_BASE (64 , ##__VA_ARGS__); \
316
+ FIXED_BLOCK_DIM_BASE (32 , ##__VA_ARGS__)
317
+
289
318
template <typename T>
290
319
class TopkOpCUDAKernel : public framework ::OpKernel<T> {
291
320
public:
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
310
339
// NOTE: pass lds and dim same to input width.
311
340
// NOTE: old matrix implementation of stride is different to eigen.
312
341
// TODO(typhoonzero): refine this kernel.
313
- dim3 threads (256 , 1 );
314
- dim3 grid (input_height, 1 );
315
-
316
- KeMatrixTopK<T, 5 , 256 ><<<
317
- grid, threads, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
318
- ctx.device_context())
319
- .stream()>>> (
320
- output_data, output->dims ()[1 ], indices_data, input_data, input_width,
321
- input_width, static_cast <int >(k));
342
+ const int kMaxHeight = 2048 ;
343
+ int gridx = input_height < kMaxHeight ? input_height : kMaxHeight ;
344
+ auto & dev_ctx = ctx.cuda_device_context ();
345
+
346
+ switch (GetDesiredBlockDim (input_width)) {
347
+ FIXED_BLOCK_DIM (
348
+ KeMatrixTopK<T, 5 ,
349
+ kBlockDim ><<<gridx, kBlockDim , 0 , dev_ctx.stream()>>> (
350
+ output_data, output->dims ()[1 ], indices_data, input_data,
351
+ input_width, input_width, static_cast <int >(k), gridx,
352
+ input_height));
353
+ default :
354
+ PADDLE_THROW (" Error" );
355
+ }
322
356
}
323
357
};
324
358
359
+ #undef FIXED_BLOCK_DIM_BASE
360
+ #undef FIXED_BLOCK_DIM
361
+
325
362
} // namespace operators
326
363
} // namespace paddle
327
364
0 commit comments