Skip to content

Commit a410627

Browse files
authored
BeamSearchDecodeOp (#5498)
* init trieconcat_op * add basic implementation * add test * add more test * update unit test * add PackAllSteps test * fix PackAllSteps * all test passed * clean code * remove state inside helper * rename prob to score * optimize RemoveFromEnd * use deconstructor to delete BeamNode recursively * optimize interface * add comment to interface * optimizer data structure * use template to define the type of score * use template parameter for BeamHelper * change father to parent * rename TrieConcat to BeamSearchOutConcat * use LoDTensorArray * rename BeamSearchOutConcat to BeamSearchDecode * refine code * remain all candidate sentence in beam_search_decode_op, do not consider endid * use unique_ptr * fix compare bug * fix lod compile problem
1 parent 93c6e52 commit a410627

File tree

5 files changed

+613
-1
lines changed

5 files changed

+613
-1
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
214214
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
215215
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
216216
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
217+
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
217218
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
218219
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
219220
rnn/recurrent_op_utils.cc
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/beam_search_decode_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class BeamSearchDecodeOp : public framework::OperatorBase {
21+
public:
22+
BeamSearchDecodeOp(const std::string& type,
23+
const framework::VariableNameMap& inputs,
24+
const framework::VariableNameMap& outputs,
25+
const framework::AttributeMap& attrs)
26+
: OperatorBase(type, inputs, outputs, attrs) {}
27+
void Run(const framework::Scope& scope,
28+
const platform::DeviceContext& dev_ctx) const override {
29+
framework::ExecutionContext ctx(*this, scope, dev_ctx);
30+
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
31+
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
32+
const size_t step_num = ids->size();
33+
PADDLE_ENFORCE_GT(step_num, 0UL,
34+
"beam search steps should be larger than 0");
35+
const size_t source_num = ids->at(0).lod().at(0).size() - 1;
36+
PADDLE_ENFORCE_GT(source_num, 0UL, "source num should be larger than 0");
37+
38+
for (size_t i = 0; i < step_num; ++i) {
39+
PADDLE_ENFORCE_EQ(ids->at(i).lod().size(), 2UL,
40+
"Level of LodTensor should be 2");
41+
}
42+
43+
// prepare output
44+
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
45+
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
46+
47+
BeamSearchDecoder<float> beam_search_decoder;
48+
beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds,
49+
sentenceScores);
50+
}
51+
};
52+
53+
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
54+
public:
55+
BeamSearchDecodeOpProtoMaker(framework::OpProto* proto,
56+
framework::OpAttrChecker* op_checker)
57+
: OpProtoAndCheckerMaker(proto, op_checker) {
58+
AddInput("Ids",
59+
"(LodTensorArray)"
60+
"score of the candidate words in each step");
61+
AddInput("Scores",
62+
"(LodTensorArray)"
63+
"score of the candidate words in each step");
64+
AddOutput("SentenceIds",
65+
"(LodTensor)"
66+
"All possible result sentences of word ids");
67+
AddOutput("SentenceScores",
68+
"(LodTensor)"
69+
"All possible result sentences of word scores");
70+
AddComment(R"DOC(
71+
Pack the result of Beam search op into SentenceIds and SentenceScores.
72+
)DOC");
73+
}
74+
};
75+
76+
class BeamSearchDecodeInferShape : public framework::InferShapeBase {
77+
public:
78+
void operator()(framework::InferShapeContext* context) const override {
79+
PADDLE_ENFORCE(context->HasInput("Ids"),
80+
"BeamSearchDecodeOp must has input Ids");
81+
PADDLE_ENFORCE(context->HasInput("Scores"),
82+
"BeamSearchDecodeOp must has input Scores");
83+
PADDLE_ENFORCE(context->HasOutput("SentenceIds"),
84+
"BeamSearchDecodeOp must has output SentenceIds");
85+
PADDLE_ENFORCE(context->HasOutput("SentenceScores"),
86+
"BeamSearchDecodeOp must has output SentenceScores");
87+
}
88+
};
89+
90+
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
91+
public:
92+
void operator()(const framework::OpDescBind& op_desc,
93+
framework::BlockDescBind* block) const override {
94+
for (auto& o : op_desc.Output("SentenceIds")) {
95+
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
96+
}
97+
for (auto& o : op_desc.Output("SentenceScores")) {
98+
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
99+
}
100+
}
101+
};
102+
103+
} // namespace operators
104+
} // namespace paddle
105+
106+
REGISTER_OPERATOR(beam_search_decode, paddle::operators::BeamSearchDecodeOp,
107+
paddle::operators::BeamSearchDecodeOpProtoMaker,
108+
paddle::operators::BeamSearchDecodeInferShape,
109+
paddle::operators::BeamSearchDecodeInferVarType,
110+
paddle::framework::EmptyGradOpMaker);
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/framework/lod_tensor_array.h"
18+
#include "paddle/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using LoDTensor = framework::LoDTensor;
24+
using LoDTensorArray = framework::LoDTensorArray;
25+
26+
// all the lod have 2 levels.
27+
// The First is source level, the second is sentence level.
28+
// source level describe how many candidate words for this source.
29+
// sentence level describe these candidates belong to which prefix
30+
const size_t kSourceLevel = 0;
31+
const size_t kSentenceLevel = 1;
32+
33+
template <typename T>
34+
struct BeamNode {
35+
BeamNode(int64_t word_id, T score) : word_id_(word_id), score_(score) {}
36+
37+
~BeamNode() {
38+
if (parent_) {
39+
parent_->DropKid(this);
40+
if (parent_->kids_.size() == 0UL) {
41+
delete parent_;
42+
}
43+
}
44+
VLOG(3) << "Delete BeamNode root with word_id:" << this->word_id_;
45+
}
46+
47+
void AppendTo(BeamNode* parent) {
48+
parent_ = parent;
49+
parent->kids_.insert(this);
50+
}
51+
52+
void DropKid(BeamNode* kid) { kids_.erase(kid); }
53+
54+
BeamNode* parent_ = nullptr;
55+
std::unordered_set<BeamNode*> kids_;
56+
int64_t word_id_;
57+
T score_;
58+
};
59+
60+
template <typename T>
61+
using BeamNodeVector = std::vector<std::unique_ptr<BeamNode<T>>>;
62+
63+
template <typename T>
64+
struct Sentence {
65+
std::vector<int64_t> word_ids;
66+
std::vector<T> scores;
67+
};
68+
69+
template <typename T>
70+
using SentenceVector = std::vector<Sentence<T>>;
71+
72+
template <typename T>
73+
struct BeamSearchDecoder {
74+
/**
75+
* make a BeamNode and all it's related prefix BeanNode into a Sentence.
76+
*/
77+
Sentence<T> MakeSentence(const BeamNode<T>* node) const;
78+
79+
/**
80+
* Param:
81+
* cur_ids: LoDTensor of One step for word ID
82+
* cur_scores: LoDTensor of One Step for word score
83+
* prefixes_list: prefixes for each source sentence.
84+
* sentence_vector_list: result sentence_vector for each source sentence.
85+
* Return:
86+
* a new prefixes list for each source of current step
87+
*/
88+
std::vector<BeamNodeVector<T>> PackTwoSteps(
89+
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
90+
std::vector<BeamNodeVector<T>>& prefixes_list,
91+
std::vector<SentenceVector<T>>* sentence_vector_list) const;
92+
93+
/**
94+
* convert the result sentence_vector for each source sentence into two
95+
* LodTensor.
96+
* One is all candidate sentences with word id, one is all candidate sentences
97+
* with word score.
98+
* Param:
99+
* sentence_vector_list: sentence_vector for each source sentence.
100+
* id_tensor: result LoDTensor for sentences of id.
101+
* score_tensor: result LoDTensor for sentences of score.
102+
*/
103+
void ConvertSentenceVectorToLodTensor(
104+
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
105+
LoDTensor* score_tensor) const;
106+
107+
/**
108+
* Pack all steps of id/score LodTensor into sentence LoDTensor
109+
* it's main logic is:
110+
* ```python
111+
* prefix
112+
* result_sentence
113+
* result_lod_tensor
114+
*
115+
* for (step in steps):
116+
* prefix = PackTwoSteps(prefix, step, &result_sentence)
117+
* ConvertSentenceVector<T>ToLodTensor(result_sentence, &result_lod_tensor)
118+
* ```
119+
*/
120+
void PackAllSteps(const LoDTensorArray& step_ids,
121+
const LoDTensorArray& step_scores, LoDTensor* id_tensor,
122+
LoDTensor* score_tensor) const;
123+
};
124+
125+
template <typename T>
126+
Sentence<T> BeamSearchDecoder<T>::MakeSentence(const BeamNode<T>* node) const {
127+
Sentence<T> sentence;
128+
while (node != nullptr) {
129+
sentence.word_ids.emplace_back(node->word_id_);
130+
sentence.scores.emplace_back(node->score_);
131+
node = node->parent_;
132+
}
133+
134+
std::reverse(std::begin(sentence.word_ids), std::end(sentence.word_ids));
135+
std::reverse(std::begin(sentence.scores), std::end(sentence.scores));
136+
137+
return sentence;
138+
}
139+
140+
template <typename T>
141+
std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
142+
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
143+
std::vector<BeamNodeVector<T>>& prefixes_list,
144+
std::vector<SentenceVector<T>>* sentence_vector_list) const {
145+
std::vector<BeamNodeVector<T>> result;
146+
147+
for (size_t src_idx = 0; src_idx < cur_ids.lod()[kSourceLevel].size() - 1;
148+
++src_idx) {
149+
size_t src_start = cur_ids.lod().at(kSourceLevel)[src_idx];
150+
size_t src_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1];
151+
152+
BeamNodeVector<T> beam_nodes;
153+
154+
// if prefixes size is 0, it means this is the first step. In this step,
155+
// all candidate id is the start of candidate sentences.
156+
if (prefixes_list.empty()) {
157+
PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(),
158+
cur_ids.lod().at(kSentenceLevel).back(),
159+
"in the first step");
160+
for (size_t id_idx = src_start; id_idx < src_end; ++id_idx) {
161+
beam_nodes.push_back(std::unique_ptr<BeamNode<T>>(new BeamNode<T>(
162+
cur_ids.data<int64_t>()[id_idx], cur_scores.data<T>()[id_idx])));
163+
}
164+
} else {
165+
BeamNodeVector<T>& prefixes = prefixes_list[src_idx];
166+
SentenceVector<T>& sentence_vector = (*sentence_vector_list)[src_idx];
167+
168+
PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(),
169+
"prefix and candidate set number should be the same");
170+
171+
auto candidate_offset = cur_ids.lod()[kSentenceLevel];
172+
for (size_t prefix_idx = 0; prefix_idx < prefixes.size(); ++prefix_idx) {
173+
std::unique_ptr<BeamNode<T>>& prefix = prefixes[prefix_idx];
174+
size_t candidate_start = candidate_offset[src_start + prefix_idx];
175+
size_t candidate_end = candidate_offset[src_start + prefix_idx + 1];
176+
if (candidate_start == candidate_end) {
177+
VLOG(3) << "this sentence has no more candidate, "
178+
"add to result sentence and rm it from beam tree";
179+
sentence_vector.push_back(MakeSentence(prefix.get()));
180+
prefix.reset();
181+
} else {
182+
for (size_t candidate_idx = candidate_start;
183+
candidate_idx < candidate_end; ++candidate_idx) {
184+
auto* candidate =
185+
new BeamNode<T>(cur_ids.data<int64_t>()[candidate_idx],
186+
cur_scores.data<T>()[candidate_idx]);
187+
candidate->AppendTo(prefix.get());
188+
beam_nodes.push_back(std::unique_ptr<BeamNode<T>>(candidate));
189+
}
190+
prefix.release();
191+
}
192+
}
193+
}
194+
result.push_back(std::move(beam_nodes));
195+
}
196+
return result;
197+
}
198+
199+
template <typename T>
200+
void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
201+
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
202+
LoDTensor* score_tensor) const {
203+
size_t src_num = sentence_vector_list.size();
204+
205+
PADDLE_ENFORCE_NE(src_num, 0, "src_num should not be 0");
206+
207+
std::vector<size_t> source_level_lod = {0};
208+
std::vector<size_t> sentence_level_lod = {0};
209+
std::vector<int64_t> id_data;
210+
std::vector<T> score_data;
211+
212+
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
213+
for (Sentence<T>& sentence : sentence_vector_list[src_idx]) {
214+
id_data.insert(id_data.end(), sentence.word_ids.begin(),
215+
sentence.word_ids.end());
216+
score_data.insert(score_data.end(), sentence.scores.begin(),
217+
sentence.scores.end());
218+
sentence_level_lod.push_back(sentence_level_lod.back() +
219+
sentence.word_ids.size());
220+
}
221+
source_level_lod.push_back(source_level_lod.back() +
222+
sentence_vector_list[src_idx].size());
223+
}
224+
225+
auto cpu_place = new paddle::platform::CPUPlace();
226+
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place);
227+
228+
framework::LoD lod;
229+
lod.push_back(source_level_lod);
230+
lod.push_back(sentence_level_lod);
231+
232+
id_tensor->set_lod(lod);
233+
id_tensor->Resize({static_cast<int64_t>(id_data.size())});
234+
id_tensor->mutable_data<int64_t>(paddle::platform::CPUPlace());
235+
id_tensor->CopyFromVector<int64_t>(id_data, cpu_ctx);
236+
237+
score_tensor->set_lod(lod);
238+
score_tensor->Resize({static_cast<int64_t>(score_data.size())});
239+
score_tensor->mutable_data<T>(paddle::platform::CPUPlace());
240+
score_tensor->CopyFromVector<T>(score_data, cpu_ctx);
241+
}
242+
243+
template <typename T>
244+
void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
245+
const LoDTensorArray& step_scores,
246+
LoDTensor* id_tensor,
247+
LoDTensor* score_tensor) const {
248+
PADDLE_ENFORCE(!step_ids.empty(), "step num should be larger than 0");
249+
PADDLE_ENFORCE_EQ(step_ids.size(), step_scores.size(),
250+
"step_ids and step_scores should be the same");
251+
const size_t step_num = step_ids.size();
252+
const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1;
253+
254+
PADDLE_ENFORCE_GT(src_num, 0UL, "source num should be larger than 0");
255+
256+
// previous prefixes for each step,
257+
// the init length is 0, means this is the first step.
258+
std::vector<BeamNodeVector<T>> beamnode_vector_list(0);
259+
std::vector<SentenceVector<T>> sentence_vector_list(src_num);
260+
261+
// pack all steps for one batch first, then another batch
262+
for (size_t step_id = 0; step_id < step_num; ++step_id) {
263+
beamnode_vector_list =
264+
PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id),
265+
beamnode_vector_list, &sentence_vector_list);
266+
}
267+
// append last beam_node to result
268+
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
269+
for (auto& beam_node : beamnode_vector_list.at(src_idx)) {
270+
sentence_vector_list[src_idx].push_back(MakeSentence(beam_node.get()));
271+
beam_node.reset();
272+
}
273+
}
274+
275+
ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor,
276+
score_tensor);
277+
}
278+
279+
} // namespace operators
280+
} // namespace paddle

0 commit comments

Comments
 (0)