@@ -116,7 +116,7 @@ class ModelHyperParams(object):
116
116
# to process after each sub-layer
117
117
postprocess_cmd = "da" # dropout + residual connection
118
118
# random seed used in dropout for CE.
119
- dropout_seed = 1
119
+ dropout_seed = None
120
120
# the flag indicating whether to share embedding and softmax weights.
121
121
# vocabularies in source and target should be same for weight sharing.
122
122
weight_sharing = True
@@ -166,15 +166,21 @@ def create_data(is_static=False):
166
166
]
167
167
else :
168
168
enc_inputs = [
169
- to_variable (src_word_np ), to_variable (src_pos_np ),
170
- to_variable (src_slf_attn_bias_np )
169
+ to_variable (
170
+ src_word_np , name = 'src_word' ), to_variable (
171
+ src_pos_np , name = 'src_pos' ), to_variable (
172
+ src_slf_attn_bias_np , name = 'src_slf_attn_bias' )
171
173
]
172
174
dec_inputs = [
173
- to_variable (trg_word_np ), to_variable (trg_pos_np ),
174
- to_variable (trg_slf_attn_bias_np ), to_variable (trg_src_attn_bias_np )
175
+ to_variable (
176
+ trg_word_np , name = 'trg_word' ), to_variable (
177
+ trg_pos_np , name = 'trg_pos' ), to_variable (
178
+ trg_slf_attn_bias_np , name = 'trg_slf_attn_bias' ),
179
+ to_variable (
180
+ trg_src_attn_bias_np , name = 'trg_src_attn_bias' )
175
181
]
176
- label = to_variable (lbl_word_np )
177
- weight = to_variable (lbl_weight_np )
182
+ label = to_variable (lbl_word_np , name = 'lbl_word' )
183
+ weight = to_variable (lbl_weight_np , name = 'lbl_weight' )
178
184
return enc_inputs , dec_inputs , label , weight
179
185
180
186
@@ -211,7 +217,7 @@ def make_all_inputs(input_fields):
211
217
# The placeholder for batch_size in compile time. Must be -1 currently to be
212
218
# consistent with some ops' infer-shape output in compile time, such as the
213
219
# sequence_expand op used in beamsearch decoder.
214
- batch_size = 32
220
+ batch_size = - 1
215
221
# The placeholder for squence length in compile time.
216
222
seq_len = ModelHyperParams .max_length
217
223
# Here list the data shapes and data types of all inputs.
@@ -304,35 +310,40 @@ def make_all_inputs(input_fields):
304
310
305
311
batch_num = 5
306
312
307
- np .random .seed = 1
313
+ np .random .seed = 90
308
314
src_word_np = np .random .randint (
309
315
1 ,
310
316
ModelHyperParams .src_vocab_size - 1 ,
311
- size = (batch_size , seq_len , 1 ),
317
+ size = (TrainTaskConfig . batch_size , seq_len , 1 ),
312
318
dtype = 'int64' )
313
319
src_pos_np = np .random .randint (
314
- 1 , seq_len , size = (batch_size , seq_len , 1 ), dtype = 'int64' )
315
- src_slf_attn_bias_np = np .random .randn (batch_size , ModelHyperParams .n_head ,
316
- seq_len , seq_len ).astype ('float32' )
320
+ 1 , seq_len , size = (TrainTaskConfig .batch_size , seq_len , 1 ), dtype = 'int64' )
321
+ src_slf_attn_bias_np = np .random .randn (TrainTaskConfig .batch_size ,
322
+ ModelHyperParams .n_head , seq_len ,
323
+ seq_len ).astype ('float32' )
317
324
318
325
trg_word_np = np .random .randint (
319
326
1 ,
320
327
ModelHyperParams .src_vocab_size - 1 ,
321
- size = (batch_size , seq_len , 1 ),
328
+ size = (TrainTaskConfig . batch_size , seq_len , 1 ),
322
329
dtype = 'int64' )
323
330
trg_pos_np = np .random .randint (
324
- 1 , seq_len , size = (batch_size , seq_len , 1 ), dtype = 'int64' )
325
- trg_slf_attn_bias_np = np .random .randn (batch_size , ModelHyperParams .n_head ,
326
- seq_len , seq_len ).astype ('float32' )
327
- trg_src_attn_bias_np = np .random .randn (batch_size , ModelHyperParams .n_head ,
328
- seq_len , seq_len ).astype ('float32' )
331
+ 1 , seq_len , size = (TrainTaskConfig .batch_size , seq_len , 1 ), dtype = 'int64' )
332
+ trg_slf_attn_bias_np = np .random .randn (TrainTaskConfig .batch_size ,
333
+ ModelHyperParams .n_head , seq_len ,
334
+ seq_len ).astype ('float32' )
335
+ trg_src_attn_bias_np = np .random .randn (TrainTaskConfig .batch_size ,
336
+ ModelHyperParams .n_head , seq_len ,
337
+ seq_len ).astype ('float32' )
329
338
330
339
lbl_word_np = np .random .randint (
331
340
1 ,
332
341
ModelHyperParams .src_vocab_size - 1 ,
333
- size = (batch_size * seq_len , 1 ),
342
+ size = (TrainTaskConfig . batch_size * seq_len , 1 ),
334
343
dtype = 'int64' )
335
- lbl_weight_np = np .random .randn (batch_size * seq_len , 1 ).astype ('float32' )
344
+
345
+ lbl_weight_np = np .random .randn (TrainTaskConfig .batch_size * seq_len ,
346
+ 1 ).astype ('float32' )
336
347
337
348
pos_inp1 = position_encoding_init (ModelHyperParams .max_length ,
338
349
ModelHyperParams .d_model )
@@ -447,7 +458,7 @@ def forward(self, queries, keys, values, attn_bias):
447
458
x = v , shape = [0 , 0 , self ._n_head , self ._d_value ], inplace = False )
448
459
transpose_v = fluid .layers .transpose (x = reshaped_v , perm = [0 , 2 , 1 , 3 ])
449
460
450
- #scale dot product attention
461
+ # scale dot product attention
451
462
product = fluid .layers .matmul (
452
463
x = transpose_q ,
453
464
y = transpose_k ,
@@ -971,13 +982,15 @@ def test_transformer_float32(self):
971
982
enc_inputs , dec_inputs , label , weights = create_data ()
972
983
dy_sum_cost , dy_avg_cost , dy_predict , dy_token_num = transformer (
973
984
enc_inputs , dec_inputs , label , weights )
985
+
974
986
if i == 0 :
975
987
for param in transformer .parameters ():
976
988
dy_param_init [param .name ] = param ._numpy ()
977
989
978
990
dy_avg_cost ._backward ()
979
991
optimizer .minimize (dy_avg_cost )
980
992
transformer .clear_gradients ()
993
+
981
994
if i == batch_num - 1 :
982
995
for param in transformer .parameters ():
983
996
dy_param_updated [param .name ] = param ._numpy ()
@@ -1024,7 +1037,6 @@ def test_transformer_float32(self):
1024
1037
static_param_name_list = list ()
1025
1038
static_sum_cost , static_avg_cost , static_predict , static_token_num = transformer (
1026
1039
enc_inputs , dec_inputs , label , weights )
1027
-
1028
1040
optimizer .minimize (static_avg_cost )
1029
1041
for param in transformer .parameters ():
1030
1042
static_param_name_list .append (param .name )
@@ -1042,8 +1054,8 @@ def test_transformer_float32(self):
1042
1054
static_sum_cost , static_avg_cost , static_predict ,
1043
1055
static_token_num
1044
1056
]
1045
- fetch_list .extend (static_param_name_list )
1046
1057
1058
+ fetch_list .extend (static_param_name_list )
1047
1059
out = exe .run (fluid .default_main_program (),
1048
1060
feed = feed_dict ,
1049
1061
fetch_list = fetch_list )
0 commit comments