Skip to content

Commit 3bb99c4

Browse files
author
Qingsheng Li
authored
Added auto transform to beam_search_decode_op (#10286)
* Added auto transform to beam_search_decode_op * Added some comment * Added unittest for beam_search_decode_op on GPU
1 parent ddf6167 commit 3bb99c4

File tree

2 files changed

+60
-10
lines changed

2 files changed

+60
-10
lines changed

paddle/fluid/operators/beam_search_decode_op.cc

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,69 @@ struct BeamSearchDecodeFunctor {
2323
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
2424
const LoDTensorArray& step_scores,
2525
LoDTensor* id_tensor, LoDTensor* score_tensor)
26-
: step_ids_(step_ids),
27-
step_scores_(step_scores),
26+
: step_ids_origin_(step_ids),
27+
step_scores_origin_(step_scores),
2828
id_tensor_(id_tensor),
29-
score_tensor_(score_tensor) {}
29+
score_tensor_(score_tensor) {
30+
tensor_on_gpu_ = false;
31+
// First make a copy of GPU data on CPU
32+
if (platform::is_gpu_place(step_ids_origin_[0].place())) {
33+
tensor_on_gpu_ = true;
34+
platform::DeviceContextPool& pool =
35+
platform::DeviceContextPool::Instance();
36+
auto* dev_ctx = pool.Get(step_ids_origin_[0].place());
37+
// Copy all tensors in the input tensor array
38+
for (auto& step_id : step_ids_origin_) {
39+
framework::LoDTensor out;
40+
dev_ctx->Wait();
41+
framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
42+
dev_ctx->Wait();
43+
44+
out.set_lod(step_id.lod());
45+
step_ids_.push_back(out);
46+
}
47+
}
48+
if (platform::is_gpu_place(step_scores_origin_[0].place())) {
49+
tensor_on_gpu_ = true;
50+
platform::DeviceContextPool& pool =
51+
platform::DeviceContextPool::Instance();
52+
auto* dev_ctx = pool.Get(step_scores_origin_[0].place());
53+
// Copy all tensors in the input tensor array
54+
for (auto& step_score : step_scores_origin_) {
55+
framework::LoDTensor out;
56+
dev_ctx->Wait();
57+
framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx, &out);
58+
dev_ctx->Wait();
59+
60+
out.set_lod(step_score.lod());
61+
step_scores_.push_back(out);
62+
}
63+
}
64+
}
3065

3166
template <typename T>
3267
void operator()() const;
3368

34-
const LoDTensorArray& step_ids_;
35-
const LoDTensorArray& step_scores_;
69+
bool tensor_on_gpu_;
70+
const LoDTensorArray& step_ids_origin_;
71+
const LoDTensorArray& step_scores_origin_;
72+
LoDTensorArray step_ids_ = LoDTensorArray();
73+
LoDTensorArray step_scores_ = LoDTensorArray();
3674
LoDTensor* id_tensor_;
3775
LoDTensor* score_tensor_;
3876
};
3977

4078
template <typename T>
4179
void BeamSearchDecodeFunctor::operator()() const {
4280
BeamSearchDecoder<T> beam_search_decoder;
43-
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
44-
score_tensor_);
81+
// Check if the tensor is on GPU. If so, use the CPU copy instead
82+
if (tensor_on_gpu_) {
83+
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
84+
score_tensor_);
85+
} else {
86+
beam_search_decoder.PackAllSteps(step_ids_origin_, step_scores_origin_,
87+
id_tensor_, score_tensor_);
88+
}
4589
}
4690

4791
template <>

python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
class TestBeamSearchDecodeOp(unittest.TestCase):
2323
def setUp(self):
2424
self.scope = core.Scope()
25-
self.cpu_place = core.CPUPlace()
25+
self.place = core.CPUPlace()
2626

2727
def append_lod_tensor(self, tensor_array, lod, data):
2828
lod_tensor = core.LoDTensor()
2929
lod_tensor.set_lod(lod)
30-
lod_tensor.set(data, self.cpu_place)
30+
lod_tensor.set(data, self.place)
3131
tensor_array.append(lod_tensor)
3232

3333
def test_get_set(self):
@@ -71,7 +71,7 @@ def test_get_set(self):
7171
SentenceIds="sentence_ids",
7272
SentenceScores="sentence_scores")
7373

74-
beam_search_decode_op.run(self.scope, self.cpu_place)
74+
beam_search_decode_op.run(self.scope, self.place)
7575

7676
expected_lod = [[0, 4, 8], [0, 1, 3, 6, 9, 10, 13, 16, 19]]
7777
self.assertEqual(sentence_ids.lod(), expected_lod)
@@ -84,5 +84,11 @@ def test_get_set(self):
8484
np.array_equal(np.array(sentence_scores), expected_data))
8585

8686

87+
class TestBeamSearchDecodeOpGPU(TestBeamSearchDecodeOp):
88+
def setUp(self):
89+
self.scope = core.Scope()
90+
self.place = core.CUDAPlace(0)
91+
92+
8793
if __name__ == '__main__':
8894
unittest.main()

0 commit comments

Comments
 (0)