@@ -18,34 +18,33 @@ limitations under the License. */
18
18
namespace paddle {
19
19
namespace platform {
20
20
21
- // __shfl_down and __shfl have been deprecated as of CUDA 9.0.
22
21
#if CUDA_VERSION < 9000
23
- template <typename T>
24
- __forceinline__ __device__ T __shfl_down_sync (unsigned , T val, int delta) {
25
- return __shfl_down (val, delta);
26
- }
27
-
28
- template <typename T>
29
- __forceinline__ __device__ T __shfl_sync (unsigned , T val, int src_line,
30
- int width) {
31
- return __shfl (val, src_line, width);
32
- }
33
22
#define CREATE_SHFL_MASK (mask, predicate ) mask = 0u ;
34
23
#else
35
24
#define FULL_WARP_MASK 0xFFFFFFFF
36
25
#define CREATE_SHFL_MASK (mask, predicate ) \
37
26
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
27
+ #endif
28
+
38
29
template <typename T>
39
- __forceinline__ __device__ T __shfl_down_sync (unsigned mask, T val, int delta) {
40
- return __shfl_down_sync (mask, val, delta);
30
+ __forceinline__ __device__ T CudaShuffleDownSync (unsigned mask, T val,
31
+ int delta, int width = 32 ) {
32
+ #if CUDA_VERSION < 9000
33
+ return __shfl_down (val, delta, width);
34
+ #else
35
+ return __shfl_down_sync (mask, val, delta, width);
36
+ #endif
41
37
}
42
38
43
39
template <typename T>
44
- __forceinline__ __device__ T __shfl_sync (unsigned mask, T val, int src_line,
45
- int width) {
40
+ __forceinline__ __device__ T CudaShuffleSync (unsigned mask, T val, int src_line,
41
+ int width = 32 ) {
42
+ #if CUDA_VERSION < 9000
43
+ return __shfl (val, src_line, width);
44
+ #else
46
45
return __shfl_sync (mask, val, src_line, width);
47
- }
48
46
#endif
47
+ }
49
48
50
49
template <typename T>
51
50
__device__ T reduceSum (T val, int tid, int len) {
@@ -61,7 +60,7 @@ __device__ T reduceSum(T val, int tid, int len) {
61
60
CREATE_SHFL_MASK (mask, tid < len);
62
61
63
62
for (int offset = warpSize / 2 ; offset > 0 ; offset /= 2 )
64
- val += platform::__shfl_down_sync (mask, val, offset);
63
+ val += platform::CudaShuffleDownSync (mask, val, offset);
65
64
66
65
if (tid < warpSize) shm[tid] = 0 ;
67
66
@@ -75,7 +74,7 @@ __device__ T reduceSum(T val, int tid, int len) {
75
74
if (tid < warpSize) {
76
75
val = shm[tid];
77
76
for (int offset = warpSize / 2 ; offset > 0 ; offset /= 2 )
78
- val += platform::__shfl_down_sync (mask, val, offset);
77
+ val += platform::CudaShuffleDownSync (mask, val, offset);
79
78
}
80
79
return val;
81
80
}
0 commit comments