Skip to content

Commit 7345de3

Browse files
authored
Beam search decode op python (#5631)
* fix lod_tensor_array * init test beam search decode op * add test_beam_search_decode_op
1 parent 85b839f commit 7345de3

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

paddle/operators/beam_search_decode_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
2727
void Run(const framework::Scope& scope,
2828
const platform::DeviceContext& dev_ctx) const override {
2929
framework::ExecutionContext ctx(*this, scope, dev_ctx);
30+
3031
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
3132
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
3233
const size_t step_num = ids->size();

python/paddle/v2/framework/layers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,23 @@ def batch_norm(input,
839839
return helper.append_activation(batch_norm_out)
840840

841841

842+
def beam_search_decode(ids, scores, main_program=None, startup_program=None):
843+
helper = LayerHelper('beam_search_decode', **locals())
844+
sentence_ids = helper.create_tmp_variable(dtype=ids.data_type)
845+
sentence_scores = helper.create_tmp_variable(dtype=ids.data_type)
846+
847+
helper.append_op(
848+
type="beam_search_decode",
849+
inputs={"Ids": ids,
850+
"Scores": scores},
851+
outputs={
852+
"SentenceIds": sentence_ids,
853+
"SentenceScores": sentence_scores
854+
})
855+
856+
return sentence_ids, sentence_scores
857+
858+
842859
class BlockGuard(object):
843860
"""
844861
BlockGuard class.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import unittest
2+
3+
import numpy as np
4+
import paddle.v2.framework.core as core
5+
from paddle.v2.framework.op import Operator
6+
7+
8+
class TestBeamSearchDecodeOp(unittest.TestCase):
9+
def setUp(self):
10+
self.scope = core.Scope()
11+
self.cpu_place = core.CPUPlace()
12+
13+
def append_lod_tensor(self, tensor_array, lod, data):
14+
lod_tensor = core.LoDTensor()
15+
lod_tensor.set_lod(lod)
16+
lod_tensor.set(data, self.cpu_place)
17+
tensor_array.append(lod_tensor)
18+
19+
def test_get_set(self):
20+
ids = self.scope.var("ids").get_lod_tensor_array()
21+
self.append_lod_tensor(
22+
ids, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
23+
np.array(
24+
[1, 2, 3, 4, 5, 6], dtype="int64"))
25+
self.append_lod_tensor(
26+
ids, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]],
27+
np.array(
28+
[0, 1, 2, 3, 4, 5], dtype="int64"))
29+
self.append_lod_tensor(
30+
ids, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]],
31+
np.array(
32+
[0, 1, 2, 3, 4], dtype="int64"))
33+
34+
scores = self.scope.var("scores").get_lod_tensor_array()
35+
self.append_lod_tensor(
36+
scores, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
37+
np.array(
38+
[1, 2, 3, 4, 5, 6], dtype="float32"))
39+
self.append_lod_tensor(
40+
scores, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]],
41+
np.array(
42+
[0, 1, 2, 3, 4, 5], dtype="float32"))
43+
self.append_lod_tensor(
44+
scores, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]],
45+
np.array(
46+
[0, 1, 2, 3, 4], dtype="float32"))
47+
48+
sentence_ids = self.scope.var("sentence_ids").get_tensor()
49+
sentence_scores = self.scope.var("sentence_scores").get_tensor()
50+
51+
beam_search_decode_op = Operator(
52+
"beam_search_decode",
53+
# inputs
54+
Ids="ids",
55+
Scores="scores",
56+
# outputs
57+
SentenceIds="sentence_ids",
58+
SentenceScores="sentence_scores")
59+
60+
ctx = core.DeviceContext.create(self.cpu_place)
61+
beam_search_decode_op.run(self.scope, ctx)
62+
63+
expected_lod = [[0, 4, 8], [0, 1, 3, 6, 9, 10, 13, 16, 19]]
64+
self.assertEqual(sentence_ids.lod(), expected_lod)
65+
self.assertEqual(sentence_scores.lod(), expected_lod)
66+
67+
expected_data = np.array(
68+
[2, 1, 0, 3, 1, 0, 3, 2, 1, 5, 4, 3, 2, 4, 4, 3, 6, 5, 4], "int64")
69+
self.assertTrue(np.array_equal(np.array(sentence_ids), expected_data))
70+
self.assertTrue(
71+
np.array_equal(np.array(sentence_scores), expected_data))
72+
73+
74+
if __name__ == '__main__':
75+
unittest.main()

0 commit comments

Comments
 (0)