Skip to content

Commit 3f3f5a2

Browse files
authored
[Cherry-Pick]Fix group_norm vectorized address misalignment (#41585)
[cherry-pick] #41531 and #41570
1 parent 727dcbd commit 3f3f5a2

File tree

1 file changed

+4
-29
lines changed

1 file changed

+4
-29
lines changed

paddle/fluid/operators/group_norm_op.cu

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -419,23 +419,6 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
419419
}
420420
}
421421

422-
template <typename T, typename AccT, int VecSize>
423-
__global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
424-
T* ds, T* db) {
425-
int i = blockIdx.x;
426-
AccT ds_sum = static_cast<AccT>(0);
427-
AccT db_sum = static_cast<AccT>(0);
428-
x += i * imsize;
429-
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
430-
431-
phi::Array<const T*, 2> ins;
432-
ins[0] = x;
433-
ins[1] = dy;
434-
ThreadReduce<T, AccT, VecSize, 2>(ins, imsize, input_offset, &db_sum,
435-
&ds_sum);
436-
ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1);
437-
}
438-
439422
template <typename T>
440423
__global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
441424
T* ds, T* db) {
@@ -622,25 +605,17 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
622605
int flags =
623606
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
624607
if (data_layout == DataLayout::kNCHW) {
625-
using AccT = typename details::MPTypeTrait<T>::Type;
626-
constexpr int vec_size = sizeof(float4) / sizeof(T);
627608
const int max_num_threads = 1024;
628-
int max_block_size = std::min(imsize / vec_size, max_num_threads);
609+
int max_block_size = std::min(imsize, max_num_threads);
629610
int block_size_nchw = 1;
630611
while (block_size_nchw < max_block_size) {
631612
block_size_nchw *= 2;
632613
}
633614
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
634615
dim3 blocks(block_size_nchw);
635-
if (imsize < vec_size * block_size_nchw) {
636-
ScalarGetDsDbCUDAKernel<
637-
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
638-
imsize, x_data, dy_data, ds_data, db_data);
639-
} else {
640-
VectorizedGetDsDbCUDAKernel<
641-
T, AccT, vec_size><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
642-
imsize, x_data, dy_data, ds_data, db_data);
643-
}
616+
ScalarGetDsDbCUDAKernel<
617+
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
618+
imsize, x_data, dy_data, ds_data, db_data);
644619

645620
if (d_scale || d_bias) {
646621
const int block = 256;

0 commit comments

Comments
 (0)