@@ -34,7 +34,7 @@ uint64_t Pack(uint32_t LowBits, uint32_t HighBits) {
3434 return (((uint64_t )HighBits) << 32 ) | (uint64_t )LowBits;
3535}
3636
37- int32_t shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane);
37+ int32_t shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane, int32_t Width );
3838int32_t shuffleDown (uint64_t Mask, int32_t Var, uint32_t LaneDelta,
3939 int32_t Width);
4040
@@ -45,8 +45,7 @@ uint64_t ballotSync(uint64_t Mask, int32_t Pred);
4545// /{
4646#pragma omp begin declare variant match(device = {arch(amdgcn)})
4747
48- int32_t shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane) {
49- int Width = mapping::getWarpSize ();
48+ int32_t shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane, int32_t Width) {
5049 int Self = mapping::getThreadIdInWarp ();
5150 int Index = SrcLane + (Self & ~(Width - 1 ));
5251 return __builtin_amdgcn_ds_bpermute (Index << 2 , Var);
@@ -82,8 +81,8 @@ bool isThreadLocalMemPtr(const void *Ptr) {
8281 device = {arch (nvptx, nvptx64)}, \
8382 implementation = {extension (match_any)})
8483
85- int32_t shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane) {
86- return __nvvm_shfl_sync_idx_i32 (Mask, Var, SrcLane, 0x1f );
84+ int32_t shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane, int32_t Width ) {
85+ return __nvvm_shfl_sync_idx_i32 (Mask, Var, SrcLane, Width );
8786}
8887
8988int32_t shuffleDown (uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width) {
@@ -111,8 +110,9 @@ void utils::unpack(uint64_t Val, uint32_t &LowBits, uint32_t &HighBits) {
111110 impl::Unpack (Val, &LowBits, &HighBits);
112111}
113112
114- int32_t utils::shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane) {
115- return impl::shuffle (Mask, Var, SrcLane);
113+ int32_t utils::shuffle (uint64_t Mask, int32_t Var, int32_t SrcLane,
114+ int32_t Width) {
115+ return impl::shuffle (Mask, Var, SrcLane, Width);
116116}
117117
118118int32_t utils::shuffleDown (uint64_t Mask, int32_t Var, uint32_t Delta,
0 commit comments