Skip to content

Commit a86de1b

Browse files
authored
trainer refactor (PaddlePaddle#7909)
1 parent 2a5e5a6 commit a86de1b

File tree

1 file changed

+76
-46
lines changed

1 file changed

+76
-46
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def __init__(
330330
"Passing `optimizers` is not allowed if sharding is enabled."
331331
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
332332
)
333-
if self.args.pipeline_parallel_degree > 1:
333+
334+
if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel:
334335
from paddle.distributed.fleet.meta_parallel import PipelineLayer
335336

336337
assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance(
@@ -352,52 +353,9 @@ def __init__(
352353

353354
self.do_grad_scaling = False
354355
self.enable_autocast_context_manager = False
355-
356356
if args.fp16 or args.bf16:
357-
logger.info("Using half precision")
358-
self.enable_autocast_context_manager = True
359-
self.do_grad_scaling = True if args.fp16 else False
360-
self.amp_dtype = "float16" if args.fp16 else "bfloat16"
361-
# fix for load saved fp16 or bf16 ckpt, decorate model first.
362-
if self.args.fp16_opt_level == "O2":
363-
paddle.amp.decorate(
364-
models=model,
365-
level=self.args.fp16_opt_level,
366-
dtype=self.amp_dtype,
367-
excluded_layers=QuantizationLinear,
368-
)
369-
# for pipeline mode and pure tensor parallel
370-
if self.args.pipeline_parallel_degree > 1 or (
371-
self.args.tensor_parallel_degree > 1 and self.sharding is None
372-
):
373-
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
374-
if self.args.amp_master_grad:
375-
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
376-
self.scaler = fleet.distributed_scaler(self.scaler)
377-
elif self.sharding is not None:
378-
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
379-
if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16":
380-
if ShardingOption.SHARD_OP in self.args.sharding:
381-
self.scaler = fleet.distributed_scaler(self.scaler)
382-
if self.args.amp_master_grad:
383-
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
384-
else:
385-
# scaler for stage2 and stage3
386-
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
387-
GroupShardedScaler,
388-
)
389-
390-
if self.args.amp_master_grad:
391-
mix_precision_utils.MixPrecisionScaler(self.scaler) # return value has no use
392-
393-
self.scaler = GroupShardedScaler(self.scaler)
394-
else:
395-
self.do_grad_scaling = False
396-
self.use_cuda_amp = False
397-
self.amp_dtype = None
398-
399-
else:
400-
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
357+
# set do_grad_scaling, enable_autocast_context_manager
358+
self._wrap_amp_model(args, model)
401359

402360
if args.recompute:
403361

@@ -422,6 +380,50 @@ def fn(layer):
422380
# very last
423381
self._memory_tracker.stop_and_update_metrics()
424382

383+
def _wrap_amp_model(self, args, model):
384+
logger.info("Using half precision")
385+
self.enable_autocast_context_manager = True
386+
self.do_grad_scaling = True if args.fp16 else False
387+
self.amp_dtype = "float16" if args.fp16 else "bfloat16"
388+
# fix for load saved fp16 or bf16 ckpt, decorate model first.
389+
if self.args.fp16_opt_level == "O2":
390+
paddle.amp.decorate(
391+
models=model,
392+
level=self.args.fp16_opt_level,
393+
dtype=self.amp_dtype,
394+
excluded_layers=QuantizationLinear,
395+
)
396+
# for pipeline mode and pure tensor parallel
397+
if self.args.pipeline_parallel_degree > 1 or (self.args.tensor_parallel_degree > 1 and self.sharding is None):
398+
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
399+
if self.args.amp_master_grad:
400+
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
401+
self.scaler = fleet.distributed_scaler(self.scaler)
402+
elif self.sharding is not None:
403+
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
404+
if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16":
405+
if ShardingOption.SHARD_OP in self.args.sharding:
406+
self.scaler = fleet.distributed_scaler(self.scaler)
407+
if self.args.amp_master_grad:
408+
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
409+
else:
410+
# scaler for stage2 and stage3
411+
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
412+
GroupShardedScaler,
413+
)
414+
415+
if self.args.amp_master_grad:
416+
mix_precision_utils.MixPrecisionScaler(self.scaler) # return value has no use
417+
418+
self.scaler = GroupShardedScaler(self.scaler)
419+
else:
420+
self.do_grad_scaling = False
421+
self.use_cuda_amp = False
422+
self.amp_dtype = None
423+
424+
else:
425+
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
426+
425427
def add_callback(self, callback):
426428
"""
427429
Add a callback to the current list of [`~TrainerCallback`].
@@ -747,6 +749,33 @@ def train(
747749
# so, the trainable numel is a little bigger than real.
748750
logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")
749751

752+
return self._inner_training_loop(
753+
args,
754+
model,
755+
train_dataloader,
756+
len_dataloader,
757+
max_steps,
758+
num_train_epochs,
759+
num_update_steps_per_epoch,
760+
num_train_samples,
761+
resume_from_checkpoint,
762+
ignore_keys_for_eval,
763+
)
764+
765+
def _inner_training_loop(
766+
self,
767+
args,
768+
model,
769+
train_dataloader,
770+
len_dataloader,
771+
max_steps,
772+
num_train_epochs,
773+
num_update_steps_per_epoch,
774+
num_train_samples,
775+
resume_from_checkpoint,
776+
ignore_keys_for_eval,
777+
):
778+
750779
start_time = time.time()
751780
self._globalstep_last_start_time = time.time()
752781
self.state.epoch = 0
@@ -2274,6 +2303,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
22742303
logger.warning("Trainer.model is not a `PretrainedModel`, not suppor for merge_tensor_parallel.")
22752304
if state_dict is None:
22762305
state_dict = self.model.state_dict()
2306+
22772307
paddle.save(
22782308
state_dict,
22792309
os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)),

0 commit comments

Comments
 (0)