@@ -260,10 +260,13 @@ def __init__(
260
260
logger .info ("Using half precision" )
261
261
self .do_grad_scaling = True
262
262
self .amp_dtype = "float16" if args .fp16 else "bfloat16"
263
+ # fix for load saved fp16 or bf16 ckpt, decorate model first.
264
+ if self .args .fp16_opt_level == "O2" :
265
+ paddle .amp .decorate (models = model , level = self .args .fp16_opt_level , dtype = self .amp_dtype )
263
266
264
267
if self .sharding is not None :
265
268
self .scaler = paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss )
266
- if self .amp_dtype == "float16" :
269
+ if self .amp_dtype == "float16" or self . amp_dtype == "bfloat16" :
267
270
if ShardingOption .SHARD_OP in self .args .sharding :
268
271
self .scaler = fleet .distributed_scaler (self .scaler )
269
272
else :
@@ -291,7 +294,9 @@ def __init__(
291
294
if args .recompute :
292
295
293
296
def fn (layer ):
294
- if hasattr (layer , "enable_recompute" ) and layer .enable_recompute is False :
297
+ if hasattr (layer , "enable_recompute" ) and (
298
+ layer .enable_recompute is False or layer .enable_recompute == 0
299
+ ):
295
300
layer .enable_recompute = True
296
301
297
302
model .apply (fn )
@@ -453,9 +458,9 @@ def train(
453
458
454
459
# delay_optimizer_creation = (
455
460
# self.sharding is not None
456
- # and ShardingOption.SHARD_OP not in self.args.sharding
461
+ # and ShardingOption.SHARD_OP in self.args.sharding
457
462
# )
458
- delay_optimizer_creation = self . sharding is None
463
+ delay_optimizer_creation = False
459
464
460
465
if not delay_optimizer_creation :
461
466
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
@@ -757,7 +762,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
757
762
tr_loss .subtract_ (tr_loss )
758
763
759
764
logs ["loss" ] = round (tr_loss_scalar / (self .state .global_step - self ._globalstep_last_logged ), 8 )
760
- logs ["learning_rate" ] = self ._get_learning_rate ()
765
+ logs ["learning_rate" ] = float ( "{0:.3e}" . format ( self ._get_learning_rate ()) )
761
766
logs ["global_step" ] = int (self .state .global_step )
762
767
763
768
total_train_batch_size = (
@@ -967,6 +972,8 @@ def apply_decay_param_fun(x):
967
972
return x in decay_parameters
968
973
969
974
optimizer_cls , optimizer_kwargs = Trainer .get_optimizer_cls_and_kwargs (self .args )
975
+ if hasattr (optimizer_cls , "_create_master_weight" ) and self .args .fp16_opt_level == "O2" :
976
+ optimizer_kwargs ["multi_precision" ] = True
970
977
971
978
if ShardingOption .SHARD_OP in self .args .sharding :
972
979
from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer import (
@@ -1575,9 +1582,15 @@ def evaluation_loop(
1575
1582
logger .info (f"***** Running { description } *****" )
1576
1583
if has_length (dataloader ):
1577
1584
logger .info (f" Num examples = { self .num_examples (dataloader )} " )
1578
- logger .info (f" Total prediction steps = { len (dataloader )} " )
1585
+ if max_eval_iters > 0 :
1586
+ logger .info (f" Total prediction steps = { max_eval_iters } " )
1587
+ else :
1588
+ logger .info (f" Total prediction steps = { len (dataloader )} " )
1579
1589
else :
1580
1590
logger .info (" Num examples: Unknown" )
1591
+ if max_eval_iters > 0 :
1592
+ logger .info (f" Total prediction steps = { max_eval_iters } " )
1593
+
1581
1594
logger .info (f" Pre device batch size = { batch_size } " )
1582
1595
logger .info (f" Total Batch size = { batch_size * self .args .world_size } " )
1583
1596
0 commit comments