Skip to content

Commit df300ff

Browse files
authored
Merge pull request #11056 from nickyfantasy/refract_machine_translation_test
Simplify and make clear function names on Machine Translation example
2 parents 83fb834 + 6f3c7d9 commit df300ff

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def encoder(is_sparse):
5353
return encoder_out
5454

5555

56-
def decoder_train(context, is_sparse):
56+
def train_decoder(context, is_sparse):
5757
# decoder
5858
trg_language_word = pd.data(
5959
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
@@ -81,7 +81,7 @@ def decoder_train(context, is_sparse):
8181
return rnn()
8282

8383

84-
def decoder_decode(context, is_sparse):
84+
def decode(context, is_sparse):
8585
init_state = context
8686
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
8787
counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
@@ -150,7 +150,7 @@ def decoder_decode(context, is_sparse):
150150

151151
def train_program(is_sparse):
152152
context = encoder(is_sparse)
153-
rnn_out = decoder_train(context, is_sparse)
153+
rnn_out = train_decoder(context, is_sparse)
154154
label = pd.data(
155155
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
156156
cost = pd.cross_entropy(input=rnn_out, label=label)
@@ -201,7 +201,7 @@ def decode_main(use_cuda, is_sparse):
201201
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
202202

203203
context = encoder(is_sparse)
204-
translation_ids, translation_scores = decoder_decode(context, is_sparse)
204+
translation_ids, translation_scores = decode(context, is_sparse)
205205

206206
exe = Executor(place)
207207
exe.run(framework.default_startup_program())

0 commit comments

Comments
 (0)