Skip to content

Commit 3222cf1

Browse files
author
chengduo
authored
Merge pull request #10325 from chengduoZH/fix_shfl_sync
Fix shfl_sync for CUDA8.0
2 parents 4613aeb + 90d73c7 commit 3222cf1

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

paddle/cuda/include/hl_base.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,21 @@ extern __thread cudaStream_t default_stream;
228228
<< "CUDA error: " << hl_get_device_error_string((size_t)err); \
229229
}
230230

231+
// __shfl has been deprecated as of CUDA 9.0.
232+
#if CUDA_VERSION < 9000
233+
template <typename T>
234+
__forceinline__ __device__ T
235+
__shfl_sync(unsigned, T val, int src_line, int width) {
236+
return __shfl(val, src_line, width);
237+
}
238+
239+
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
240+
#else
241+
#define FULL_WARP_MASK 0xFFFFFFFF
242+
#define CREATE_SHFL_MASK(mask, predicate) \
243+
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
244+
#endif
245+
231246
#endif /* __NVCC__ */
232247

233248
#endif /* HL_BASE_H_ */

paddle/cuda/src/hl_cuda_lstm.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue,
341341
}
342342

343343
__device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
344-
int addr = idx % 32;
344+
const int warp_size = 32;
345+
int addr = idx % warp_size;
346+
unsigned mask = 0u;
347+
CREATE_SHFL_MASK(mask, addr < warp_size);
345348
#pragma unroll
346349
for (int k = 1; k < 32; k++) {
347350
// rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32);
348-
addr = __shfl_sync(addr, (idx + 1) % 32, 32);
349-
a[k] = __shfl_sync(a[k], addr, 32);
351+
addr = __shfl_sync(mask, addr, (idx + 1) % 32, 32);
352+
a[k] = __shfl_sync(mask, a[k], addr, 32);
350353
}
351354

352355
#pragma unroll
@@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
360363
}
361364

362365
addr = (32 - idx) % 32;
366+
CREATE_SHFL_MASK(mask, idx % 32 < warp_size);
363367
#pragma unroll
364368
for (int k = 0; k < 32; k++) {
365-
a[k] = __shfl_sync(a[k], addr, 32);
366-
addr = __shfl_sync(addr, (idx + 31) % 32, 32);
369+
a[k] = __shfl_sync(mask, a[k], addr, 32);
370+
addr = __shfl_sync(mask, addr, (idx + 31) % 32, 32);
367371
}
368372
}
369373

paddle/cuda/src/hl_top_k.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
244244
if (--beamSize == 0) break;
245245
__syncthreads();
246246

247+
unsigned mask = 0u;
248+
// CREATE_SHFL_MASK(mask, tid < len);
249+
247250
if (tid == maxId[0]) {
248251
if (beam < maxLength) {
249252
shTopK[tid] = topK[beam];
250253
}
251254
}
252255
if (maxId[0] / 32 == warp) {
253-
if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break;
256+
if (__shfl_sync(mask, beam, (maxId[0]) % 32, 32) == maxLength) break;
254257
}
255258
}
256259
}

paddle/fluid/platform/cuda_primitives.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,6 @@ __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
7474
}
7575
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
7676
#else
77-
template <typename T>
78-
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) {
79-
return __shfl_down(mask, val, delta);
80-
}
8177
#define FULL_WARP_MASK 0xFFFFFFFF
8278
#define CREATE_SHFL_MASK(mask, predicate) \
8379
mask = __ballot_sync(FULL_WARP_MASK, (predicate))

0 commit comments

Comments
 (0)