Skip to content

Commit d3f7e14

Browse files
authored
fix idx overflow in cuda kernel (PaddlePaddle#76376)
1 parent bdae14e commit d3f7e14

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

paddle/phi/kernels/gpu/elementwise_grad.h

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,19 @@ void GetGradXOrYOut(const GPUContext &dev_ctx,
112112
******************************
113113
*/
114114

115-
template <typename T>
115+
template <typename T, typename IndexT = int>
116116
static __global__ void SimpleElemwiseAddGradCUDAKernel(
117-
const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) {
118-
int tid = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
119-
int stride = GRID_NUM_X * BLOCK_NUM_X;
120-
int loop = size / vec_size;
121-
int remainder = size % vec_size;
117+
const T *__restrict__ dout, IndexT size, int vec_size, T *dx, T *dy) {
118+
IndexT tid = static_cast<IndexT>(BLOCK_ID_X) * BLOCK_NUM_X + THREAD_ID_X;
119+
IndexT stride = static_cast<IndexT>(GRID_NUM_X) * BLOCK_NUM_X;
120+
IndexT loop = size / vec_size;
121+
IndexT remainder = size % vec_size;
122122
const float4 *dout_vec = reinterpret_cast<const float4 *>(dout);
123123
float4 *dx_vec = reinterpret_cast<float4 *>(dx);
124124
float4 *dy_vec = reinterpret_cast<float4 *>(dy);
125125
float4 tmp_loop;
126126

127-
for (int i = tid; i < loop; i += stride) {
127+
for (IndexT i = tid; i < loop; i += stride) {
128128
tmp_loop = dout_vec[i];
129129
dx_vec[i] = tmp_loop;
130130
dy_vec[i] = tmp_loop;
@@ -133,7 +133,7 @@ static __global__ void SimpleElemwiseAddGradCUDAKernel(
133133
if (tid == loop && remainder != 0) {
134134
T tmp_rem;
135135
while (remainder) {
136-
int idx = size - remainder;
136+
IndexT idx = size - remainder;
137137
remainder--;
138138
tmp_rem = dout[idx];
139139
dx[idx] = tmp_rem;
@@ -219,13 +219,24 @@ void ElementwiseAddGrad(const GPUContext &dev_ctx,
219219
dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) /
220220
PREDEFINED_BLOCK_SIZE,
221221
1);
222-
SimpleElemwiseAddGradCUDAKernel<T>
223-
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
224-
dout.data<T>(),
225-
size,
226-
vec_size,
227-
dev_ctx.template Alloc<T>(dx),
228-
dev_ctx.template Alloc<T>(dy));
222+
if (size < std::numeric_limits<int>::max()) {
223+
SimpleElemwiseAddGradCUDAKernel<T>
224+
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
225+
dout.data<T>(),
226+
size,
227+
vec_size,
228+
dev_ctx.template Alloc<T>(dx),
229+
dev_ctx.template Alloc<T>(dy));
230+
} else {
231+
SimpleElemwiseAddGradCUDAKernel<T, int64_t>
232+
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
233+
dout.data<T>(),
234+
size,
235+
vec_size,
236+
dev_ctx.template Alloc<T>(dx),
237+
dev_ctx.template Alloc<T>(dy));
238+
}
239+
229240
} else {
230241
VLOG(4) << "Special case when dy_data is the same as dout_data, "
231242
"and dx_data is the same as dout_data, do not need "

0 commit comments

Comments
 (0)