@@ -20,6 +20,7 @@ limitations under the License. */
20
20
21
21
#ifdef __NVCC__
22
22
#include < thrust/iterator/iterator_adaptor.h>
23
+ #include " paddle/fluid/platform/cuda_helper.h"
23
24
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024 ;
24
25
#endif
25
26
@@ -361,36 +362,27 @@ template <typename T, typename DX_OP, typename DY_OP>
361
362
static __global__ void ElemwiseGradBroadcast1CUDAKernel (
362
363
const T* x, const T* y, const T* out, const T* dout, int h, int w,
363
364
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
364
- extern __shared__ char shm_buffer[];
365
- T* shm = reinterpret_cast <T*>(shm_buffer);
366
-
367
365
int j = blockIdx.x ;
368
366
int i = threadIdx.x ;
369
367
int tid = threadIdx.x ;
370
- shm[tid] = 0 ;
368
+ T val = 0 ;
371
369
372
370
do {
373
371
int x_offset = i * w + j;
374
372
if (dx) {
375
373
dx[x_offset] = dx_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
376
374
}
377
375
if (dy) {
378
- shm[tid] += dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
376
+ val += dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
379
377
}
380
378
i += ELEMWISE_MAX_BLOCK_DIM;
381
379
} while (i < h);
382
380
383
381
if (dy) {
384
- __syncthreads ();
385
-
386
382
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
387
-
388
- // Sum, could be optimized
383
+ val = platform::reduceSum (val, tid, h);
389
384
if (threadIdx.x == 0 ) {
390
- for (int k = 1 ; k < h; ++k) {
391
- shm[0 ] += shm[k];
392
- }
393
- dy[j] = shm[0 ];
385
+ dy[j] = val;
394
386
}
395
387
}
396
388
}
@@ -402,10 +394,8 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
402
394
T* dx, T* dy) {
403
395
int block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, h);
404
396
int gird_size = w;
405
- int shared_mem_size = block_size * sizeof (T);
406
- ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, shared_mem_size,
407
- stream>>>(x, y, out, dout, h, w, dx_op,
408
- dy_op, dx, dy);
397
+ ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0 , stream>>>(
398
+ x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
409
399
}
410
400
411
401
#endif
@@ -436,17 +426,14 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
436
426
}
437
427
438
428
#ifdef __NVCC__
439
-
440
429
template <typename T, typename DX_OP, typename DY_OP>
441
430
static __global__ void ElemwiseGradBroadcast2CUDAKernel (
442
431
const T* x, const T* y, const T* out, const T* dout, int pre, int n,
443
432
int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
444
433
int tid = threadIdx.x ;
445
434
int j = blockIdx.x ;
446
435
447
- extern __shared__ char shm_buffer[];
448
- T* shm = reinterpret_cast <T*>(shm_buffer);
449
- shm[tid] = 0 ;
436
+ T val = 0 ;
450
437
int ttid = tid;
451
438
452
439
while (true ) {
@@ -461,23 +448,18 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
461
448
}
462
449
463
450
if (dy != nullptr ) {
464
- shm[tid] += dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
451
+ val += dy_op (x[x_offset], y[j], out[x_offset], dout[x_offset]);
465
452
}
466
453
467
454
ttid += ELEMWISE_MAX_BLOCK_DIM;
468
455
}
469
456
470
457
if (dy) {
471
- __syncthreads ();
472
458
int h = pre * post;
473
459
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
474
-
475
- // Sum, could be optimized
476
- if (tid == 0 ) {
477
- for (int i = 1 ; i < h; ++i) {
478
- shm[0 ] += shm[i];
479
- }
480
- dy[j] = shm[0 ];
460
+ val = platform::reduceSum (val, tid, h);
461
+ if (threadIdx.x == 0 ) {
462
+ dy[j] = val;
481
463
}
482
464
}
483
465
}
@@ -489,10 +471,8 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
489
471
DY_OP dy_op, T* dx, T* dy) {
490
472
int block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, pre * post);
491
473
int gird_size = n;
492
- int shared_mem_size = block_size * sizeof (T);
493
- ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, shared_mem_size,
494
- stream>>>(x, y, out, dout, pre, n, post,
495
- dx_op, dy_op, dx, dy);
474
+ ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0 , stream>>>(
475
+ x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
496
476
}
497
477
498
478
#endif
0 commit comments