@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/fluid/operators/beam_search_decode_op.h "
15
+ #include < algorithm >
16
16
#include < string>
17
+
18
+ #include " paddle/fluid/operators/beam_search_decode_op.h"
17
19
#include " paddle/fluid/platform/device_context.h"
18
20
19
21
namespace paddle {
@@ -22,8 +24,11 @@ namespace operators {
22
24
struct BeamSearchDecodeFunctor {
23
25
BeamSearchDecodeFunctor (const LoDTensorArray& step_ids,
24
26
const LoDTensorArray& step_scores,
25
- LoDTensor* id_tensor, LoDTensor* score_tensor)
26
- : step_ids_origin_(step_ids),
27
+ LoDTensor* id_tensor, LoDTensor* score_tensor,
28
+ size_t beam_size, int end_id)
29
+ : beam_size_(beam_size),
30
+ end_id_ (end_id),
31
+ step_ids_origin_(step_ids),
27
32
step_scores_origin_(step_scores),
28
33
id_tensor_(id_tensor),
29
34
score_tensor_(score_tensor) {
@@ -37,9 +42,11 @@ struct BeamSearchDecodeFunctor {
37
42
// Copy all tensors in the input tensor array
38
43
for (auto & step_id : step_ids_origin_) {
39
44
framework::LoDTensor out;
40
- dev_ctx->Wait ();
41
- framework::TensorCopy (step_id, platform::CPUPlace (), *dev_ctx, &out);
42
- dev_ctx->Wait ();
45
+ if (step_id.numel () > 0 ) {
46
+ dev_ctx->Wait ();
47
+ framework::TensorCopy (step_id, platform::CPUPlace (), *dev_ctx, &out);
48
+ dev_ctx->Wait ();
49
+ }
43
50
44
51
out.set_lod (step_id.lod ());
45
52
step_ids_.push_back (out);
@@ -53,9 +60,12 @@ struct BeamSearchDecodeFunctor {
53
60
// Copy all tensors in the input tensor array
54
61
for (auto & step_score : step_scores_origin_) {
55
62
framework::LoDTensor out;
56
- dev_ctx->Wait ();
57
- framework::TensorCopy (step_score, platform::CPUPlace (), *dev_ctx, &out);
58
- dev_ctx->Wait ();
63
+ if (step_score.numel () > 0 ) {
64
+ dev_ctx->Wait ();
65
+ framework::TensorCopy (step_score, platform::CPUPlace (), *dev_ctx,
66
+ &out);
67
+ dev_ctx->Wait ();
68
+ }
59
69
60
70
out.set_lod (step_score.lod ());
61
71
step_scores_.push_back (out);
@@ -67,6 +77,8 @@ struct BeamSearchDecodeFunctor {
67
77
void operator ()() const ;
68
78
69
79
bool tensor_on_gpu_;
80
+ size_t beam_size_;
81
+ int end_id_;
70
82
const LoDTensorArray& step_ids_origin_;
71
83
const LoDTensorArray& step_scores_origin_;
72
84
LoDTensorArray step_ids_ = LoDTensorArray();
@@ -77,14 +89,14 @@ struct BeamSearchDecodeFunctor {
77
89
78
90
template <typename T>
79
91
void BeamSearchDecodeFunctor::operator ()() const {
80
- BeamSearchDecoder<T> beam_search_decoder;
92
+ BeamSearchDecoder<T> beam_search_decoder (beam_size_, end_id_) ;
81
93
// Check if the tensor is on GPU. If so, use the CPU copy instead
82
94
if (tensor_on_gpu_) {
83
- beam_search_decoder.PackAllSteps (step_ids_, step_scores_, id_tensor_,
84
- score_tensor_);
95
+ beam_search_decoder.Backtrace (step_ids_, step_scores_, id_tensor_,
96
+ score_tensor_);
85
97
} else {
86
- beam_search_decoder.PackAllSteps (step_ids_origin_, step_scores_origin_,
87
- id_tensor_, score_tensor_);
98
+ beam_search_decoder.Backtrace (step_ids_origin_, step_scores_origin_,
99
+ id_tensor_, score_tensor_);
88
100
}
89
101
}
90
102
@@ -122,13 +134,17 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
122
134
" Level of LodTensor should be 2" );
123
135
}
124
136
137
+ size_t beam_size = ctx.Attr <int >(" beam_size" );
138
+ int end_id = ctx.Attr <int >(" end_id" );
139
+
125
140
// prepare output
126
141
LoDTensor* sentenceIds = ctx.Output <LoDTensor>(" SentenceIds" );
127
142
LoDTensor* sentenceScores = ctx.Output <LoDTensor>(" SentenceScores" );
128
143
129
144
framework::VisitDataType (
130
145
framework::ToDataType (scores->at (0 ).type ()),
131
- BeamSearchDecodeFunctor (*ids, *scores, sentenceIds, sentenceScores));
146
+ BeamSearchDecodeFunctor (*ids, *scores, sentenceIds, sentenceScores,
147
+ beam_size, end_id));
132
148
}
133
149
};
134
150
@@ -137,18 +153,32 @@ class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
137
153
void Make () override {
138
154
AddInput (" Ids" ,
139
155
" (LodTensorArray)"
140
- " score of the candidate words in each step " );
156
+ " The LodTensorArray containing the selected ids of all steps " );
141
157
AddInput (" Scores" ,
142
158
" (LodTensorArray)"
143
- " score of the candidate words in each step" );
144
- AddOutput (" SentenceIds" ,
145
- " (LodTensor)"
146
- " All possible result sentences of word ids" );
147
- AddOutput (" SentenceScores" ,
148
- " (LodTensor)"
149
- " All possible result sentences of word scores" );
159
+ " The LodTensorArray containing the selected scores of all steps" );
160
+ AddOutput (
161
+ " SentenceIds" ,
162
+ " (LodTensor)"
163
+ " An LodTensor containing all generated id sequences for all source "
164
+ " sentences" );
165
+ AddOutput (
166
+ " SentenceScores" ,
167
+ " (LodTensor)"
168
+ " An LodTensor containing scores corresponding to Output(SentenceIds)" );
169
+ AddAttr<int >(" beam_size" , " beam size for beam search" );
170
+ AddAttr<int >(" end_id" ,
171
+ " the token id which indicates the end of a sequence" );
150
172
AddComment (R"DOC(
151
- Pack the result of Beam search op into SentenceIds and SentenceScores.
173
+ Beam Search Decode Operator. This Operator constructs the full hypotheses for
174
+ each source sentence by walking back along the LoDTensorArray Input(ids)
175
+ whose lods can be used to restore the path in the beam search tree.
176
+
177
+ The Output(SentenceIds) and Output(SentenceScores) separately contain the
178
+ generated id sequences and the corresponding scores. The shapes and lods of the
179
+ two LodTensor are same. The lod level is 2 and the two levels separately
180
+ indicate how many hypotheses each source sentence has and how many ids each
181
+ hypothesis has.
152
182
)DOC" );
153
183
}
154
184
};
@@ -172,10 +202,12 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
172
202
void operator ()(const framework::OpDesc& op_desc,
173
203
framework::BlockDesc* block) const override {
174
204
for (auto & o : op_desc.Output (" SentenceIds" )) {
175
- block->Var (o)->SetType (framework::proto::VarType::LOD_TENSOR);
205
+ auto & sentence_ids = block->FindRecursiveOrCreateVar (o);
206
+ sentence_ids.SetType (framework::proto::VarType::LOD_TENSOR);
176
207
}
177
208
for (auto & o : op_desc.Output (" SentenceScores" )) {
178
- block->Var (o)->SetType (framework::proto::VarType::LOD_TENSOR);
209
+ auto & sentence_scores = block->FindRecursiveOrCreateVar (o);
210
+ sentence_scores.SetType (framework::proto::VarType::LOD_TENSOR);
179
211
}
180
212
}
181
213
};
0 commit comments