@@ -12,11 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include " cub/cub.cuh"
15
16
#include " paddle/fluid/framework/op_registry.h"
16
17
#include " paddle/fluid/operators/top_k_op.h"
17
18
#include " paddle/fluid/platform/cuda_device_function.h"
18
19
#include " paddle/fluid/platform/float16.h"
19
20
21
+ // set cub base traits in order to handle float16
22
+ namespace cub {
23
+ template <>
24
+ struct NumericTraits <paddle::platform::float16>
25
+ : BaseTraits<FLOATING_POINT, true , false , uint16_t ,
26
+ paddle::platform::float16> {};
27
+ } // namespace cub
28
+
20
29
namespace paddle {
21
30
namespace operators {
22
31
@@ -303,6 +312,160 @@ inline static int GetDesiredBlockDim(int dim) {
303
312
}
304
313
}
305
314
315
+ // Iter for move to next row
316
+ struct SegmentOffsetIter {
317
+ EIGEN_DEVICE_FUNC
318
+ explicit SegmentOffsetIter (int num_cols) : num_cols_(num_cols) {}
319
+
320
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator ()(int idx) const {
321
+ return idx * num_cols_;
322
+ }
323
+
324
+ int num_cols_;
325
+ };
326
+
327
+ // Iter using into a column
328
+ struct ColumnIndexIter {
329
+ explicit ColumnIndexIter (int num_cols) : num_cols_(num_cols) {}
330
+
331
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator ()(
332
+ const Eigen::array<int , 1 >& ix) const {
333
+ return ix[0 ] % num_cols_;
334
+ }
335
+
336
+ int num_cols_;
337
+ };
338
+
339
+ __global__ void InitIndex (int64_t * indices, int64_t num_rows,
340
+ int64_t num_cols) {
341
+ int col_id = threadIdx .x ;
342
+ int row_id = blockIdx .x ;
343
+
344
+ for (int64_t j = row_id; j < num_rows; j += gridDim .x ) {
345
+ for (int64_t i = col_id; i < num_cols; i += blockDim .x ) {
346
+ indices[j * num_cols + i] = i;
347
+ }
348
+ }
349
+ }
350
+
351
+ template <typename T>
352
+ bool SortTopk (const platform::CUDADeviceContext& ctx,
353
+ const framework::Tensor* input_tensor, const int64_t num_cols,
354
+ const int64_t num_rows, const int k,
355
+ framework::Tensor* out_tensor,
356
+ framework::Tensor* indices_tensor) {
357
+ auto cu_stream = ctx.stream ();
358
+
359
+ Tensor input_indices;
360
+ const std::vector<int64_t > dims = {num_rows, num_cols};
361
+ auto dim = framework::make_ddim (dims);
362
+ input_indices.Resize (dim);
363
+ // input_indices.Resize(num_rows*num_cols);
364
+ input_indices.mutable_data <int64_t >(ctx.GetPlace ());
365
+ size_t temp_storage_bytes = -1 ;
366
+
367
+ auto ComputeBlockSize = [](int col) {
368
+ if (col > 512 )
369
+ return 1024 ;
370
+ else if (col > 256 && col <= 512 )
371
+ return 512 ;
372
+ else if (col > 128 && col <= 256 )
373
+ return 256 ;
374
+ else if (col > 64 && col <= 128 )
375
+ return 128 ;
376
+ else
377
+ return 64 ;
378
+ };
379
+
380
+ int block_size = ComputeBlockSize (num_cols);
381
+
382
+ unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize ().x ;
383
+ // actually, int num_rows < max_grid_size
384
+ unsigned int grid_size = num_rows < maxGridDimX
385
+ ? static_cast <unsigned int >(num_rows)
386
+ : maxGridDimX;
387
+ // Init a index array
388
+ InitIndex<<<grid_size, block_size, 0 , cu_stream>>> (
389
+ input_indices.data <int64_t >(), num_rows, num_cols);
390
+
391
+ // create iter for counting input
392
+ cub::CountingInputIterator<int64_t > counting_iter (0 );
393
+ // segment_offset is used for move to next row
394
+ cub::TransformInputIterator<int64_t , SegmentOffsetIter,
395
+ cub::CountingInputIterator<int64_t >>
396
+ segment_offsets_t (counting_iter, SegmentOffsetIter (num_cols));
397
+
398
+ T* sorted_values_ptr;
399
+ int64_t * sorted_indices_ptr;
400
+
401
+ Tensor temp_values;
402
+ Tensor temp_indices;
403
+
404
+ const T* input = input_tensor->data <T>();
405
+ T* values = out_tensor->data <T>();
406
+ int64_t * indices = indices_tensor->mutable_data <int64_t >(ctx.GetPlace ());
407
+
408
+ if (k == num_cols) {
409
+ // Doing a full sort.
410
+ sorted_values_ptr = values;
411
+ sorted_indices_ptr = indices;
412
+ } else {
413
+ temp_values.Resize (dim);
414
+ temp_indices.Resize (dim);
415
+ sorted_values_ptr = temp_values.mutable_data <T>(ctx.GetPlace ());
416
+ sorted_indices_ptr = temp_indices.mutable_data <int64_t >(ctx.GetPlace ());
417
+ }
418
+
419
+ // Get temp storage buffer size, maybe can allocate a fixed buffer to save
420
+ // time.
421
+ auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
422
+ nullptr , temp_storage_bytes, input, sorted_values_ptr,
423
+ input_indices.data <int64_t >(), sorted_indices_ptr, num_cols * num_rows,
424
+ num_rows, segment_offsets_t , segment_offsets_t + 1 , 0 , sizeof (T) * 8 ,
425
+ cu_stream);
426
+ if (err != cudaSuccess) {
427
+ LOG (ERROR)
428
+ << " TopKOP failed as could not launch "
429
+ " cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
430
+ " temp_storage_bytes, status: "
431
+ << cudaGetErrorString (err);
432
+ return false ;
433
+ }
434
+ Tensor temp_storage;
435
+ temp_storage.mutable_data <uint8_t >(ctx.GetPlace (), temp_storage_bytes);
436
+
437
+ err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
438
+ temp_storage.data <uint8_t >(), temp_storage_bytes, input,
439
+ sorted_values_ptr, input_indices.data <int64_t >(), sorted_indices_ptr,
440
+ num_cols * num_rows, num_rows, segment_offsets_t , segment_offsets_t + 1 ,
441
+ 0 , sizeof (T) * 8 , cu_stream);
442
+ if (err != cudaSuccess) {
443
+ LOG (ERROR)
444
+ << " TopKOP failed as could not launch "
445
+ " cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
446
+ " temp_storage_bytes: "
447
+ << temp_storage_bytes << " , status: " << cudaGetErrorString (err);
448
+ return false ;
449
+ }
450
+ auto & dev = *ctx.eigen_device ();
451
+ if (k < num_cols) {
452
+ // copy sliced data to output.
453
+ const Eigen::DSizes<Eigen::DenseIndex, 2 > slice_indices{0 , 0 };
454
+ const Eigen::DSizes<Eigen::DenseIndex, 2 > slice_sizes{num_rows, k};
455
+ auto e_indices = EigenMatrix<int64_t >::From (*indices_tensor, dim);
456
+ auto e_tmp_indices = EigenMatrix<int64_t >::From (temp_indices);
457
+
458
+ std::vector<int > odims = {static_cast <int >(num_rows), static_cast <int >(k)};
459
+ auto dim = framework::make_ddim (odims);
460
+ auto e_values = EigenMatrix<T>::From (*out_tensor, dim);
461
+ auto e_tmp_values = EigenMatrix<T>::From (temp_values);
462
+
463
+ e_indices.device (dev) = e_tmp_indices.slice (slice_indices, slice_sizes);
464
+ e_values.device (dev) = e_tmp_values.slice (slice_indices, slice_sizes);
465
+ }
466
+ return true ;
467
+ }
468
+
306
469
#define FIXED_BLOCK_DIM_BASE (dim, ...) \
307
470
case (dim): { \
308
471
constexpr auto kBlockDim = (dim); \
@@ -324,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
324
487
auto * input = ctx.Input <Tensor>(" X" );
325
488
auto * output = ctx.Output <Tensor>(" Out" );
326
489
auto * indices = ctx.Output <Tensor>(" Indices" );
327
- size_t k = static_cast <int >(ctx.Attr <int >(" k" ));
490
+ int k = static_cast <int >(ctx.Attr <int >(" k" ));
328
491
329
492
auto * k_t = ctx.Input <Tensor>(" K" );
330
493
if (k_t ) {
@@ -340,21 +503,31 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
340
503
const T* input_data = input->data <T>();
341
504
T* output_data = output->mutable_data <T>(ctx.GetPlace ());
342
505
// FIXME(typhoonzero): data is always converted to type T?
343
- int64_t * indices_data = indices->mutable_data <int64_t >(ctx.GetPlace ());
344
506
345
507
framework::DDim inputdims = input->dims ();
346
- const size_t input_height = framework::product (
508
+ const int64_t input_height = framework::product (
347
509
framework::slice_ddim (inputdims, 0 , inputdims.size () - 1 ));
348
- const size_t input_width = inputdims[inputdims.size () - 1 ];
349
-
510
+ const int64_t input_width = inputdims[inputdims.size () - 1 ];
511
+ const auto & dev_ctx = ctx.cuda_device_context ();
512
+
513
+ if ((input_width <= 1024 || k >= 128 || k == input_width)) {
514
+ if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
515
+ indices)) {
516
+ // Successed, return.
517
+ return ;
518
+ } else {
519
+ LOG (INFO) << " TopKOP: Some errors happened when use cub sorting, use "
520
+ " default topk kernel." ;
521
+ }
522
+ }
523
+ int64_t * indices_data = indices->mutable_data <int64_t >(ctx.GetPlace ());
350
524
if (k > input_width) k = input_width;
351
525
352
526
// NOTE: pass lds and dim same to input width.
353
527
// NOTE: old matrix implementation of stride is different to eigen.
354
528
// TODO(typhoonzero): refine this kernel.
355
529
const int kMaxHeight = 2048 ;
356
530
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight ;
357
- auto & dev_ctx = ctx.cuda_device_context ();
358
531
switch (GetDesiredBlockDim (input_width)) {
359
532
FIXED_BLOCK_DIM (
360
533
KeMatrixTopK<T, 5 ,
0 commit comments