Skip to content

Commit 5466eff

Browse files
author
kavyasrinet
authored
Adding the distributed implementation for machine translation (#7751)
* Adding the distributed implementation for machine translation * re-running CI * Updated the code style
1 parent 917b10b commit 5466eff

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle.v2 as paddle
17+
import paddle.v2.fluid as fluid
18+
import paddle.v2.fluid.core as core
19+
import paddle.v2.fluid.framework as framework
20+
import paddle.v2.fluid.layers as layers
21+
from paddle.v2.fluid.executor import Executor
22+
import os
23+
24+
dict_size = 30000
25+
source_dict_dim = target_dict_dim = dict_size
26+
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
27+
hidden_dim = 32
28+
word_dim = 16
29+
IS_SPARSE = True
30+
batch_size = 10
31+
max_length = 50
32+
topk_size = 50
33+
trg_dic_size = 10000
34+
35+
decoder_size = hidden_dim
36+
37+
38+
def encoder_decoder():
39+
# encoder
40+
src_word_id = layers.data(
41+
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
42+
src_embedding = layers.embedding(
43+
input=src_word_id,
44+
size=[dict_size, word_dim],
45+
dtype='float32',
46+
is_sparse=IS_SPARSE,
47+
param_attr=fluid.ParamAttr(name='vemb'))
48+
49+
fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
50+
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
51+
encoder_out = layers.sequence_last_step(input=lstm_hidden0)
52+
53+
# decoder
54+
trg_language_word = layers.data(
55+
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
56+
trg_embedding = layers.embedding(
57+
input=trg_language_word,
58+
size=[dict_size, word_dim],
59+
dtype='float32',
60+
is_sparse=IS_SPARSE,
61+
param_attr=fluid.ParamAttr(name='vemb'))
62+
63+
rnn = fluid.layers.DynamicRNN()
64+
with rnn.block():
65+
current_word = rnn.step_input(trg_embedding)
66+
mem = rnn.memory(init=encoder_out)
67+
fc1 = fluid.layers.fc(input=[current_word, mem],
68+
size=decoder_size,
69+
act='tanh')
70+
out = fluid.layers.fc(input=fc1, size=target_dict_dim, act='softmax')
71+
rnn.update_memory(mem, fc1)
72+
rnn.output(out)
73+
74+
return rnn()
75+
76+
77+
def to_lodtensor(data, place):
78+
seq_lens = [len(seq) for seq in data]
79+
cur_len = 0
80+
lod = [cur_len]
81+
for l in seq_lens:
82+
cur_len += l
83+
lod.append(cur_len)
84+
flattened_data = np.concatenate(data, axis=0).astype("int64")
85+
flattened_data = flattened_data.reshape([len(flattened_data), 1])
86+
res = core.LoDTensor()
87+
res.set(flattened_data, place)
88+
res.set_lod([lod])
89+
return res
90+
91+
92+
def main():
93+
rnn_out = encoder_decoder()
94+
label = layers.data(
95+
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
96+
cost = layers.cross_entropy(input=rnn_out, label=label)
97+
avg_cost = fluid.layers.mean(x=cost)
98+
99+
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
100+
optimize_ops, params_grads = optimizer.minimize(avg_cost)
101+
102+
train_data = paddle.batch(
103+
paddle.reader.shuffle(
104+
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
105+
batch_size=batch_size)
106+
107+
place = core.CPUPlace()
108+
exe = Executor(place)
109+
110+
t = fluid.DistributeTranspiler()
111+
# all parameter server endpoints list for spliting parameters
112+
pserver_endpoints = os.getenv("PSERVERS")
113+
# server endpoint for current node
114+
current_endpoint = os.getenv("SERVER_ENDPOINT")
115+
# run as trainer or parameter server
116+
training_role = os.getenv(
117+
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
118+
t.transpile(
119+
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
120+
121+
if training_role == "PSERVER":
122+
if not current_endpoint:
123+
print("need env SERVER_ENDPOINT")
124+
exit(1)
125+
pserver_prog = t.get_pserver_program(current_endpoint)
126+
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
127+
exe.run(pserver_startup)
128+
exe.run(pserver_prog)
129+
elif training_role == "TRAINER":
130+
trainer_prog = t.get_trainer_program()
131+
exe.run(framework.default_startup_program())
132+
133+
batch_id = 0
134+
for pass_id in xrange(2):
135+
for data in train_data():
136+
word_data = to_lodtensor(map(lambda x: x[0], data), place)
137+
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
138+
trg_word_next = to_lodtensor(map(lambda x: x[2], data), place)
139+
outs = exe.run(trainer_prog,
140+
feed={
141+
'src_word_id': word_data,
142+
'target_language_word': trg_word,
143+
'target_language_next_word': trg_word_next
144+
},
145+
fetch_list=[avg_cost])
146+
avg_cost_val = np.array(outs[0])
147+
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
148+
" avg_cost=" + str(avg_cost_val))
149+
if batch_id > 3:
150+
exit(0)
151+
batch_id += 1
152+
else:
153+
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
154+
155+
156+
if __name__ == '__main__':
157+
main()

0 commit comments

Comments
 (0)