Skip to content

Commit e9f2033

Browse files
author
chengduo
authored
Merge pull request #8539 from chengduoZH/feature/refine_elementwise_op_function.h
Refine Sum in elementwise_op_function
2 parents fee90b5 + 90dc33b commit e9f2033

File tree

2 files changed

+62
-34
lines changed

2 files changed

+62
-34
lines changed

paddle/fluid/operators/elementwise_op_function.h

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020

2121
#ifdef __NVCC__
2222
#include <thrust/iterator/iterator_adaptor.h>
23+
#include "paddle/fluid/platform/cuda_helper.h"
2324
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
2425
#endif
2526

@@ -361,36 +362,27 @@ template <typename T, typename DX_OP, typename DY_OP>
361362
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
362363
const T* x, const T* y, const T* out, const T* dout, int h, int w,
363364
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
364-
extern __shared__ char shm_buffer[];
365-
T* shm = reinterpret_cast<T*>(shm_buffer);
366-
367365
int j = blockIdx.x;
368366
int i = threadIdx.x;
369367
int tid = threadIdx.x;
370-
shm[tid] = 0;
368+
T val = 0;
371369

372370
do {
373371
int x_offset = i * w + j;
374372
if (dx) {
375373
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
376374
}
377375
if (dy) {
378-
shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
376+
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
379377
}
380378
i += ELEMWISE_MAX_BLOCK_DIM;
381379
} while (i < h);
382380

383381
if (dy) {
384-
__syncthreads();
385-
386382
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
387-
388-
// Sum, could be optimized
383+
val = platform::reduceSum(val, tid, h);
389384
if (threadIdx.x == 0) {
390-
for (int k = 1; k < h; ++k) {
391-
shm[0] += shm[k];
392-
}
393-
dy[j] = shm[0];
385+
dy[j] = val;
394386
}
395387
}
396388
}
@@ -402,10 +394,8 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
402394
T* dx, T* dy) {
403395
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
404396
int gird_size = w;
405-
int shared_mem_size = block_size * sizeof(T);
406-
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, shared_mem_size,
407-
stream>>>(x, y, out, dout, h, w, dx_op,
408-
dy_op, dx, dy);
397+
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
398+
x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
409399
}
410400

411401
#endif
@@ -436,17 +426,14 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
436426
}
437427

438428
#ifdef __NVCC__
439-
440429
template <typename T, typename DX_OP, typename DY_OP>
441430
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
442431
const T* x, const T* y, const T* out, const T* dout, int pre, int n,
443432
int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
444433
int tid = threadIdx.x;
445434
int j = blockIdx.x;
446435

447-
extern __shared__ char shm_buffer[];
448-
T* shm = reinterpret_cast<T*>(shm_buffer);
449-
shm[tid] = 0;
436+
T val = 0;
450437
int ttid = tid;
451438

452439
while (true) {
@@ -461,23 +448,18 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
461448
}
462449

463450
if (dy != nullptr) {
464-
shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
451+
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
465452
}
466453

467454
ttid += ELEMWISE_MAX_BLOCK_DIM;
468455
}
469456

470457
if (dy) {
471-
__syncthreads();
472458
int h = pre * post;
473459
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
474-
475-
// Sum, could be optimized
476-
if (tid == 0) {
477-
for (int i = 1; i < h; ++i) {
478-
shm[0] += shm[i];
479-
}
480-
dy[j] = shm[0];
460+
val = platform::reduceSum(val, tid, h);
461+
if (threadIdx.x == 0) {
462+
dy[j] = val;
481463
}
482464
}
483465
}
@@ -489,10 +471,8 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
489471
DY_OP dy_op, T* dx, T* dy) {
490472
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
491473
int gird_size = n;
492-
int shared_mem_size = block_size * sizeof(T);
493-
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, shared_mem_size,
494-
stream>>>(x, y, out, dout, pre, n, post,
495-
dx_op, dy_op, dx, dy);
474+
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
475+
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
496476
}
497477

498478
#endif

paddle/fluid/platform/cuda_helper.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,53 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
6262
}
6363
#endif
6464

65+
// __shfl_down has been deprecated as of CUDA 9.0.
66+
#if CUDA_VERSION < 9000
67+
template <typename T>
68+
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
69+
return __shfl_down(val, delta);
70+
}
71+
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
72+
#else
73+
#define FULL_WARP_MASK 0xFFFFFFFF
74+
#define CREATE_SHFL_MASK(mask, predicate) \
75+
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
76+
#endif
77+
78+
template <typename T>
79+
__device__ T reduceSum(T val, int tid, int len) {
80+
// TODO(zcd): The warp size should be taken from the
81+
// parameters of the GPU but not specified as 32 simply.
82+
// To make the reduceSum more efficiently,
83+
// I use Warp-Level Parallelism and assume the Warp size
84+
// is 32 which may be different for different GPU,
85+
// but most card's warp size is 32.
86+
__shared__ T shm[32];
87+
const int warpSize = 32;
88+
unsigned mask = 0u;
89+
CREATE_SHFL_MASK(mask, tid < len);
90+
91+
for (int offset = warpSize / 2; offset > 0; offset /= 2)
92+
val += __shfl_down_sync(mask, val, offset);
93+
94+
if (tid < warpSize) shm[tid] = 0;
95+
96+
__syncthreads();
97+
98+
if (tid % warpSize == 0) {
99+
shm[tid / warpSize] = val;
100+
}
101+
102+
CREATE_SHFL_MASK(mask, tid < warpSize);
103+
104+
if (tid < warpSize) {
105+
val = shm[tid];
106+
for (int offset = warpSize / 2; offset > 0; offset /= 2)
107+
val += __shfl_down_sync(mask, val, offset);
108+
}
109+
110+
return val;
111+
}
112+
65113
} // namespace platform
66114
} // namespace paddle

0 commit comments

Comments
 (0)