Skip to content

Commit b07584d

Browse files
authored
test=release/1.4, refine test_imperative_transformer (#16737)
1 parent cb9c59b commit b07584d

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

python/paddle/fluid/tests/unittests/test_imperative_transformer.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class ModelHyperParams(object):
116116
# to process after each sub-layer
117117
postprocess_cmd = "da" # dropout + residual connection
118118
# random seed used in dropout for CE.
119-
dropout_seed = 1
119+
dropout_seed = None
120120
# the flag indicating whether to share embedding and softmax weights.
121121
# vocabularies in source and target should be same for weight sharing.
122122
weight_sharing = True
@@ -166,15 +166,21 @@ def create_data(is_static=False):
166166
]
167167
else:
168168
enc_inputs = [
169-
to_variable(src_word_np), to_variable(src_pos_np),
170-
to_variable(src_slf_attn_bias_np)
169+
to_variable(
170+
src_word_np, name='src_word'), to_variable(
171+
src_pos_np, name='src_pos'), to_variable(
172+
src_slf_attn_bias_np, name='src_slf_attn_bias')
171173
]
172174
dec_inputs = [
173-
to_variable(trg_word_np), to_variable(trg_pos_np),
174-
to_variable(trg_slf_attn_bias_np), to_variable(trg_src_attn_bias_np)
175+
to_variable(
176+
trg_word_np, name='trg_word'), to_variable(
177+
trg_pos_np, name='trg_pos'), to_variable(
178+
trg_slf_attn_bias_np, name='trg_slf_attn_bias'),
179+
to_variable(
180+
trg_src_attn_bias_np, name='trg_src_attn_bias')
175181
]
176-
label = to_variable(lbl_word_np)
177-
weight = to_variable(lbl_weight_np)
182+
label = to_variable(lbl_word_np, name='lbl_word')
183+
weight = to_variable(lbl_weight_np, name='lbl_weight')
178184
return enc_inputs, dec_inputs, label, weight
179185

180186

@@ -211,7 +217,7 @@ def make_all_inputs(input_fields):
211217
# The placeholder for batch_size in compile time. Must be -1 currently to be
212218
# consistent with some ops' infer-shape output in compile time, such as the
213219
# sequence_expand op used in beamsearch decoder.
214-
batch_size = 32
220+
batch_size = -1
215221
# The placeholder for squence length in compile time.
216222
seq_len = ModelHyperParams.max_length
217223
# Here list the data shapes and data types of all inputs.
@@ -304,35 +310,40 @@ def make_all_inputs(input_fields):
304310

305311
batch_num = 5
306312

307-
np.random.seed = 1
313+
np.random.seed = 90
308314
src_word_np = np.random.randint(
309315
1,
310316
ModelHyperParams.src_vocab_size - 1,
311-
size=(batch_size, seq_len, 1),
317+
size=(TrainTaskConfig.batch_size, seq_len, 1),
312318
dtype='int64')
313319
src_pos_np = np.random.randint(
314-
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
315-
src_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
316-
seq_len, seq_len).astype('float32')
320+
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
321+
src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
322+
ModelHyperParams.n_head, seq_len,
323+
seq_len).astype('float32')
317324

318325
trg_word_np = np.random.randint(
319326
1,
320327
ModelHyperParams.src_vocab_size - 1,
321-
size=(batch_size, seq_len, 1),
328+
size=(TrainTaskConfig.batch_size, seq_len, 1),
322329
dtype='int64')
323330
trg_pos_np = np.random.randint(
324-
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
325-
trg_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
326-
seq_len, seq_len).astype('float32')
327-
trg_src_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
328-
seq_len, seq_len).astype('float32')
331+
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
332+
trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
333+
ModelHyperParams.n_head, seq_len,
334+
seq_len).astype('float32')
335+
trg_src_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
336+
ModelHyperParams.n_head, seq_len,
337+
seq_len).astype('float32')
329338

330339
lbl_word_np = np.random.randint(
331340
1,
332341
ModelHyperParams.src_vocab_size - 1,
333-
size=(batch_size * seq_len, 1),
342+
size=(TrainTaskConfig.batch_size * seq_len, 1),
334343
dtype='int64')
335-
lbl_weight_np = np.random.randn(batch_size * seq_len, 1).astype('float32')
344+
345+
lbl_weight_np = np.random.randn(TrainTaskConfig.batch_size * seq_len,
346+
1).astype('float32')
336347

337348
pos_inp1 = position_encoding_init(ModelHyperParams.max_length,
338349
ModelHyperParams.d_model)
@@ -447,7 +458,7 @@ def forward(self, queries, keys, values, attn_bias):
447458
x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False)
448459
transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
449460

450-
#scale dot product attention
461+
# scale dot product attention
451462
product = fluid.layers.matmul(
452463
x=transpose_q,
453464
y=transpose_k,
@@ -971,13 +982,15 @@ def test_transformer_float32(self):
971982
enc_inputs, dec_inputs, label, weights = create_data()
972983
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
973984
enc_inputs, dec_inputs, label, weights)
985+
974986
if i == 0:
975987
for param in transformer.parameters():
976988
dy_param_init[param.name] = param._numpy()
977989

978990
dy_avg_cost._backward()
979991
optimizer.minimize(dy_avg_cost)
980992
transformer.clear_gradients()
993+
981994
if i == batch_num - 1:
982995
for param in transformer.parameters():
983996
dy_param_updated[param.name] = param._numpy()
@@ -1024,7 +1037,6 @@ def test_transformer_float32(self):
10241037
static_param_name_list = list()
10251038
static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer(
10261039
enc_inputs, dec_inputs, label, weights)
1027-
10281040
optimizer.minimize(static_avg_cost)
10291041
for param in transformer.parameters():
10301042
static_param_name_list.append(param.name)
@@ -1042,8 +1054,8 @@ def test_transformer_float32(self):
10421054
static_sum_cost, static_avg_cost, static_predict,
10431055
static_token_num
10441056
]
1045-
fetch_list.extend(static_param_name_list)
10461057

1058+
fetch_list.extend(static_param_name_list)
10471059
out = exe.run(fluid.default_main_program(),
10481060
feed=feed_dict,
10491061
fetch_list=fetch_list)

0 commit comments

Comments
 (0)