@@ -158,7 +158,8 @@ def do_train(args):
158
158
amp_level = 'O2' if args .use_pure_fp16 else 'O1'
159
159
scaler = paddle .amp .GradScaler (
160
160
enable = True , init_loss_scaling = args .scale_loss )
161
- transformer = paddle .amp .decorate (models = transformer , level = amp_level )
161
+ transformer = paddle .amp .decorate (
162
+ models = transformer , level = amp_level , save_dtype = 'float32' )
162
163
163
164
# for distributed training
164
165
if trainer_count > 1 :
@@ -260,15 +261,23 @@ def do_train(args):
260
261
with paddle .no_grad ():
261
262
for input_data in eval_loader :
262
263
(src_word , trg_word , lbl_word ) = input_data
263
- with paddle .amp .auto_cast (
264
- custom_black_list = {
265
- 'scale' , 'reduce_sum' , 'elementwise_div'
266
- } if amp_level == 'O2' else {},
267
- level = amp_level ):
264
+ if args .use_amp :
265
+ with paddle .amp .auto_cast (
266
+ custom_black_list = {
267
+ 'scale' , 'reduce_sum' , 'elementwise_div'
268
+ } if amp_level == 'O2' else {},
269
+ level = amp_level ):
270
+ logits = transformer (
271
+ src_word = src_word , trg_word = trg_word )
272
+ sum_cost , avg_cost , token_num = criterion (
273
+ logits , lbl_word )
274
+
275
+ else :
268
276
logits = transformer (
269
277
src_word = src_word , trg_word = trg_word )
270
278
sum_cost , avg_cost , token_num = criterion (logits ,
271
279
lbl_word )
280
+
272
281
total_sum_cost += sum_cost .numpy ()
273
282
total_token_num += token_num .numpy ()
274
283
total_avg_cost = total_sum_cost / total_token_num
0 commit comments