Skip to content

Commit 318ed29

Browse files
authored
[Trainer] fix multi_precision (#4208)
* fix multi_precision
1 parent 5e3c3a5 commit 318ed29

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,13 @@ def __init__(
260260
logger.info("Using half precision")
261261
self.do_grad_scaling = True
262262
self.amp_dtype = "float16" if args.fp16 else "bfloat16"
263+
# fix for load saved fp16 or bf16 ckpt, decorate model first.
264+
if self.args.fp16_opt_level == "O2":
265+
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype)
263266

264267
if self.sharding is not None:
265268
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
266-
if self.amp_dtype == "float16":
269+
if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16":
267270
if ShardingOption.SHARD_OP in self.args.sharding:
268271
self.scaler = fleet.distributed_scaler(self.scaler)
269272
else:
@@ -291,7 +294,9 @@ def __init__(
291294
if args.recompute:
292295

293296
def fn(layer):
294-
if hasattr(layer, "enable_recompute") and layer.enable_recompute is False:
297+
if hasattr(layer, "enable_recompute") and (
298+
layer.enable_recompute is False or layer.enable_recompute == 0
299+
):
295300
layer.enable_recompute = True
296301

297302
model.apply(fn)
@@ -453,9 +458,9 @@ def train(
453458

454459
# delay_optimizer_creation = (
455460
# self.sharding is not None
456-
# and ShardingOption.SHARD_OP not in self.args.sharding
461+
# and ShardingOption.SHARD_OP in self.args.sharding
457462
# )
458-
delay_optimizer_creation = self.sharding is None
463+
delay_optimizer_creation = False
459464

460465
if not delay_optimizer_creation:
461466
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
@@ -757,7 +762,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
757762
tr_loss.subtract_(tr_loss)
758763

759764
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
760-
logs["learning_rate"] = self._get_learning_rate()
765+
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
761766
logs["global_step"] = int(self.state.global_step)
762767

763768
total_train_batch_size = (
@@ -967,6 +972,8 @@ def apply_decay_param_fun(x):
967972
return x in decay_parameters
968973

969974
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
975+
if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2":
976+
optimizer_kwargs["multi_precision"] = True
970977

971978
if ShardingOption.SHARD_OP in self.args.sharding:
972979
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
@@ -1575,9 +1582,15 @@ def evaluation_loop(
15751582
logger.info(f"***** Running {description} *****")
15761583
if has_length(dataloader):
15771584
logger.info(f" Num examples = {self.num_examples(dataloader)}")
1578-
logger.info(f" Total prediction steps = {len(dataloader)}")
1585+
if max_eval_iters > 0:
1586+
logger.info(f" Total prediction steps = {max_eval_iters}")
1587+
else:
1588+
logger.info(f" Total prediction steps = {len(dataloader)}")
15791589
else:
15801590
logger.info(" Num examples: Unknown")
1591+
if max_eval_iters > 0:
1592+
logger.info(f" Total prediction steps = {max_eval_iters}")
1593+
15811594
logger.info(f" Pre device batch size = {batch_size}")
15821595
logger.info(f" Total Batch size = {batch_size * self.args.world_size}")
15831596

0 commit comments

Comments
 (0)