@@ -212,9 +212,15 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
212
212
prob_greater_than_threshold[j] = pred (prob_vec[j]) ? prob_vec[j] : 0 ;
213
213
valid[j] = pred (prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
214
214
}
215
+ #ifdef PADDLE_WITH_COREX
216
+ float aggregate_local =
217
+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim .reduce )
218
+ .Sum (prob_greater_than_threshold);
219
+ #else
215
220
float aggregate_local =
216
221
BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim .reduce )
217
222
.Sum <VEC_SIZE>(prob_greater_than_threshold);
223
+ #endif
218
224
if (tx == 0 ) {
219
225
temp_storage->block_aggregate .value = aggregate_local;
220
226
}
@@ -226,8 +232,13 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
226
232
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
227
233
prob_greater_than_threshold, inclusive_cdf, temp_storage);
228
234
} else {
235
+ #ifdef PADDLE_WITH_COREX
236
+ BlockScan<float , BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim .scan )
237
+ .InclusiveSum (prob_greater_than_threshold, inclusive_cdf);
238
+ #else
229
239
BlockScan<float , BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim .scan )
230
240
.InclusiveSum <VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
241
+ #endif
231
242
232
243
__syncthreads ();
233
244
}
@@ -239,11 +250,21 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
239
250
240
251
bool greater_than_u_diff[VEC_SIZE];
241
252
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
242
- BlockAdjacentDifference<bool , BLOCK_THREADS>(temp_storage->block_prim .adj_diff )
243
- .SubtractLeft <VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp ());
253
+ #ifdef PADDLE_WITH_COREX
254
+ BlockAdjacentDifference<bool , BLOCK_THREADS>(temp_storage->block_prim .adj_diff )
255
+ .SubtractLeft (greater_than_u, greater_than_u_diff, BoolDiffOp ());
256
+ #else
257
+ BlockAdjacentDifference<bool , BLOCK_THREADS>(temp_storage->block_prim .adj_diff )
258
+ .SubtractLeft <VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp ());
259
+ #endif
244
260
#else
245
- BlockAdjacentDifference<bool , BLOCK_THREADS>(temp_storage->block_prim .adj_diff )
246
- .FlagHeads <VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp (), 0 );
261
+ #ifdef PADDLE_WITH_COREX
262
+ BlockAdjacentDifference<bool , BLOCK_THREADS>(temp_storage->block_prim .adj_diff )
263
+ .FlagHeads (greater_than_u_diff, greater_than_u, BoolDiffOp (), 0 );
264
+ #else
265
+ BlockAdjacentDifference<bool , BLOCK_THREADS>(temp_storage->block_prim .adj_diff )
266
+ .FlagHeads <VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp (), 0 );
267
+ #endif
247
268
#endif
248
269
__syncthreads ();
249
270
@@ -355,18 +376,30 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
355
376
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
356
377
}
357
378
379
+ #ifdef PADDLE_WITH_COREX
380
+ aggregate_gt_pivot_0 +=
381
+ BlockReduce<ValueCount<float >, BLOCK_THREADS>(temp_storage.block_prim .reduce_value_count )
382
+ .Sum (probs_gt_pivot_0);
383
+ #else
358
384
aggregate_gt_pivot_0 +=
359
385
BlockReduce<ValueCount<float >, BLOCK_THREADS>(temp_storage.block_prim .reduce_value_count )
360
386
.Sum <VEC_SIZE>(probs_gt_pivot_0);
387
+ #endif
361
388
if (tx == 0 ) {
362
389
temp_storage.block_aggregate .pair = aggregate_gt_pivot_0;
363
390
}
364
391
__syncthreads ();
365
392
aggregate_gt_pivot_0 = temp_storage.block_aggregate .pair ;
366
393
394
+ #ifdef PADDLE_WITH_COREX
395
+ aggregate_gt_pivot_1 +=
396
+ BlockReduce<ValueCount<float >, BLOCK_THREADS>(temp_storage.block_prim .reduce_value_count )
397
+ .Sum (probs_gt_pivot_1);
398
+ #else
367
399
aggregate_gt_pivot_1 +=
368
400
BlockReduce<ValueCount<float >, BLOCK_THREADS>(temp_storage.block_prim .reduce_value_count )
369
401
.Sum <VEC_SIZE>(probs_gt_pivot_1);
402
+ #endif
370
403
if (tx == 0 ) {
371
404
temp_storage.block_aggregate .pair = aggregate_gt_pivot_1;
372
405
}
@@ -466,16 +499,26 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
466
499
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0 ;
467
500
}
468
501
502
+ #ifdef PADDLE_WITH_COREX
503
+ aggregate_gt_pivot_0 += BlockReduce<float , BLOCK_THREADS>(temp_storage.block_prim .reduce )
504
+ .Sum (probs_gt_pivot_0);
505
+ #else
469
506
aggregate_gt_pivot_0 += BlockReduce<float , BLOCK_THREADS>(temp_storage.block_prim .reduce )
470
507
.Sum <VEC_SIZE>(probs_gt_pivot_0);
508
+ #endif
471
509
if (tx == 0 ) {
472
510
temp_storage.block_aggregate .value = aggregate_gt_pivot_0;
473
511
}
474
512
__syncthreads ();
475
513
aggregate_gt_pivot_0 = temp_storage.block_aggregate .value ;
476
514
515
+ #ifdef PADDLE_WITH_COREX
516
+ aggregate_gt_pivot_1 += BlockReduce<float , BLOCK_THREADS>(temp_storage.block_prim .reduce )
517
+ .Sum (probs_gt_pivot_1);
518
+ #else
477
519
aggregate_gt_pivot_1 += BlockReduce<float , BLOCK_THREADS>(temp_storage.block_prim .reduce )
478
520
.Sum <VEC_SIZE>(probs_gt_pivot_1);
521
+ #endif
479
522
if (tx == 0 ) {
480
523
temp_storage.block_aggregate .value = aggregate_gt_pivot_1;
481
524
}
@@ -521,9 +564,15 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
521
564
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
522
565
in_data_[j] = in_data_vec[j];
523
566
}
567
+ #ifdef PADDLE_WITH_COREX
568
+ max_val = max (
569
+ max_val, BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
570
+ .Reduce (in_data_, cub::Max ()));
571
+ #else
524
572
max_val = max (
525
573
max_val, BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
526
574
.Reduce <VEC_SIZE>(in_data_, cub::Max ()));
575
+ #endif
527
576
__syncthreads ();
528
577
}
529
578
if (tx == 0 ) {
@@ -605,7 +654,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
605
654
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
606
655
const uint32_t row_idx = bx;
607
656
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
657
+ #ifdef PADDLE_WITH_COREX
658
+ double pivot = std::numeric_limits<float >::infinity (), normalizer = 1 ;
659
+ #else
608
660
double pivot = -cuda::std::numeric_limits<float >::infinity (), normalizer = 1 ;
661
+ #endif
609
662
vec_t <float , VEC_SIZE> probs_vec;
610
663
if (k < d) {
611
664
extern __shared__ __align__ (alignof (RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
@@ -659,14 +712,26 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
659
712
}
660
713
}
661
714
715
+ #ifdef PADDLE_WITH_COREX
716
+ aggregate_gt_pivot_0 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
717
+ temp_storage.block_prim .reduce_value_count )
718
+ .Sum (probs_gt_pivot_0_pair);
719
+ #else
662
720
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
663
721
temp_storage.block_prim .reduce_value_count )
664
722
.Sum <VEC_SIZE>(probs_gt_pivot_0_pair);
723
+ #endif
665
724
__syncthreads ();
666
725
726
+ #ifdef PADDLE_WITH_COREX
727
+ aggregate_gt_pivot_1 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
728
+ temp_storage.block_prim .reduce_value_count )
729
+ .Sum (probs_gt_pivot_1_pair);
730
+ #else
667
731
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
668
732
temp_storage.block_prim .reduce_value_count )
669
733
.Sum <VEC_SIZE>(probs_gt_pivot_1_pair);
734
+ #endif
670
735
__syncthreads ();
671
736
}
672
737
min_gt_low =
0 commit comments