Skip to content

Commit fcf3b10

Browse files
authored
fix beam search decode kernel (PaddlePaddle#76238)
1 parent 2dab658 commit fcf3b10

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

paddle/phi/kernels/impl/beam_search_decode_kernel_impl.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ struct BeamSearchDecodeFunctor {
3434
score_tensor_(score_tensor) {
3535
tensor_on_gpu_ = false;
3636
// First make a copy of GPU data on CPU
37-
if (step_ids_origin_[0].place().GetType() == phi::AllocationType::GPU) {
38-
if (step_ids_origin_[0].place().GetType() == phi::AllocationType::GPU) {
37+
if (step_ids_origin_[0].place().GetType() == phi::AllocationType::GPU ||
38+
step_ids_origin_[0].place().GetType() == phi::AllocationType::CUSTOM) {
39+
if (step_ids_origin_[0].place().GetType() == phi::AllocationType::GPU ||
40+
step_ids_origin_[0].place().GetType() ==
41+
phi::AllocationType::CUSTOM) {
3942
tensor_on_gpu_ = true;
4043
}
4144
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
@@ -55,9 +58,13 @@ struct BeamSearchDecodeFunctor {
5558
step_ids_.push_back(out);
5659
}
5760
}
58-
if (step_scores_origin_[0].place().GetType() == phi::AllocationType::GPU) {
61+
if (step_scores_origin_[0].place().GetType() == phi::AllocationType::GPU ||
62+
step_scores_origin_[0].place().GetType() ==
63+
phi::AllocationType::CUSTOM) {
5964
if (step_scores_origin_[0].place().GetType() ==
60-
phi::AllocationType::GPU) {
65+
phi::AllocationType::GPU ||
66+
step_scores_origin_[0].place().GetType() ==
67+
phi::AllocationType::CUSTOM) {
6168
tensor_on_gpu_ = true;
6269
}
6370
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();

0 commit comments

Comments
 (0)