@@ -53,7 +53,7 @@ def encoder(is_sparse):
53
53
return encoder_out
54
54
55
55
56
- def decoder_train (context , is_sparse ):
56
+ def train_decoder (context , is_sparse ):
57
57
# decoder
58
58
trg_language_word = pd .data (
59
59
name = "target_language_word" , shape = [1 ], dtype = 'int64' , lod_level = 1 )
@@ -81,7 +81,7 @@ def decoder_train(context, is_sparse):
81
81
return rnn ()
82
82
83
83
84
- def decoder_decode (context , is_sparse ):
84
+ def decode (context , is_sparse ):
85
85
init_state = context
86
86
array_len = pd .fill_constant (shape = [1 ], dtype = 'int64' , value = max_length )
87
87
counter = pd .zeros (shape = [1 ], dtype = 'int64' , force_cpu = True )
@@ -150,7 +150,7 @@ def decoder_decode(context, is_sparse):
150
150
151
151
def train_program (is_sparse ):
152
152
context = encoder (is_sparse )
153
- rnn_out = decoder_train (context , is_sparse )
153
+ rnn_out = train_decoder (context , is_sparse )
154
154
label = pd .data (
155
155
name = "target_language_next_word" , shape = [1 ], dtype = 'int64' , lod_level = 1 )
156
156
cost = pd .cross_entropy (input = rnn_out , label = label )
@@ -201,7 +201,7 @@ def decode_main(use_cuda, is_sparse):
201
201
place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
202
202
203
203
context = encoder (is_sparse )
204
- translation_ids , translation_scores = decoder_decode (context , is_sparse )
204
+ translation_ids , translation_scores = decode (context , is_sparse )
205
205
206
206
exe = Executor (place )
207
207
exe .run (framework .default_startup_program ())
0 commit comments