@@ -34,8 +34,11 @@ struct BeamSearchDecodeFunctor {
3434 score_tensor_(score_tensor) {
3535 tensor_on_gpu_ = false ;
3636 // First make a copy of GPU data on CPU
37- if (step_ids_origin_[0 ].place ().GetType () == phi::AllocationType::GPU) {
38- if (step_ids_origin_[0 ].place ().GetType () == phi::AllocationType::GPU) {
37+ if (step_ids_origin_[0 ].place ().GetType () == phi::AllocationType::GPU ||
38+ step_ids_origin_[0 ].place ().GetType () == phi::AllocationType::CUSTOM) {
39+ if (step_ids_origin_[0 ].place ().GetType () == phi::AllocationType::GPU ||
40+ step_ids_origin_[0 ].place ().GetType () ==
41+ phi::AllocationType::CUSTOM) {
3942 tensor_on_gpu_ = true ;
4043 }
4144 phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance ();
@@ -55,9 +58,13 @@ struct BeamSearchDecodeFunctor {
5558 step_ids_.push_back (out);
5659 }
5760 }
58- if (step_scores_origin_[0 ].place ().GetType () == phi::AllocationType::GPU) {
61+ if (step_scores_origin_[0 ].place ().GetType () == phi::AllocationType::GPU ||
62+ step_scores_origin_[0 ].place ().GetType () ==
63+ phi::AllocationType::CUSTOM) {
5964 if (step_scores_origin_[0 ].place ().GetType () ==
60- phi::AllocationType::GPU) {
65+ phi::AllocationType::GPU ||
66+ step_scores_origin_[0 ].place ().GetType () ==
67+ phi::AllocationType::CUSTOM) {
6168 tensor_on_gpu_ = true ;
6269 }
6370 phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance ();
0 commit comments