Skip to content

Commit 65c859d

Browse files
authored
beam_search_decode support multi data type (#5847)
* beam_search_decode support multi data type * add VisitDataType for beam search decode * use Specialization to handle bool * move Specialization of BeamSearchDecodeFunctor out of class
1 parent 3a76062 commit 65c859d

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

paddle/operators/beam_search_decode_op.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
struct BeamSearchDecodeFunctor {
21+
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
22+
const LoDTensorArray& step_scores,
23+
LoDTensor* id_tensor, LoDTensor* score_tensor)
24+
: step_ids_(step_ids),
25+
step_scores_(step_scores),
26+
id_tensor_(id_tensor),
27+
score_tensor_(score_tensor) {}
28+
29+
template <typename T>
30+
void operator()() const;
31+
32+
const LoDTensorArray& step_ids_;
33+
const LoDTensorArray& step_scores_;
34+
LoDTensor* id_tensor_;
35+
LoDTensor* score_tensor_;
36+
};
37+
38+
template <typename T>
39+
void BeamSearchDecodeFunctor::operator()() const {
40+
BeamSearchDecoder<T> beam_search_decoder;
41+
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
42+
score_tensor_);
43+
}
44+
45+
template <>
46+
void BeamSearchDecodeFunctor::operator()<bool>() const {
47+
PADDLE_THROW("beam search decode op does not support bool!");
48+
}
49+
2050
class BeamSearchDecodeOp : public framework::OperatorBase {
2151
public:
2252
BeamSearchDecodeOp(const std::string& type,
@@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
4575
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
4676
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
4777

48-
BeamSearchDecoder<float> beam_search_decoder;
49-
beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds,
50-
sentenceScores);
78+
framework::VisitDataType(
79+
framework::ToDataType(scores->at(0).type()),
80+
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores));
5181
}
5282
};
5383

python/paddle/v2/fluid/tests/test_beam_search_decode_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def test_get_set(self):
3535
self.append_lod_tensor(
3636
scores, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
3737
np.array(
38-
[1, 2, 3, 4, 5, 6], dtype="float32"))
38+
[1, 2, 3, 4, 5, 6], dtype="float64"))
3939
self.append_lod_tensor(
4040
scores, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]],
4141
np.array(
42-
[0, 1, 2, 3, 4, 5], dtype="float32"))
42+
[0, 1, 2, 3, 4, 5], dtype="float64"))
4343
self.append_lod_tensor(
4444
scores, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]],
4545
np.array(
46-
[0, 1, 2, 3, 4], dtype="float32"))
46+
[0, 1, 2, 3, 4], dtype="float64"))
4747

4848
sentence_ids = self.scope.var("sentence_ids").get_tensor()
4949
sentence_scores = self.scope.var("sentence_scores").get_tensor()

0 commit comments

Comments
 (0)