Skip to content

Commit 8a071ff

Browse files
author
chengduo
authored
Merge pull request #10366 from chengduoZH/feature/fix_shlf_for_cuda9.0
Fix __shfl and __shfl_down for CUDA9.0
2 parents e3b8db0 + e97c1a8 commit 8a071ff

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

paddle/fluid/platform/cuda_device_function.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ __forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
3535
#define FULL_WARP_MASK 0xFFFFFFFF
3636
#define CREATE_SHFL_MASK(mask, predicate) \
3737
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
38+
template <typename T>
39+
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) {
40+
return __shfl_down_sync(mask, val, delta);
41+
}
42+
43+
template <typename T>
44+
__forceinline__ __device__ T __shfl_sync(unsigned mask, T val, int src_line,
45+
int width) {
46+
return __shfl_sync(mask, val, src_line, width);
47+
}
3848
#endif
3949

4050
template <typename T>

0 commit comments

Comments
 (0)