@@ -419,23 +419,6 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
419
419
}
420
420
}
421
421
422
- template <typename T, typename AccT, int VecSize>
423
- __global__ void VectorizedGetDsDbCUDAKernel (int imsize, const T* x, const T* dy,
424
- T* ds, T* db) {
425
- int i = blockIdx .x ;
426
- AccT ds_sum = static_cast <AccT>(0 );
427
- AccT db_sum = static_cast <AccT>(0 );
428
- x += i * imsize;
429
- const int input_offset = ((uint64_t )x) % ALIGN_BYTES / sizeof (T);
430
-
431
- phi::Array<const T*, 2 > ins;
432
- ins[0 ] = x;
433
- ins[1 ] = dy;
434
- ThreadReduce<T, AccT, VecSize, 2 >(ins, imsize, input_offset, &db_sum,
435
- &ds_sum);
436
- ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1 );
437
- }
438
-
439
422
template <typename T>
440
423
__global__ void ScalarGetDsDbCUDAKernel (int imsize, const T* x, const T* dy,
441
424
T* ds, T* db) {
@@ -622,25 +605,17 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
622
605
int flags =
623
606
(scale_data != nullptr ) * kHasScale + (bias_data != nullptr ) * kHasBias ;
624
607
if (data_layout == DataLayout::kNCHW ) {
625
- using AccT = typename details::MPTypeTrait<T>::Type;
626
- constexpr int vec_size = sizeof (float4 ) / sizeof (T);
627
608
const int max_num_threads = 1024 ;
628
- int max_block_size = std::min (imsize / vec_size , max_num_threads);
609
+ int max_block_size = std::min (imsize, max_num_threads);
629
610
int block_size_nchw = 1 ;
630
611
while (block_size_nchw < max_block_size) {
631
612
block_size_nchw *= 2 ;
632
613
}
633
614
block_size_nchw = std::max (block_size_nchw, kps::details::kWarpSize );
634
615
dim3 blocks (block_size_nchw);
635
- if (imsize < vec_size * block_size_nchw) {
636
- ScalarGetDsDbCUDAKernel<
637
- T><<<x_dims[0 ] * C, blocks, 0 , dev_ctx.stream()>>> (
638
- imsize, x_data, dy_data, ds_data, db_data);
639
- } else {
640
- VectorizedGetDsDbCUDAKernel<
641
- T, AccT, vec_size><<<x_dims[0 ] * C, blocks, 0 , dev_ctx.stream()>>> (
642
- imsize, x_data, dy_data, ds_data, db_data);
643
- }
616
+ ScalarGetDsDbCUDAKernel<
617
+ T><<<x_dims[0 ] * C, blocks, 0 , dev_ctx.stream()>>> (
618
+ imsize, x_data, dy_data, ds_data, db_data);
644
619
645
620
if (d_scale || d_bias) {
646
621
const int block = 256 ;
0 commit comments