Skip to content

Commit 36fcc95

Browse files
authored
Nmt decoder train (#6367)
* init decoder_trainer * can run * fix lod * add sharelod to cross_entropy_grad_op * add avg_cost to fetch list * modify learning rate * can run * optimie code * add early exit * fix print * revert test_understand_sentiment_conv.py * add act to fc
1 parent 7d85b6d commit 36fcc95

File tree

5 files changed

+80
-58
lines changed

5 files changed

+80
-58
lines changed

paddle/framework/op_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
5959
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
6060
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
6161
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
62-
VLOG(3) << "input " << in << "is not LodTensor";
62+
VLOG(3) << "input " << in << " is not LodTensor";
6363
return;
6464
}
6565
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,

paddle/operators/concat_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,18 @@ class ConcatOp : public framework::OperatorWithKernel {
4141
for (size_t j = 0; j < in_zero_dims_size; j++) {
4242
if (j == axis) {
4343
out_dims[axis] += ins[i][j];
44-
continue;
44+
} else {
45+
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
46+
"Input tensors should have the same "
47+
"elements except the specify axis.");
4548
}
46-
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
47-
"Input tensors should have the same "
48-
"elements except the specify axis.");
4949
}
5050
}
51+
if (out_dims[axis] < 0) {
52+
out_dims[axis] = -1;
53+
}
5154
ctx->SetOutputDim("Out", out_dims);
55+
ctx->ShareLoD("X", /*->*/ "Out");
5256
}
5357
};
5458

paddle/operators/cross_entropy_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
9595
"Input(Label) should be 1.");
9696
}
9797
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
98+
ctx->ShareLoD("X", framework::GradVarName("X"));
9899
}
99100

100101
protected:

python/paddle/v2/fluid/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ def infer_and_check_dtype(op_proto, **kwargs):
430430
dtype = each.dtype
431431
elif dtype != each.dtype:
432432
raise ValueError(
433-
"operator {0} must input same dtype".format(op_type))
433+
"operator {0} must input same dtype. {1} vs {2}".format(
434+
op_type, dtype, each.dtype))
434435

435436
return dtype
436437

python/paddle/v2/fluid/tests/book/test_machine_translation.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,62 @@
11
import numpy as np
22
import paddle.v2 as paddle
3-
import paddle.v2.dataset.conll05 as conll05
3+
import paddle.v2.fluid as fluid
44
import paddle.v2.fluid.core as core
55
import paddle.v2.fluid.framework as framework
66
import paddle.v2.fluid.layers as layers
7-
from paddle.v2.fluid.executor import Executor, g_scope
8-
from paddle.v2.fluid.optimizer import SGDOptimizer
9-
import paddle.v2.fluid as fluid
10-
import paddle.v2.fluid.layers as pd
7+
from paddle.v2.fluid.executor import Executor
118

129
dict_size = 30000
1310
source_dict_dim = target_dict_dim = dict_size
1411
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
15-
hidden_dim = 512
16-
word_dim = 512
12+
hidden_dim = 32
13+
word_dim = 16
1714
IS_SPARSE = True
18-
batch_size = 50
15+
batch_size = 10
1916
max_length = 50
2017
topk_size = 50
2118
trg_dic_size = 10000
2219

23-
src_word_id = layers.data(name="src_word_id", shape=[1], dtype='int64')
24-
src_embedding = layers.embedding(
25-
input=src_word_id,
26-
size=[dict_size, word_dim],
27-
dtype='float32',
28-
is_sparse=IS_SPARSE,
29-
param_attr=fluid.ParamAttr(name='vemb'))
30-
31-
32-
def encoder():
33-
34-
lstm_hidden0, lstm_0 = layers.dynamic_lstm(
35-
input=src_embedding,
36-
size=hidden_dim,
37-
candidate_activation='sigmoid',
38-
cell_activation='sigmoid')
39-
40-
lstm_hidden1, lstm_1 = layers.dynamic_lstm(
41-
input=src_embedding,
42-
size=hidden_dim,
43-
candidate_activation='sigmoid',
44-
cell_activation='sigmoid',
45-
is_reverse=True)
46-
47-
bidirect_lstm_out = layers.concat([lstm_hidden0, lstm_hidden1], axis=0)
48-
49-
return bidirect_lstm_out
50-
51-
52-
def decoder_trainer(context):
53-
'''
54-
decoder with trainer
55-
'''
56-
pass
20+
decoder_size = hidden_dim
21+
22+
23+
def encoder_decoder():
24+
# encoder
25+
src_word_id = layers.data(
26+
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
27+
src_embedding = layers.embedding(
28+
input=src_word_id,
29+
size=[dict_size, word_dim],
30+
dtype='float32',
31+
is_sparse=IS_SPARSE,
32+
param_attr=fluid.ParamAttr(name='vemb'))
33+
34+
fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
35+
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
36+
encoder_out = layers.sequence_pool(input=lstm_hidden0, pool_type="last")
37+
38+
# decoder
39+
trg_language_word = layers.data(
40+
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
41+
trg_embedding = layers.embedding(
42+
input=trg_language_word,
43+
size=[dict_size, word_dim],
44+
dtype='float32',
45+
is_sparse=IS_SPARSE,
46+
param_attr=fluid.ParamAttr(name='vemb'))
47+
48+
rnn = fluid.layers.DynamicRNN()
49+
with rnn.block():
50+
current_word = rnn.step_input(trg_embedding)
51+
mem = rnn.memory(init=encoder_out)
52+
fc1 = fluid.layers.fc(input=[current_word, mem],
53+
size=decoder_size,
54+
act='tanh')
55+
out = fluid.layers.fc(input=fc1, size=target_dict_dim, act='softmax')
56+
rnn.update_memory(mem, fc1)
57+
rnn.output(out)
58+
59+
return rnn()
5760

5861

5962
def to_lodtensor(data, place):
@@ -72,13 +75,18 @@ def to_lodtensor(data, place):
7275

7376

7477
def main():
75-
encoder_out = encoder()
76-
# TODO(jacquesqiao) call here
77-
decoder_trainer(encoder_out)
78+
rnn_out = encoder_decoder()
79+
label = layers.data(
80+
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
81+
cost = layers.cross_entropy(input=rnn_out, label=label)
82+
avg_cost = fluid.layers.mean(x=cost)
83+
84+
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
85+
optimizer.minimize(avg_cost)
7886

7987
train_data = paddle.batch(
8088
paddle.reader.shuffle(
81-
paddle.dataset.wmt14.train(8000), buf_size=1000),
89+
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
8290
batch_size=batch_size)
8391

8492
place = core.CPUPlace()
@@ -88,15 +96,23 @@ def main():
8896

8997
batch_id = 0
9098
for pass_id in xrange(2):
91-
print 'pass_id', pass_id
9299
for data in train_data():
93-
print 'batch', batch_id
94-
batch_id += 1
95-
if batch_id > 10: break
96100
word_data = to_lodtensor(map(lambda x: x[0], data), place)
101+
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
102+
trg_word_next = to_lodtensor(map(lambda x: x[2], data), place)
97103
outs = exe.run(framework.default_main_program(),
98-
feed={'src_word_id': word_data, },
99-
fetch_list=[encoder_out])
104+
feed={
105+
'src_word_id': word_data,
106+
'target_language_word': trg_word,
107+
'target_language_next_word': trg_word_next
108+
},
109+
fetch_list=[avg_cost])
110+
avg_cost_val = np.array(outs[0])
111+
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
112+
" avg_cost=" + str(avg_cost_val))
113+
if batch_id > 3:
114+
exit(0)
115+
batch_id += 1
100116

101117

102118
if __name__ == '__main__':

0 commit comments

Comments
 (0)