Skip to content

Commit d4c2164

Browse files
authored
Merge pull request #10895 from nickyfantasy/high_level_api_machine_translation
Simplify Machine Translation demo by using Trainer API
2 parents 87ff95d + 5b9d09d commit d4c2164

File tree

3 files changed

+327
-0
lines changed

3 files changed

+327
-0
lines changed

python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ add_subdirectory(understand_sentiment)
1313
add_subdirectory(label_semantic_roles)
1414
add_subdirectory(word2vec)
1515
add_subdirectory(recommender_system)
16+
add_subdirectory(machine_translation)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
2+
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
3+
4+
# default test
5+
foreach(src ${TEST_OPS})
6+
py_test(${src} SRCS ${src}.py)
7+
endforeach()
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import contextlib
15+
16+
import numpy as np
17+
import paddle
18+
import paddle.fluid as fluid
19+
import paddle.fluid.framework as framework
20+
import paddle.fluid.layers as pd
21+
from paddle.fluid.executor import Executor
22+
from functools import partial
23+
import unittest
24+
import os
25+
26+
dict_size = 30000
27+
source_dict_dim = target_dict_dim = dict_size
28+
hidden_dim = 32
29+
word_dim = 16
30+
batch_size = 2
31+
max_length = 8
32+
topk_size = 50
33+
trg_dic_size = 10000
34+
beam_size = 2
35+
36+
decoder_size = hidden_dim
37+
38+
39+
def encoder(is_sparse):
40+
# encoder
41+
src_word_id = pd.data(
42+
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
43+
src_embedding = pd.embedding(
44+
input=src_word_id,
45+
size=[dict_size, word_dim],
46+
dtype='float32',
47+
is_sparse=is_sparse,
48+
param_attr=fluid.ParamAttr(name='vemb'))
49+
50+
fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
51+
lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
52+
encoder_out = pd.sequence_last_step(input=lstm_hidden0)
53+
return encoder_out
54+
55+
56+
def decoder_train(context, is_sparse):
57+
# decoder
58+
trg_language_word = pd.data(
59+
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
60+
trg_embedding = pd.embedding(
61+
input=trg_language_word,
62+
size=[dict_size, word_dim],
63+
dtype='float32',
64+
is_sparse=is_sparse,
65+
param_attr=fluid.ParamAttr(name='vemb'))
66+
67+
rnn = pd.DynamicRNN()
68+
with rnn.block():
69+
current_word = rnn.step_input(trg_embedding)
70+
pre_state = rnn.memory(init=context)
71+
current_state = pd.fc(input=[current_word, pre_state],
72+
size=decoder_size,
73+
act='tanh')
74+
75+
current_score = pd.fc(input=current_state,
76+
size=target_dict_dim,
77+
act='softmax')
78+
rnn.update_memory(pre_state, current_state)
79+
rnn.output(current_score)
80+
81+
return rnn()
82+
83+
84+
def decoder_decode(context, is_sparse):
85+
init_state = context
86+
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
87+
counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
88+
89+
# fill the first element with init_state
90+
state_array = pd.create_array('float32')
91+
pd.array_write(init_state, array=state_array, i=counter)
92+
93+
# ids, scores as memory
94+
ids_array = pd.create_array('int64')
95+
scores_array = pd.create_array('float32')
96+
97+
init_ids = pd.data(name="init_ids", shape=[1], dtype="int64", lod_level=2)
98+
init_scores = pd.data(
99+
name="init_scores", shape=[1], dtype="float32", lod_level=2)
100+
101+
pd.array_write(init_ids, array=ids_array, i=counter)
102+
pd.array_write(init_scores, array=scores_array, i=counter)
103+
104+
cond = pd.less_than(x=counter, y=array_len)
105+
106+
while_op = pd.While(cond=cond)
107+
with while_op.block():
108+
pre_ids = pd.array_read(array=ids_array, i=counter)
109+
pre_state = pd.array_read(array=state_array, i=counter)
110+
pre_score = pd.array_read(array=scores_array, i=counter)
111+
112+
# expand the lod of pre_state to be the same with pre_score
113+
pre_state_expanded = pd.sequence_expand(pre_state, pre_score)
114+
115+
pre_ids_emb = pd.embedding(
116+
input=pre_ids,
117+
size=[dict_size, word_dim],
118+
dtype='float32',
119+
is_sparse=is_sparse)
120+
121+
# use rnn unit to update rnn
122+
current_state = pd.fc(input=[pre_state_expanded, pre_ids_emb],
123+
size=decoder_size,
124+
act='tanh')
125+
current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
126+
# use score to do beam search
127+
current_score = pd.fc(input=current_state_with_lod,
128+
size=target_dict_dim,
129+
act='softmax')
130+
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
131+
selected_ids, selected_scores = pd.beam_search(
132+
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
133+
134+
pd.increment(x=counter, value=1, in_place=True)
135+
136+
# update the memories
137+
pd.array_write(current_state, array=state_array, i=counter)
138+
pd.array_write(selected_ids, array=ids_array, i=counter)
139+
pd.array_write(selected_scores, array=scores_array, i=counter)
140+
141+
pd.less_than(x=counter, y=array_len, cond=cond)
142+
143+
translation_ids, translation_scores = pd.beam_search_decode(
144+
ids=ids_array, scores=scores_array)
145+
146+
# return init_ids, init_scores
147+
148+
return translation_ids, translation_scores
149+
150+
151+
def set_init_lod(data, lod, place):
152+
res = fluid.LoDTensor()
153+
res.set(data, place)
154+
res.set_lod(lod)
155+
return res
156+
157+
158+
def to_lodtensor(data, place):
159+
seq_lens = [len(seq) for seq in data]
160+
cur_len = 0
161+
lod = [cur_len]
162+
for l in seq_lens:
163+
cur_len += l
164+
lod.append(cur_len)
165+
flattened_data = np.concatenate(data, axis=0).astype("int64")
166+
flattened_data = flattened_data.reshape([len(flattened_data), 1])
167+
res = fluid.LoDTensor()
168+
res.set(flattened_data, place)
169+
res.set_lod([lod])
170+
return res
171+
172+
173+
def train_program(is_sparse):
174+
context = encoder(is_sparse)
175+
rnn_out = decoder_train(context, is_sparse)
176+
label = pd.data(
177+
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
178+
cost = pd.cross_entropy(input=rnn_out, label=label)
179+
avg_cost = pd.mean(cost)
180+
return avg_cost
181+
182+
183+
def train(use_cuda, is_sparse, is_local=True):
184+
EPOCH_NUM = 1
185+
186+
if use_cuda and not fluid.core.is_compiled_with_cuda():
187+
return
188+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
189+
190+
train_reader = paddle.batch(
191+
paddle.reader.shuffle(
192+
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
193+
batch_size=batch_size)
194+
195+
feed_order = [
196+
'src_word_id', 'target_language_word', 'target_language_next_word'
197+
]
198+
199+
def event_handler(event):
200+
if isinstance(event, fluid.EndStepEvent):
201+
print('pass_id=' + str(event.epoch) + ' batch=' + str(event.step))
202+
if event.step == 10:
203+
trainer.stop()
204+
205+
trainer = fluid.Trainer(
206+
train_func=partial(train_program, is_sparse),
207+
optimizer=fluid.optimizer.Adagrad(
208+
learning_rate=1e-4,
209+
regularization=fluid.regularizer.L2DecayRegularizer(
210+
regularization_coeff=0.1)),
211+
place=place)
212+
213+
trainer.train(
214+
reader=train_reader,
215+
num_epochs=EPOCH_NUM,
216+
event_handler=event_handler,
217+
feed_order=feed_order)
218+
219+
220+
def decode_main(use_cuda, is_sparse):
221+
222+
if use_cuda and not fluid.core.is_compiled_with_cuda():
223+
return
224+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
225+
226+
context = encoder(is_sparse)
227+
translation_ids, translation_scores = decoder_decode(context, is_sparse)
228+
229+
exe = Executor(place)
230+
exe.run(framework.default_startup_program())
231+
232+
init_ids_data = np.array([1 for _ in range(batch_size)], dtype='int64')
233+
init_scores_data = np.array(
234+
[1. for _ in range(batch_size)], dtype='float32')
235+
init_ids_data = init_ids_data.reshape((batch_size, 1))
236+
init_scores_data = init_scores_data.reshape((batch_size, 1))
237+
init_lod = [i for i in range(batch_size)] + [batch_size]
238+
init_lod = [init_lod, init_lod]
239+
240+
train_data = paddle.batch(
241+
paddle.reader.shuffle(
242+
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
243+
batch_size=batch_size)
244+
for _, data in enumerate(train_data()):
245+
init_ids = set_init_lod(init_ids_data, init_lod, place)
246+
init_scores = set_init_lod(init_scores_data, init_lod, place)
247+
248+
src_word_data = to_lodtensor(map(lambda x: x[0], data), place)
249+
250+
result_ids, result_scores = exe.run(
251+
framework.default_main_program(),
252+
feed={
253+
'src_word_id': src_word_data,
254+
'init_ids': init_ids,
255+
'init_scores': init_scores
256+
},
257+
fetch_list=[translation_ids, translation_scores],
258+
return_numpy=False)
259+
print result_ids.lod()
260+
break
261+
262+
263+
class TestMachineTranslation(unittest.TestCase):
264+
pass
265+
266+
267+
@contextlib.contextmanager
268+
def scope_prog_guard():
269+
prog = fluid.Program()
270+
startup_prog = fluid.Program()
271+
scope = fluid.core.Scope()
272+
with fluid.scope_guard(scope):
273+
with fluid.program_guard(prog, startup_prog):
274+
yield
275+
276+
277+
def inject_test_train(use_cuda, is_sparse):
278+
f_name = 'test_{0}_{1}_train'.format('cuda' if use_cuda else 'cpu', 'sparse'
279+
if is_sparse else 'dense')
280+
281+
def f(*args):
282+
with scope_prog_guard():
283+
train(use_cuda, is_sparse)
284+
285+
setattr(TestMachineTranslation, f_name, f)
286+
287+
288+
def inject_test_decode(use_cuda, is_sparse, decorator=None):
289+
f_name = 'test_{0}_{1}_decode'.format('cuda'
290+
if use_cuda else 'cpu', 'sparse'
291+
if is_sparse else 'dense')
292+
293+
def f(*args):
294+
with scope_prog_guard():
295+
decode_main(use_cuda, is_sparse)
296+
297+
if decorator is not None:
298+
f = decorator(f)
299+
300+
setattr(TestMachineTranslation, f_name, f)
301+
302+
303+
for _use_cuda_ in (False, True):
304+
for _is_sparse_ in (False, True):
305+
inject_test_train(_use_cuda_, _is_sparse_)
306+
307+
for _use_cuda_ in (False, True):
308+
for _is_sparse_ in (False, True):
309+
310+
_decorator_ = None
311+
if _use_cuda_:
312+
_decorator_ = unittest.skip(
313+
reason='Beam Search does not support CUDA!')
314+
315+
inject_test_decode(
316+
is_sparse=_is_sparse_, use_cuda=_use_cuda_, decorator=_decorator_)
317+
318+
if __name__ == '__main__':
319+
unittest.main()

0 commit comments

Comments
 (0)