@@ -112,19 +112,19 @@ void GetGradXOrYOut(const GPUContext &dev_ctx,
112112******************************
113113*/
114114
115- template <typename T>
115+ template <typename T, typename IndexT = int >
116116static __global__ void SimpleElemwiseAddGradCUDAKernel (
117- const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) {
118- int tid = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
119- int stride = GRID_NUM_X * BLOCK_NUM_X;
120- int loop = size / vec_size;
121- int remainder = size % vec_size;
117+ const T *__restrict__ dout, IndexT size, int vec_size, T *dx, T *dy) {
118+ IndexT tid = static_cast <IndexT>( BLOCK_ID_X) * BLOCK_NUM_X + THREAD_ID_X;
119+ IndexT stride = static_cast <IndexT>( GRID_NUM_X) * BLOCK_NUM_X;
120+ IndexT loop = size / vec_size;
121+ IndexT remainder = size % vec_size;
122122 const float4 *dout_vec = reinterpret_cast <const float4 *>(dout);
123123 float4 *dx_vec = reinterpret_cast <float4 *>(dx);
124124 float4 *dy_vec = reinterpret_cast <float4 *>(dy);
125125 float4 tmp_loop;
126126
127- for (int i = tid; i < loop; i += stride) {
127+ for (IndexT i = tid; i < loop; i += stride) {
128128 tmp_loop = dout_vec[i];
129129 dx_vec[i] = tmp_loop;
130130 dy_vec[i] = tmp_loop;
@@ -133,7 +133,7 @@ static __global__ void SimpleElemwiseAddGradCUDAKernel(
133133 if (tid == loop && remainder != 0 ) {
134134 T tmp_rem;
135135 while (remainder) {
136- int idx = size - remainder;
136+ IndexT idx = size - remainder;
137137 remainder--;
138138 tmp_rem = dout[idx];
139139 dx[idx] = tmp_rem;
@@ -219,13 +219,24 @@ void ElementwiseAddGrad(const GPUContext &dev_ctx,
219219 dim3 (((size + vec_size - 1 ) / vec_size + PREDEFINED_BLOCK_SIZE - 1 ) /
220220 PREDEFINED_BLOCK_SIZE,
221221 1 );
222- SimpleElemwiseAddGradCUDAKernel<T>
223- <<<grid_size, block_size, 0 , dev_ctx.stream ()>>>(
224- dout.data <T>(),
225- size,
226- vec_size,
227- dev_ctx.template Alloc <T>(dx),
228- dev_ctx.template Alloc <T>(dy));
222+ if (size < std::numeric_limits<int >::max ()) {
223+ SimpleElemwiseAddGradCUDAKernel<T>
224+ <<<grid_size, block_size, 0 , dev_ctx.stream ()>>>(
225+ dout.data <T>(),
226+ size,
227+ vec_size,
228+ dev_ctx.template Alloc <T>(dx),
229+ dev_ctx.template Alloc <T>(dy));
230+ } else {
231+ SimpleElemwiseAddGradCUDAKernel<T, int64_t >
232+ <<<grid_size, block_size, 0 , dev_ctx.stream ()>>>(
233+ dout.data <T>(),
234+ size,
235+ vec_size,
236+ dev_ctx.template Alloc <T>(dx),
237+ dev_ctx.template Alloc <T>(dy));
238+ }
239+
229240 } else {
230241 VLOG (4 ) << " Special case when dy_data is the same as dout_data, "
231242 " and dx_data is the same as dout_data, do not need "
0 commit comments