Skip to content

Commit e7d44a2

Browse files
authored
Nmt model (#7340)
neural machine translation model support beam search with while op
1 parent d8b923a commit e7d44a2

File tree

11 files changed

+279
-58
lines changed

11 files changed

+279
-58
lines changed

doc/design/ops/sequence_decoder.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The current `LoDTensor` is designed to store levels of variable-length sequences
2222
The integers in each level represent the begin and end (not inclusive) offset of a sequence **in the underlying tensor**,
2323
let's call this format the **absolute-offset LoD** for clarity.
2424

25-
The relative-offset LoD can retrieve any sequence very quickly but fails to represent empty sequences, for example, a two-level LoD is as follows
25+
The absolute-offset LoD can retrieve any sequence very quickly but fails to represent empty sequences, for example, a two-level LoD is as follows
2626
```python
2727
[[0, 3, 9]
2828
[0, 2, 3, 3, 3, 9]]
@@ -119,7 +119,7 @@ def generate():
119119
encoder_ctx_expanded = pd.lod_expand(encoder_ctx, target_word)
120120
decoder_input = pd.fc(
121121
act=pd.activation.Linear(),
122-
input=[target_word, encoder_ctx],
122+
input=[target_word, encoder_ctx_expanded],
123123
size=3 * decoder_dim)
124124
gru_out, cur_mem = pd.gru_step(
125125
decoder_input, mem=decoder_mem, size=decoder_dim)

paddle/framework/executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
116116

117117
for (auto& op_desc : block.AllOps()) {
118118
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
119-
VLOG(3) << op->DebugStringEx(local_scope);
119+
VLOG(4) << op->DebugStringEx(local_scope);
120120
op->Run(*local_scope, place_);
121+
VLOG(3) << op->DebugStringEx(local_scope);
121122
if (FLAGS_do_memory_benchmark) {
122123
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
123124
<< memory::memory_usage(place_);

paddle/framework/lod_tensor.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ LoD ToAbsOffset(const LoD &in) {
107107
// the lowest level stores relative offsets
108108
if (in.empty() || in.size() == 1) return in;
109109
LoD result = in;
110-
for (int level = result.size() - 2; level >= 0; level--) {
111-
for (auto &ele : result[level]) {
112-
ele = result[level + 1][ele];
110+
for (auto level = static_cast<int>(in.size() - 2); level >= 0; level--) {
111+
for (size_t i = 0; i < in[level].size(); ++i) {
112+
size_t index = in[level][i];
113+
result[level][i] = result[level + 1][index];
113114
}
114115
}
115116
return result;

paddle/operators/beam_search_op.cc

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,18 @@ namespace operators {
2424
void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
2525
framework::LoDTensor *selected_ids,
2626
framework::LoDTensor *selected_scores) {
27+
auto abs_lod = framework::ToAbsOffset(ids_->lod());
28+
auto &high_level = abs_lod[lod_level_];
29+
2730
auto items = SelectTopBeamSizeItems();
28-
auto selected_items = ToMap(items);
31+
auto selected_items = ToMap(items, high_level.back());
32+
VLOG(3) << "selected_items:";
33+
for (size_t i = 0; i < selected_items.size(); ++i) {
34+
VLOG(3) << "offset:" << i;
35+
for (auto &item : selected_items[i]) {
36+
VLOG(3) << ItemToString(item);
37+
}
38+
}
2939
PruneEndidCandidates(pre_ids, &selected_items);
3040
// calculate the output tensor's height
3141
size_t num_instances = std::accumulate(
@@ -63,11 +73,12 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
6373
low_level.push_back(low_offset);
6474

6575
// fill lod
66-
auto abs_lod = framework::ToAbsOffset(ids_->lod());
67-
auto &high_level = abs_lod[lod_level_];
6876
framework::LoD lod(2);
6977
lod[0].assign(high_level.begin(), high_level.end());
7078
lod[1].assign(low_level.begin(), low_level.end());
79+
if (!framework::CheckLoD(lod)) {
80+
PADDLE_THROW("lod %s is not right", framework::LoDToString(lod));
81+
}
7182
selected_ids->set_lod(lod);
7283
selected_scores->set_lod(lod);
7384
}
@@ -90,13 +101,11 @@ int BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
90101
}
91102

92103
std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
93-
const std::vector<std::vector<Item>> &items) {
104+
const std::vector<std::vector<Item>> &items, size_t element_num) {
94105
std::vector<std::vector<Item>> result;
106+
result.resize(element_num);
95107
for (auto &entries : items) {
96108
for (const auto &item : entries) {
97-
if (item.offset >= result.size()) {
98-
result.resize(item.offset + 1);
99-
}
100109
result[item.offset].push_back(item);
101110
}
102111
}
@@ -122,6 +131,14 @@ BeamSearch::SelectTopBeamSizeItems() {
122131
}
123132
result.emplace_back(items);
124133
}
134+
VLOG(3) << "SelectTopBeamSizeItems result size " << result.size();
135+
for (auto &items : result) {
136+
VLOG(3) << "item set:";
137+
for (auto &item : items) {
138+
VLOG(3) << ItemToString(item);
139+
}
140+
}
141+
125142
return result;
126143
}
127144

@@ -159,6 +176,22 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
159176
return true;
160177
}
161178

179+
std::ostream &operator<<(std::ostream &os, const BeamSearch::Item &item) {
180+
os << "{";
181+
os << "offset: " << item.offset << ", ";
182+
os << "id: " << item.id << ", ";
183+
os << "score: " << item.score << "";
184+
os << "}";
185+
186+
return os;
187+
}
188+
189+
std::string ItemToString(const BeamSearch::Item &item) {
190+
std::ostringstream stream;
191+
stream << item;
192+
return stream.str();
193+
}
194+
162195
class BeamSearchProtoAndCheckerMaker
163196
: public framework::OpProtoAndCheckerMaker {
164197
public:
@@ -186,8 +219,40 @@ class BeamSearchProtoAndCheckerMaker
186219
}
187220
};
188221

222+
class BeamSearchInferShape : public framework::InferShapeBase {
223+
public:
224+
void operator()(framework::InferShapeContext *context) const override {
225+
for (const std::string &arg :
226+
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
227+
PADDLE_ENFORCE(context->HasInput(arg),
228+
"BeamSearch need input argument '%s'", arg);
229+
}
230+
for (const std::string &arg :
231+
std::vector<std::string>({"selected_ids", "selected_scores"})) {
232+
PADDLE_ENFORCE(context->HasOutput(arg),
233+
"BeamSearch need output argument '%s'", arg);
234+
}
235+
}
236+
};
237+
238+
class BeamSearchInferVarType : public framework::VarTypeInference {
239+
public:
240+
void operator()(const framework::OpDesc &op_desc,
241+
framework::BlockDesc *block) const override {
242+
for (auto &o : op_desc.Output("selected_ids")) {
243+
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
244+
}
245+
for (auto &o : op_desc.Output("selected_scores")) {
246+
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
247+
}
248+
}
249+
};
250+
189251
} // namespace operators
190252
} // namespace paddle
191253

192-
REGISTER_OP_WITHOUT_GRADIENT(beam_search, paddle::operators::BeamSearchOp,
193-
paddle::operators::BeamSearchProtoAndCheckerMaker);
254+
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
255+
paddle::operators::BeamSearchProtoAndCheckerMaker,
256+
paddle::operators::BeamSearchInferShape,
257+
paddle::operators::BeamSearchInferVarType,
258+
paddle::framework::EmptyGradOpMaker);

paddle/operators/beam_search_op.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,6 @@ class BeamSearch {
136136
void operator()(const framework::LoDTensor& pre_ids,
137137
framework::LoDTensor* selected_ids,
138138
framework::LoDTensor* selected_scores);
139-
140-
protected:
141139
/*
142140
* The basic items help to sort.
143141
*/
@@ -155,6 +153,7 @@ class BeamSearch {
155153
score_t score;
156154
};
157155

156+
protected:
158157
/*
159158
* Delete all the records that follows the end token.
160159
*/
@@ -166,7 +165,7 @@ class BeamSearch {
166165
* NOTE low performance
167166
*/
168167
std::vector<std::vector<Item>> ToMap(
169-
const std::vector<std::vector<Item>>& inputs);
168+
const std::vector<std::vector<Item>>& inputs, size_t element_num);
170169

171170
/*
172171
* For each source, select top beam_size records.
@@ -187,6 +186,10 @@ class BeamSearch {
187186
int end_id_{0};
188187
};
189188

189+
std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
190+
191+
std::string ItemToString(const BeamSearch::Item& item);
192+
190193
class BeamSearchOp : public framework::OperatorBase {
191194
public:
192195
BeamSearchOp(const std::string& type,
@@ -203,7 +206,6 @@ class BeamSearchOp : public framework::OperatorBase {
203206

204207
void Run(const framework::Scope& scope,
205208
const platform::Place& dev_place) const override {
206-
LOG(INFO) << "run beam search op";
207209
auto ids_var = scope.FindVar(Input("ids"));
208210
auto scores_var = scope.FindVar(Input("scores"));
209211
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
@@ -217,10 +219,8 @@ class BeamSearchOp : public framework::OperatorBase {
217219
size_t level = Attr<int>("level");
218220
size_t beam_size = Attr<int>("beam_size");
219221
int end_id = Attr<int>("end_id");
220-
LOG(INFO) << "init beam search";
221222
BeamSearch alg(ids, scores, level, beam_size, end_id);
222223

223-
LOG(INFO) << "after beam search";
224224
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
225225
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
226226
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
@@ -229,9 +229,7 @@ class BeamSearchOp : public framework::OperatorBase {
229229
*selected_ids_var->GetMutable<framework::LoDTensor>();
230230
auto& selected_scores_tensor =
231231
*selected_scores_var->GetMutable<framework::LoDTensor>();
232-
LOG(INFO) << "run beam search";
233232
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
234-
LOG(INFO) << "finish beam search";
235233
}
236234
};
237235

paddle/operators/sequence_expand_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
3232
const T* x_data = x->data<T>();
3333
auto x_dims = x->dims();
3434
auto* y = context.Input<LoDTensor>("Y");
35+
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod");
3536
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]),
3637
y->lod().back().size() - 1,
3738
"The size of last lod level in Input(Y)"

paddle/operators/top_k_op.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace paddle {
2222
namespace operators {
2323

2424
using Tensor = framework::Tensor;
25+
using LoDTensor = framework::LoDTensor;
2526

2627
template <typename T, int MajorType = Eigen::RowMajor,
2728
typename IndexType = Eigen::DenseIndex>
@@ -33,9 +34,9 @@ class TopkKernel : public framework::OpKernel<T> {
3334
void Compute(const framework::ExecutionContext& ctx) const override {
3435
// Get the top k elements of each row of input tensor
3536
// FIXME: only deal with matrix(2d tensor).
36-
auto* input = ctx.Input<Tensor>("X");
37-
auto* output = ctx.Output<Tensor>("Out");
38-
auto* indices = ctx.Output<Tensor>("Indices");
37+
auto* input = ctx.Input<LoDTensor>("X");
38+
auto* output = ctx.Output<LoDTensor>("Out");
39+
auto* indices = ctx.Output<LoDTensor>("Indices");
3940
// k is determined by Attr
4041
const size_t k = static_cast<int>(ctx.Attr<int>("k"));
4142

python/paddle/v2/fluid/layer_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def input_dtype(self, input_param_name='input'):
100100
if dtype is None:
101101
dtype = each.dtype
102102
elif dtype != each.dtype:
103-
raise ValueError("Data Type mismatch")
103+
raise ValueError("Data Type mismatch: %d to %d" %
104+
(dtype, each.dtype))
104105
return dtype
105106

106107
def create_parameter(self,

python/paddle/v2/fluid/layers/control_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def topk(input, k):
769769
array = fluid.layers.topk(x, k)
770770
"""
771771
helper = LayerHelper('topk', **locals())
772-
topk_out = helper.create_tmp_variable(dtype=input.data_type)
772+
topk_out = helper.create_tmp_variable(dtype=input.dtype)
773773
topk_indices = helper.create_tmp_variable(dtype='int64')
774774
helper.append_op(
775775
type='top_k',

python/paddle/v2/fluid/layers/nn.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
'transpose',
6262
'im2sequence',
6363
'nce',
64+
'beam_search',
6465
]
6566

6667

@@ -163,10 +164,8 @@ def fc(input,
163164
tmp = helper.create_tmp_variable(dtype)
164165
helper.append_op(
165166
type="mul",
166-
inputs={
167-
"X": input_var,
168-
"Y": w,
169-
},
167+
inputs={"X": input_var,
168+
"Y": w},
170169
outputs={"Out": tmp},
171170
attrs={"x_num_col_dims": num_flatten_dims,
172171
"y_num_col_dims": 1})
@@ -1551,6 +1550,38 @@ def sequence_expand(x, y, name=None):
15511550
return tmp
15521551

15531552

1553+
def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
1554+
'''
1555+
This function implements the beam search algorithm.
1556+
'''
1557+
helper = LayerHelper('beam_search', **locals())
1558+
score_type = scores.dtype
1559+
id_type = ids.dtype
1560+
1561+
selected_scores = helper.create_tmp_variable(dtype=score_type)
1562+
selected_ids = helper.create_tmp_variable(dtype=id_type)
1563+
1564+
helper.append_op(
1565+
type='beam_search',
1566+
inputs={
1567+
'pre_ids': pre_ids,
1568+
'ids': ids,
1569+
'scores': scores,
1570+
},
1571+
outputs={
1572+
'selected_ids': selected_ids,
1573+
'selected_scores': selected_scores,
1574+
},
1575+
attrs={
1576+
# TODO(ChunweiYan) to assure other value support
1577+
'level': level,
1578+
'beam_size': beam_size,
1579+
'end_id': end_id,
1580+
})
1581+
1582+
return selected_ids, selected_scores
1583+
1584+
15541585
def lstm_unit(x_t,
15551586
hidden_t_prev,
15561587
cell_t_prev,

0 commit comments

Comments
 (0)