Skip to content

Commit 54797ab

Browse files
author
chengduo
authored
Merge pull request #10347 from chengduoZH/replace___shfl_with__shfl_sync
Wrap __shfl
2 parents 62fed4c + 0cc6354 commit 54797ab

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

paddle/cuda/src/hl_top_k.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "hl_base.h"
16-
#include "hl_sparse.ph"
17-
#include "hl_top_k.h"
15+
#include "paddle/cuda/include/hl_base.h"
16+
#include "paddle/cuda/include/hl_sparse.ph"
17+
#include "paddle/cuda/include/hl_top_k.h"
1818
#include "paddle/utils/Logging.h"
1919

2020
// using namespace hppl;
@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
244244
if (--beamSize == 0) break;
245245
__syncthreads();
246246

247+
// NOTE(zcd): temporary solution
247248
unsigned mask = 0u;
248-
// CREATE_SHFL_MASK(mask, tid < len);
249+
CREATE_SHFL_MASK(mask, true);
249250

250251
if (tid == maxId[0]) {
251252
if (beam < maxLength) {

paddle/fluid/operators/top_k_op.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/operators/top_k_op.h"
1717
#include "paddle/fluid/platform/assert.h"
18+
#include "paddle/fluid/platform/cuda_device_function.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -235,8 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
235236
sh_topk[tid] = topk[*beam];
236237
}
237238
}
239+
// NOTE(zcd): temporary solution
240+
unsigned mask = 0u;
241+
CREATE_SHFL_MASK(mask, true);
242+
238243
if (maxid[0] / 32 == warp) {
239-
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break;
244+
if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength)
245+
break;
240246
}
241247
}
242248
}

paddle/fluid/platform/cuda_primitives.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
6565
return __longlong_as_double(old);
6666
}
6767
#endif
68-
6968
} // namespace platform
7069
} // namespace paddle

0 commit comments

Comments
 (0)