@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#pragma once
16
+ #include < algorithm>
16
17
#include " paddle/fluid/framework/eigen.h"
17
18
#include " paddle/fluid/framework/op_registry.h"
18
19
#include " paddle/fluid/framework/operator.h"
19
20
#include " paddle/fluid/platform/transform.h"
20
21
21
22
#ifdef __NVCC__
23
+ #include < cuda.h>
22
24
#include < thrust/iterator/iterator_adaptor.h>
23
- #include " paddle/fluid/platform/cuda_helper.h"
24
25
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024 ;
25
26
#endif
26
27
@@ -43,35 +44,35 @@ namespace operators {
43
44
*/
44
45
inline void get_mid_dims (const framework::DDim& x_dims,
45
46
const framework::DDim& y_dims, const int axis,
46
- int & pre, int & n, int & post) {
47
- pre = 1 ;
48
- n = 1 ;
49
- post = 1 ;
47
+ int * pre, int * n, int * post) {
48
+ * pre = 1 ;
49
+ * n = 1 ;
50
+ * post = 1 ;
50
51
for (int i = 0 ; i < axis; ++i) {
51
- pre *= x_dims[i];
52
+ (* pre) *= x_dims[i];
52
53
}
53
54
54
55
for (int i = 0 ; i < y_dims.size (); ++i) {
55
56
PADDLE_ENFORCE_EQ (x_dims[i + axis], y_dims[i],
56
57
" Broadcast dimension mismatch." );
57
- n *= y_dims[i];
58
+ (*n) *= y_dims[i];
58
59
}
59
60
60
61
for (int i = axis + y_dims.size (); i < x_dims.size (); ++i) {
61
- post *= x_dims[i];
62
+ (* post) *= x_dims[i];
62
63
}
63
64
}
64
65
65
- inline void trim_trailing_singular_dims (framework::DDim& dims) {
66
+ inline void trim_trailing_singular_dims (framework::DDim* dims) {
66
67
// Remove trailing dimensions of size 1 for y
67
- auto actual_dims_size = dims. size ();
68
+ auto actual_dims_size = dims-> size ();
68
69
for (; actual_dims_size != 0 ; --actual_dims_size) {
69
- if (dims[actual_dims_size - 1 ] != 1 ) break ;
70
+ if ((* dims) [actual_dims_size - 1 ] != 1 ) break ;
70
71
}
71
- if (actual_dims_size != dims. size ()) {
72
- auto actual_dims = framework::vectorize (dims);
72
+ if (actual_dims_size != dims-> size ()) {
73
+ auto actual_dims = framework::vectorize (* dims);
73
74
actual_dims.resize (actual_dims_size);
74
- dims = framework::make_ddim (actual_dims);
75
+ * dims = framework::make_ddim (actual_dims);
75
76
}
76
77
}
77
78
@@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
159
160
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
160
161
super_t ;
161
162
HOSTDEVICE RowwiseTransformIterator (const T* x, int n)
162
- : super_t(x), begin_(x), n_(n){};
163
+ : super_t(x), begin_(x), n_(n) {}
163
164
friend class thrust ::iterator_core_access;
164
165
165
166
private:
@@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
179
180
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
180
181
super_t ;
181
182
HOSTDEVICE MidWiseTransformIterator (const T* x, int n, int post)
182
- : super_t(x), begin_(x), n_(n), post_(post){};
183
+ : super_t(x), begin_(x), n_(n), post_(post) {}
183
184
friend class thrust ::iterator_core_access;
184
185
185
186
private:
@@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
333
334
}
334
335
}
335
336
#ifdef __NVCC__
337
+
338
+ // __shfl_down has been deprecated as of CUDA 9.0.
339
+ #if CUDA_VERSION < 9000
340
+ template <typename T>
341
+ __forceinline__ __device__ T __shfl_down_sync (unsigned , T val, int delta) {
342
+ return __shfl_down (val, delta);
343
+ }
344
+ #define CREATE_SHFL_MASK (mask, predicate ) mask = 0u ;
345
+ #else
346
+ #define FULL_WARP_MASK 0xFFFFFFFF
347
+ #define CREATE_SHFL_MASK (mask, predicate ) \
348
+ mask = __ballot_sync(FULL_WARP_MASK, (predicate))
349
+ #endif
350
+
351
+ template <typename T>
352
+ __device__ T reduceSum (T val, int tid, int len) {
353
+ // TODO(zcd): The warp size should be taken from the
354
+ // parameters of the GPU but not specified as 32 simply.
355
+ // To make the reduceSum more efficiently,
356
+ // I use Warp-Level Parallelism and assume the Warp size
357
+ // is 32 which may be different for different GPU,
358
+ // but most card's warp size is 32.
359
+ __shared__ T shm[32 ];
360
+ const int warpSize = 32 ;
361
+ unsigned mask = 0u ;
362
+ CREATE_SHFL_MASK (mask, tid < len);
363
+
364
+ for (int offset = warpSize / 2 ; offset > 0 ; offset /= 2 )
365
+ val += __shfl_down_sync (mask, val, offset);
366
+
367
+ if (tid < warpSize) shm[tid] = 0 ;
368
+
369
+ __syncthreads ();
370
+
371
+ if (tid % warpSize == 0 ) {
372
+ shm[tid / warpSize] = val;
373
+ }
374
+
375
+ CREATE_SHFL_MASK (mask, tid < warpSize);
376
+
377
+ if (tid < warpSize) {
378
+ val = shm[tid];
379
+ for (int offset = warpSize / 2 ; offset > 0 ; offset /= 2 )
380
+ val += __shfl_down_sync (mask, val, offset);
381
+ }
382
+
383
+ return val;
384
+ }
385
+
336
386
template <typename T, typename DX_OP, typename DY_OP>
337
387
static __global__ void ElemwiseGradBroadcast1CUDAKernel (
338
388
const T* x, const T* y, const T* out, const T* dout, int h, int w,
@@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
355
405
356
406
if (dy) {
357
407
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
358
- val = platform:: reduceSum (val, tid, h);
408
+ val = reduceSum (val, tid, h);
359
409
if (threadIdx.x == 0 ) {
360
410
dy[j] = val;
361
411
}
@@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
432
482
if (dy) {
433
483
int h = pre * post;
434
484
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
435
- val = platform:: reduceSum (val, tid, h);
485
+ val = reduceSum (val, tid, h);
436
486
if (threadIdx.x == 0 ) {
437
487
dy[j] = val;
438
488
}
@@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
472
522
auto y_dim = y.dims ();
473
523
474
524
axis = (axis == -1 ? x_dim.size () - y_dim.size () : axis);
475
- trim_trailing_singular_dims (y_dim);
525
+ trim_trailing_singular_dims (& y_dim);
476
526
axis = (y_dim.size () == 0 ) ? x_dim.size () : axis;
477
527
478
528
int pre, n, post;
479
- get_mid_dims (x_dim, y_dim, axis, pre, n, post);
529
+ get_mid_dims (x_dim, y_dim, axis, & pre, & n, & post);
480
530
if (post == 1 ) {
481
531
int h = pre;
482
532
int w = n;
@@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
514
564
}
515
565
}
516
566
}
517
- };
567
+ }
518
568
519
569
template <typename DeviceContext, typename T, typename functor,
520
570
typename broadcastfunctor, typename broadcast2functor>
@@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
543
593
}
544
594
545
595
axis = (axis == -1 ? x_dims.size () - y_dims.size () : axis);
546
- trim_trailing_singular_dims (y_dims);
596
+ trim_trailing_singular_dims (& y_dims);
547
597
axis = (y_dims.size () == 0 ) ? x_dims.size () : axis;
548
598
549
599
int pre, n, post;
550
- get_mid_dims (x_dims, y_dims, axis, pre, n, post);
600
+ get_mid_dims (x_dims, y_dims, axis, & pre, & n, & post);
551
601
552
602
if (post == 1 ) {
553
603
broadcastfunctor f;
@@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
582
632
axis = (axis == -1 ? x_dims.size () - y_dims.size () : axis);
583
633
PADDLE_ENFORCE (axis >= 0 && axis < x_dims.size (),
584
634
" Axis should be in range [0, x_dims)" );
585
- trim_trailing_singular_dims (y_dims);
635
+ trim_trailing_singular_dims (& y_dims);
586
636
axis = (y_dims.size () == 0 ) ? x_dims.size () : axis;
587
637
588
638
int pre, n, post;
589
- get_mid_dims (x_dims, y_dims, axis, pre, n, post);
639
+ get_mid_dims (x_dims, y_dims, axis, & pre, & n, & post);
590
640
if (post == 1 ) {
591
641
functor.RunRowWise (n, pre);
592
642
return ;
0 commit comments