Skip to content

Commit 2455ec2

Browse files
authored
Fix Transformer amp (Fix #1574) (#1635)
1 parent 8993c32 commit 2455ec2

File tree

1 file changed

+15
-6
lines changed
  • examples/machine_translation/transformer

1 file changed

+15
-6
lines changed

examples/machine_translation/transformer/train.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def do_train(args):
158158
amp_level = 'O2' if args.use_pure_fp16 else 'O1'
159159
scaler = paddle.amp.GradScaler(
160160
enable=True, init_loss_scaling=args.scale_loss)
161-
transformer = paddle.amp.decorate(models=transformer, level=amp_level)
161+
transformer = paddle.amp.decorate(
162+
models=transformer, level=amp_level, save_dtype='float32')
162163

163164
# for distributed training
164165
if trainer_count > 1:
@@ -260,15 +261,23 @@ def do_train(args):
260261
with paddle.no_grad():
261262
for input_data in eval_loader:
262263
(src_word, trg_word, lbl_word) = input_data
263-
with paddle.amp.auto_cast(
264-
custom_black_list={
265-
'scale', 'reduce_sum', 'elementwise_div'
266-
} if amp_level == 'O2' else {},
267-
level=amp_level):
264+
if args.use_amp:
265+
with paddle.amp.auto_cast(
266+
custom_black_list={
267+
'scale', 'reduce_sum', 'elementwise_div'
268+
} if amp_level == 'O2' else {},
269+
level=amp_level):
270+
logits = transformer(
271+
src_word=src_word, trg_word=trg_word)
272+
sum_cost, avg_cost, token_num = criterion(
273+
logits, lbl_word)
274+
275+
else:
268276
logits = transformer(
269277
src_word=src_word, trg_word=trg_word)
270278
sum_cost, avg_cost, token_num = criterion(logits,
271279
lbl_word)
280+
272281
total_sum_cost += sum_cost.numpy()
273282
total_token_num += token_num.numpy()
274283
total_avg_cost = total_sum_cost / total_token_num

0 commit comments

Comments
 (0)