Skip to content

Commit b8f7fa9

Browse files
committed
replace __shfl with __shfl_sync
1 parent 7c90d7a commit b8f7fa9

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

paddle/cuda/src/hl_top_k.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "hl_base.h"
16-
#include "hl_sparse.ph"
17-
#include "hl_top_k.h"
15+
#include "paddle/cuda/include/hl_base.h"
16+
#include "paddle/cuda/include/hl_sparse.ph"
17+
#include "paddle/cuda/include/hl_top_k.h"
1818
#include "paddle/utils/Logging.h"
1919

2020
// using namespace hppl;
@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
244244
if (--beamSize == 0) break;
245245
__syncthreads();
246246

247+
// temporary solution
247248
unsigned mask = 0u;
248-
// CREATE_SHFL_MASK(mask, tid < len);
249+
CREATE_SHFL_MASK(mask, true);
249250

250251
if (tid == maxId[0]) {
251252
if (beam < maxLength) {

paddle/fluid/operators/top_k_op.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/operators/top_k_op.h"
1717
#include "paddle/fluid/platform/assert.h"
18+
#include "paddle/fluid/platform/cuda_primitives.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
235236
sh_topk[tid] = topk[*beam];
236237
}
237238
}
239+
// temporary solution
240+
unsigned mask = 0u;
241+
CREATE_SHFL_MASK(mask, true);
242+
238243
if (maxid[0] / 32 == warp) {
239-
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break;
244+
if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break;
240245
}
241246
}
242247
}

paddle/fluid/platform/cuda_primitives.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ template <typename T>
7272
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
7373
return __shfl_down(val, delta);
7474
}
75+
76+
template <typename T>
77+
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
78+
int width) {
79+
return __shfl(val, src_line, width);
80+
}
81+
7582
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
7683
#else
7784
#define FULL_WARP_MASK 0xFFFFFFFF

0 commit comments

Comments
 (0)