@@ -23,25 +23,69 @@ struct BeamSearchDecodeFunctor {
23
23
BeamSearchDecodeFunctor (const LoDTensorArray& step_ids,
24
24
const LoDTensorArray& step_scores,
25
25
LoDTensor* id_tensor, LoDTensor* score_tensor)
26
- : step_ids_ (step_ids),
27
- step_scores_ (step_scores),
26
+ : step_ids_origin_ (step_ids),
27
+ step_scores_origin_ (step_scores),
28
28
id_tensor_(id_tensor),
29
- score_tensor_(score_tensor) {}
29
+ score_tensor_(score_tensor) {
30
+ tensor_on_gpu_ = false ;
31
+ // First make a copy of GPU data on CPU
32
+ if (platform::is_gpu_place (step_ids_origin_[0 ].place ())) {
33
+ tensor_on_gpu_ = true ;
34
+ platform::DeviceContextPool& pool =
35
+ platform::DeviceContextPool::Instance ();
36
+ auto * dev_ctx = pool.Get (step_ids_origin_[0 ].place ());
37
+ // Copy all tensors in the input tensor array
38
+ for (auto & step_id : step_ids_origin_) {
39
+ framework::LoDTensor out;
40
+ dev_ctx->Wait ();
41
+ framework::TensorCopy (step_id, platform::CPUPlace (), *dev_ctx, &out);
42
+ dev_ctx->Wait ();
43
+
44
+ out.set_lod (step_id.lod ());
45
+ step_ids_.push_back (out);
46
+ }
47
+ }
48
+ if (platform::is_gpu_place (step_scores_origin_[0 ].place ())) {
49
+ tensor_on_gpu_ = true ;
50
+ platform::DeviceContextPool& pool =
51
+ platform::DeviceContextPool::Instance ();
52
+ auto * dev_ctx = pool.Get (step_scores_origin_[0 ].place ());
53
+ // Copy all tensors in the input tensor array
54
+ for (auto & step_score : step_scores_origin_) {
55
+ framework::LoDTensor out;
56
+ dev_ctx->Wait ();
57
+ framework::TensorCopy (step_score, platform::CPUPlace (), *dev_ctx, &out);
58
+ dev_ctx->Wait ();
59
+
60
+ out.set_lod (step_score.lod ());
61
+ step_scores_.push_back (out);
62
+ }
63
+ }
64
+ }
30
65
31
66
template <typename T>
32
67
void operator ()() const ;
33
68
34
- const LoDTensorArray& step_ids_;
35
- const LoDTensorArray& step_scores_;
69
+ bool tensor_on_gpu_;
70
+ const LoDTensorArray& step_ids_origin_;
71
+ const LoDTensorArray& step_scores_origin_;
72
+ LoDTensorArray step_ids_ = LoDTensorArray();
73
+ LoDTensorArray step_scores_ = LoDTensorArray();
36
74
LoDTensor* id_tensor_;
37
75
LoDTensor* score_tensor_;
38
76
};
39
77
40
78
template <typename T>
41
79
void BeamSearchDecodeFunctor::operator ()() const {
42
80
BeamSearchDecoder<T> beam_search_decoder;
43
- beam_search_decoder.PackAllSteps (step_ids_, step_scores_, id_tensor_,
44
- score_tensor_);
81
+ // Check if the tensor is on GPU. If so, use the CPU copy instead
82
+ if (tensor_on_gpu_) {
83
+ beam_search_decoder.PackAllSteps (step_ids_, step_scores_, id_tensor_,
84
+ score_tensor_);
85
+ } else {
86
+ beam_search_decoder.PackAllSteps (step_ids_origin_, step_scores_origin_,
87
+ id_tensor_, score_tensor_);
88
+ }
45
89
}
46
90
47
91
template <>
0 commit comments