Skip to content

Commit 8a8df06

Browse files
committed
Merge pull request #11238 from guoshengCS/fix-beam_search
Fix and enhance beam_search_op and beam_searc_decode_op
1 parent 4b8d65a commit 8a8df06

File tree

12 files changed

+545
-507
lines changed

12 files changed

+545
-507
lines changed

paddle/fluid/operators/beam_search_decode_op.cc

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/beam_search_decode_op.h"
15+
#include <algorithm>
1616
#include <string>
17+
18+
#include "paddle/fluid/operators/beam_search_decode_op.h"
1719
#include "paddle/fluid/platform/device_context.h"
1820

1921
namespace paddle {
@@ -22,8 +24,11 @@ namespace operators {
2224
struct BeamSearchDecodeFunctor {
2325
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
2426
const LoDTensorArray& step_scores,
25-
LoDTensor* id_tensor, LoDTensor* score_tensor)
26-
: step_ids_origin_(step_ids),
27+
LoDTensor* id_tensor, LoDTensor* score_tensor,
28+
size_t beam_size, int end_id)
29+
: beam_size_(beam_size),
30+
end_id_(end_id),
31+
step_ids_origin_(step_ids),
2732
step_scores_origin_(step_scores),
2833
id_tensor_(id_tensor),
2934
score_tensor_(score_tensor) {
@@ -37,9 +42,11 @@ struct BeamSearchDecodeFunctor {
3742
// Copy all tensors in the input tensor array
3843
for (auto& step_id : step_ids_origin_) {
3944
framework::LoDTensor out;
40-
dev_ctx->Wait();
41-
framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
42-
dev_ctx->Wait();
45+
if (step_id.numel() > 0) {
46+
dev_ctx->Wait();
47+
framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
48+
dev_ctx->Wait();
49+
}
4350

4451
out.set_lod(step_id.lod());
4552
step_ids_.push_back(out);
@@ -53,9 +60,12 @@ struct BeamSearchDecodeFunctor {
5360
// Copy all tensors in the input tensor array
5461
for (auto& step_score : step_scores_origin_) {
5562
framework::LoDTensor out;
56-
dev_ctx->Wait();
57-
framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx, &out);
58-
dev_ctx->Wait();
63+
if (step_score.numel() > 0) {
64+
dev_ctx->Wait();
65+
framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx,
66+
&out);
67+
dev_ctx->Wait();
68+
}
5969

6070
out.set_lod(step_score.lod());
6171
step_scores_.push_back(out);
@@ -67,6 +77,8 @@ struct BeamSearchDecodeFunctor {
6777
void operator()() const;
6878

6979
bool tensor_on_gpu_;
80+
size_t beam_size_;
81+
int end_id_;
7082
const LoDTensorArray& step_ids_origin_;
7183
const LoDTensorArray& step_scores_origin_;
7284
LoDTensorArray step_ids_ = LoDTensorArray();
@@ -77,14 +89,14 @@ struct BeamSearchDecodeFunctor {
7789

7890
template <typename T>
7991
void BeamSearchDecodeFunctor::operator()() const {
80-
BeamSearchDecoder<T> beam_search_decoder;
92+
BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
8193
// Check if the tensor is on GPU. If so, use the CPU copy instead
8294
if (tensor_on_gpu_) {
83-
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
84-
score_tensor_);
95+
beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_,
96+
score_tensor_);
8597
} else {
86-
beam_search_decoder.PackAllSteps(step_ids_origin_, step_scores_origin_,
87-
id_tensor_, score_tensor_);
98+
beam_search_decoder.Backtrace(step_ids_origin_, step_scores_origin_,
99+
id_tensor_, score_tensor_);
88100
}
89101
}
90102

@@ -122,13 +134,17 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
122134
"Level of LodTensor should be 2");
123135
}
124136

137+
size_t beam_size = ctx.Attr<int>("beam_size");
138+
int end_id = ctx.Attr<int>("end_id");
139+
125140
// prepare output
126141
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
127142
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
128143

129144
framework::VisitDataType(
130145
framework::ToDataType(scores->at(0).type()),
131-
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores));
146+
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
147+
beam_size, end_id));
132148
}
133149
};
134150

@@ -137,18 +153,32 @@ class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
137153
void Make() override {
138154
AddInput("Ids",
139155
"(LodTensorArray)"
140-
"score of the candidate words in each step");
156+
"The LodTensorArray containing the selected ids of all steps");
141157
AddInput("Scores",
142158
"(LodTensorArray)"
143-
"score of the candidate words in each step");
144-
AddOutput("SentenceIds",
145-
"(LodTensor)"
146-
"All possible result sentences of word ids");
147-
AddOutput("SentenceScores",
148-
"(LodTensor)"
149-
"All possible result sentences of word scores");
159+
"The LodTensorArray containing the selected scores of all steps");
160+
AddOutput(
161+
"SentenceIds",
162+
"(LodTensor)"
163+
"An LodTensor containing all generated id sequences for all source "
164+
"sentences");
165+
AddOutput(
166+
"SentenceScores",
167+
"(LodTensor)"
168+
"An LodTensor containing scores corresponding to Output(SentenceIds)");
169+
AddAttr<int>("beam_size", "beam size for beam search");
170+
AddAttr<int>("end_id",
171+
"the token id which indicates the end of a sequence");
150172
AddComment(R"DOC(
151-
Pack the result of Beam search op into SentenceIds and SentenceScores.
173+
Beam Search Decode Operator. This Operator constructs the full hypotheses for
174+
each source sentence by walking back along the LoDTensorArray Input(ids)
175+
whose lods can be used to restore the path in the beam search tree.
176+
177+
The Output(SentenceIds) and Output(SentenceScores) separately contain the
178+
generated id sequences and the corresponding scores. The shapes and lods of the
179+
two LodTensor are same. The lod level is 2 and the two levels separately
180+
indicate how many hypotheses each source sentence has and how many ids each
181+
hypothesis has.
152182
)DOC");
153183
}
154184
};
@@ -172,10 +202,12 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
172202
void operator()(const framework::OpDesc& op_desc,
173203
framework::BlockDesc* block) const override {
174204
for (auto& o : op_desc.Output("SentenceIds")) {
175-
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
205+
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
206+
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
176207
}
177208
for (auto& o : op_desc.Output("SentenceScores")) {
178-
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
209+
auto& sentence_scores = block->FindRecursiveOrCreateVar(o);
210+
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
179211
}
180212
}
181213
};

0 commit comments

Comments
 (0)