Skip to content

Commit 3388e52

Browse files
authored
Bugfix/beamsearch op (#7611)
1 parent f086ebb commit 3388e52

File tree

4 files changed

+120
-13
lines changed

4 files changed

+120
-13
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,13 @@ foreach(src ${GENERAL_OPS})
178178
endforeach()
179179
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
180180

181-
182181
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
183182

184-
185183
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
186184
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
187185
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
188186
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
187+
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
189188
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
190189
if(WITH_GPU)
191190
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)

paddle/operators/beam_search_op.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
2929
PruneEndidCandidates(pre_ids, &selected_items);
3030
// calculate the output tensor's height
3131
size_t num_instances = std::accumulate(
32-
std::begin(items), std::end(items), 0,
32+
std::begin(selected_items), std::end(selected_items), 0,
3333
[](size_t a, std::vector<Item> &b) { return a + b.size(); });
3434
// the output tensor shape should be [num_instances, 1]
3535
auto dims = framework::make_ddim(
@@ -48,12 +48,20 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
4848
size_t low_offset = 0;
4949
for (auto &items : selected_items) {
5050
low_level.push_back(low_offset);
51+
sort(items.begin(), items.end(), [](const Item &a, const Item &b) {
52+
if (a.offset < b.offset) {
53+
return true;
54+
}
55+
return a.id < b.id;
56+
});
5157
for (auto &item : items) {
5258
ids_data[low_offset] = item.id;
5359
scores_data[low_offset] = item.score;
5460
low_offset++;
5561
}
5662
}
63+
low_level.push_back(low_offset);
64+
5765
// fill lod
5866
auto abs_lod = framework::ToAbsOffset(ids_->lod());
5967
auto &high_level = abs_lod[lod_level_];
@@ -64,16 +72,21 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
6472
selected_scores->set_lod(lod);
6573
}
6674

67-
void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
68-
std::vector<std::vector<Item>> *items) {
75+
int BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
76+
std::vector<std::vector<Item>> *items) {
6977
auto *pre_ids_data = pre_ids.data<int64_t>();
7078

79+
int res = 0;
7180
for (size_t offset = 0; offset < items->size(); offset++) {
7281
auto prefix_id = pre_ids_data[offset];
7382
if (prefix_id == end_id_) {
7483
items->at(offset).clear();
84+
} else {
85+
res++;
7586
}
7687
}
88+
89+
return res;
7790
}
7891

7992
std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
@@ -121,11 +134,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
121134
auto ids = *ids_;
122135
auto scores = *scores_;
123136

124-
auto source_abs_two_level_lod = framework::SliceInLevel(
125-
ids.lod(), lod_level_, sent_offset_, sent_offset_ + 1);
126-
source_abs_two_level_lod = framework::ToAbsOffset(source_abs_two_level_lod);
127137
auto abs_lod = framework::ToAbsOffset(ids.lod());
128-
PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL);
129138

130139
auto *ids_data = ids.data<int64_t>();
131140
auto *scores_data = scores.data<float>();

paddle/operators/beam_search_op.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ namespace operators {
7373
* second level:
7474
* [0, 2, 4]
7575
*
76-
* tensor's data
76+
* id tensor's data
77+
* [[
78+
* 4,
79+
* 1,
80+
* 3,
81+
* 8,
82+
* ]]
83+
*
84+
* score tensor's data
7785
* [[
7886
* 0.5,
7987
* 0.3,
@@ -137,16 +145,21 @@ class BeamSearch {
137145
Item() {}
138146
Item(size_t offset, size_t id, float score)
139147
: offset(offset), id(id), score(score) {}
140-
// offset in the lod_level_+1
148+
// offset in the higher lod level.
141149
size_t offset;
150+
// // prefix id in the lower lod level.
151+
// size_t prefix;
142152
// the candidate id
143153
id_t id;
144154
// the corresponding score
145155
score_t score;
146156
};
147157

148-
void PruneEndidCandidates(const framework::LoDTensor& pre_ids,
149-
std::vector<std::vector<Item>>* items);
158+
/*
159+
* Delete all the records that follows the end token.
160+
*/
161+
int PruneEndidCandidates(const framework::LoDTensor& pre_ids,
162+
std::vector<std::vector<Item>>* items);
150163

151164
/*
152165
* Transform the items into a map whose key is offset, value is the items.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/beam_search_op.h"
16+
17+
#include <gtest/gtest.h>
18+
#include <vector>
19+
20+
namespace paddle {
21+
namespace test {
22+
23+
using std::vector;
24+
using framework::LoDTensor;
25+
using framework::LoD;
26+
using operators::BeamSearch;
27+
using paddle::platform::CPUPlace;
28+
using std::cout;
29+
using std::endl;
30+
31+
void CreateInput(LoDTensor* ids, LoDTensor* scores) {
32+
LoD lod;
33+
vector<size_t> level0({0, 1, 4});
34+
vector<size_t> level1({0, 1, 2, 3, 4});
35+
lod.push_back(level0);
36+
lod.push_back(level1);
37+
ids->set_lod(lod);
38+
scores->set_lod(lod);
39+
40+
auto dims = framework::make_ddim(vector<int64_t>({4, 3}));
41+
ids->Resize(dims);
42+
scores->Resize(dims);
43+
CPUPlace place;
44+
45+
auto* ids_data = ids->mutable_data<int64_t>(place);
46+
auto* scores_data = scores->mutable_data<float>(place);
47+
vector<int64_t> _ids({4, 2, 5, 2, 1, 3, 3, 5, 2, 8, 2, 1});
48+
vector<float> _scores(
49+
{0.5, 0.3, 0.2, 0.6, 0.3, 0.1, 0.9, 0.5, 0.1, 0.7, 0.5, 0.1});
50+
51+
for (int i = 0; i < 12; i++) {
52+
ids_data[i] = _ids[i];
53+
scores_data[i] = _scores[i];
54+
}
55+
}
56+
57+
TEST(beam_search_op, run) {
58+
CPUPlace place;
59+
LoDTensor ids, scores;
60+
CreateInput(&ids, &scores);
61+
62+
LoDTensor pre_ids;
63+
pre_ids.Resize(framework::make_ddim(vector<int64_t>(4, 1)));
64+
for (int i = 0; i < 4; i++) {
65+
pre_ids.mutable_data<int64_t>(place)[i] = i + 1;
66+
}
67+
68+
BeamSearch beamsearch(ids, scores, (int64_t)0, (int64_t)2, 0);
69+
LoDTensor sids, sscores;
70+
beamsearch(pre_ids, &sids, &sscores);
71+
72+
LOG(INFO) << "score: " << sscores << endl;
73+
74+
ASSERT_EQ(sids.lod(), sscores.lod());
75+
76+
vector<int> tids({2, 4, 3, 8});
77+
vector<float> tscores({0.3, 0.5, 0.9, 0.7});
78+
79+
for (int i = 0; i < 4; i++) {
80+
ASSERT_EQ(tids[i], sids.data<int64_t>()[i]);
81+
ASSERT_EQ(tscores[i], sscores.data<float>()[i]);
82+
}
83+
}
84+
85+
} // namespace test
86+
} // namespace paddle

0 commit comments

Comments
 (0)