Skip to content

Commit 4558c0e

Browse files
author
chengduo
authored
Merge pull request #10414 from chengduoZH/wrap_shfl_x_sync
Wrap shfl_x_sync
2 parents 2d98a41 + d36af62 commit 4558c0e

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

paddle/fluid/operators/row_conv_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
224224

225225
for (int offset = 16; offset > 0;
226226
offset = offset / 2) { // blockDim.x is 32.
227-
val += platform::__shfl_down_sync(mask, val, offset);
227+
val += platform::CudaShuffleDownSync(mask, val, offset);
228228
}
229229
__syncthreads();
230230

@@ -284,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
284284

285285
for (int offset = 16; offset > 0;
286286
offset = offset / 2) { // blockDim.x is 32.
287-
val += platform::__shfl_down_sync(mask, val, offset);
287+
val += platform::CudaShuffleDownSync(mask, val, offset);
288288
}
289289
__syncthreads();
290290

paddle/fluid/operators/top_k_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
241241
CREATE_SHFL_MASK(mask, true);
242242

243243
if (maxid[0] / 32 == warp) {
244-
if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength)
244+
if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
245+
MaxLength)
245246
break;
246247
}
247248
}

paddle/fluid/platform/cuda_device_function.h

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,33 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace platform {
2020

21-
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
2221
#if CUDA_VERSION < 9000
23-
template <typename T>
24-
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
25-
return __shfl_down(val, delta);
26-
}
27-
28-
template <typename T>
29-
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
30-
int width) {
31-
return __shfl(val, src_line, width);
32-
}
3322
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
3423
#else
3524
#define FULL_WARP_MASK 0xFFFFFFFF
3625
#define CREATE_SHFL_MASK(mask, predicate) \
3726
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
27+
#endif
28+
3829
template <typename T>
39-
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) {
40-
return __shfl_down_sync(mask, val, delta);
30+
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
31+
int delta, int width = 32) {
32+
#if CUDA_VERSION < 9000
33+
return __shfl_down(val, delta, width);
34+
#else
35+
return __shfl_down_sync(mask, val, delta, width);
36+
#endif
4137
}
4238

4339
template <typename T>
44-
__forceinline__ __device__ T __shfl_sync(unsigned mask, T val, int src_line,
45-
int width) {
40+
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
41+
int width = 32) {
42+
#if CUDA_VERSION < 9000
43+
return __shfl(val, src_line, width);
44+
#else
4645
return __shfl_sync(mask, val, src_line, width);
47-
}
4846
#endif
47+
}
4948

5049
template <typename T>
5150
__device__ T reduceSum(T val, int tid, int len) {
@@ -61,7 +60,7 @@ __device__ T reduceSum(T val, int tid, int len) {
6160
CREATE_SHFL_MASK(mask, tid < len);
6261

6362
for (int offset = warpSize / 2; offset > 0; offset /= 2)
64-
val += platform::__shfl_down_sync(mask, val, offset);
63+
val += platform::CudaShuffleDownSync(mask, val, offset);
6564

6665
if (tid < warpSize) shm[tid] = 0;
6766

@@ -75,7 +74,7 @@ __device__ T reduceSum(T val, int tid, int len) {
7574
if (tid < warpSize) {
7675
val = shm[tid];
7776
for (int offset = warpSize / 2; offset > 0; offset /= 2)
78-
val += platform::__shfl_down_sync(mask, val, offset);
77+
val += platform::CudaShuffleDownSync(mask, val, offset);
7978
}
8079
return val;
8180
}

0 commit comments

Comments
 (0)