Skip to content

Commit 0ce21ea

Browse files
authored
[PPDiffusers] update ldm train code (#3772)
* update ldm train * add rank==0 * add rank==0 * update trainer * update * update * update * update * update trainer code * update generate pipelines * update support bf16
1 parent 051d4e7 commit 0ce21ea

27 files changed

+1953
-1852
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class Trainer:
177177
def __init__(
178178
self,
179179
model: Union[PretrainedModel, nn.Layer] = None,
180-
criterion: Union[nn.Layer] = None,
180+
criterion: nn.Layer = None,
181181
args: TrainingArguments = None,
182182
data_collator: Optional[DataCollator] = None,
183183
train_dataset: Optional[Dataset] = None,
@@ -241,6 +241,7 @@ def __init__(
241241
self.state = TrainerState()
242242
self.control = TrainerControl()
243243
self._signature_columns = None
244+
self.optimizer_grouped_parameters = None
244245

245246
if (self.sharding is not None) and (self.optimizer is not None
246247
or self.lr_scheduler is not None):
@@ -710,9 +711,11 @@ def train(
710711

711712
self.control = self.callback_handler.on_step_end(
712713
args, self.state, self.control)
713-
714-
self._maybe_log_save_evaluate(tr_loss, model, epoch,
715-
ignore_keys_for_eval)
714+
self._maybe_log_save_evaluate(tr_loss,
715+
model,
716+
epoch,
717+
ignore_keys_for_eval,
718+
inputs=inputs)
716719
else:
717720
self.control = self.callback_handler.on_substep_end(
718721
args, self.state, self.control)
@@ -730,8 +733,11 @@ def train(
730733

731734
self.control = self.callback_handler.on_epoch_end(
732735
args, self.state, self.control)
733-
self._maybe_log_save_evaluate(tr_loss, model, epoch,
734-
ignore_keys_for_eval)
736+
self._maybe_log_save_evaluate(tr_loss,
737+
model,
738+
epoch,
739+
ignore_keys_for_eval,
740+
inputs=inputs)
735741

736742
if self.control.should_training_stop:
737743
break
@@ -805,7 +811,7 @@ def _set_state_dict_in_model(self, state_dict):
805811
self.model.set_state_dict(state_dict)
806812

807813
def _maybe_log_save_evaluate(self, tr_loss, model, epoch,
808-
ignore_keys_for_eval):
814+
ignore_keys_for_eval, **kwargs):
809815
if self.control.should_log:
810816

811817
logs: Dict[str, float] = {}
@@ -836,7 +842,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch,
836842
self._globalstep_last_logged = self.state.global_step
837843
self._globalstep_last_start_time = time.time()
838844

839-
self.log(logs)
845+
self.log(logs, **kwargs)
840846

841847
metrics = None
842848
if self.control.should_evaluate:
@@ -1024,11 +1030,16 @@ def create_optimizer(self, lr_scheduler=None):
10241030
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
10251031
"""
10261032
if self.optimizer is None:
1027-
decay_parameters = [
1028-
p.name for n, p in self.model.named_parameters()
1029-
if not any(nd in n for nd in ["bias", "norm"])
1030-
]
1031-
apply_decay_param_fun = lambda x: x in decay_parameters
1033+
if self.optimizer_grouped_parameters is not None:
1034+
params = self.optimizer_grouped_parameters
1035+
apply_decay_param_fun = None
1036+
else:
1037+
params = self.model.parameters()
1038+
decay_parameters = [
1039+
p.name for n, p in self.model.named_parameters()
1040+
if not any(nd in n for nd in ["bias", "norm"])
1041+
]
1042+
apply_decay_param_fun = lambda x: x in decay_parameters
10321043

10331044
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
10341045
self.args)
@@ -1038,22 +1049,24 @@ def create_optimizer(self, lr_scheduler=None):
10381049
self.optimizer = DygraphShardingOptimizer(
10391050
hcg=fleet.get_hybrid_communicate_group(),
10401051
user_defined_strategy=None,
1041-
params=self.model.parameters(),
1052+
params=params,
10421053
inner_optimizer_class=optimizer_cls,
10431054
learning_rate=self.lr_scheduler
10441055
if lr_scheduler is None else lr_scheduler,
10451056
apply_decay_param_fun=apply_decay_param_fun,
10461057
weight_decay=self.args.weight_decay,
1047-
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm),
1058+
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm)
1059+
if self.args.max_grad_norm > 0 else None,
10481060
**optimizer_kwargs)
10491061
else:
10501062
self.optimizer = optimizer_cls(
10511063
learning_rate=self.lr_scheduler
10521064
if lr_scheduler is None else lr_scheduler,
10531065
apply_decay_param_fun=apply_decay_param_fun,
1054-
parameters=self.model.parameters(),
1066+
parameters=params,
10551067
weight_decay=self.args.weight_decay,
1056-
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm),
1068+
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm)
1069+
if self.args.max_grad_norm > 0 else None,
10571070
**optimizer_kwargs)
10581071

10591072
return self.optimizer
@@ -1429,6 +1442,10 @@ def _save_checkpoint(self, model, metrics=None):
14291442
if self.args.should_save:
14301443
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
14311444

1445+
def set_optimizer_grouped_parameters(self,
1446+
optimizer_grouped_parameters=None):
1447+
self.optimizer_grouped_parameters = optimizer_grouped_parameters
1448+
14321449
def _sorted_checkpoints(self,
14331450
output_dir=None,
14341451
checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
@@ -1553,7 +1570,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
15531570
paddle.load(os.path.join(checkpoint, SCALER_NAME),
15541571
return_numpy=True))
15551572

1556-
def log(self, logs: Dict[str, float]) -> None:
1573+
def log(self, logs: Dict[str, float], **kwargs) -> None:
15571574
"""
15581575
Log `logs` on the various objects watching training.
15591576
@@ -1569,7 +1586,8 @@ def log(self, logs: Dict[str, float]) -> None:
15691586
output = {**logs, **{"step": self.state.global_step}}
15701587
self.state.log_history.append(output)
15711588
self.control = self.callback_handler.on_log(self.args, self.state,
1572-
self.control, logs)
1589+
self.control, logs,
1590+
**kwargs)
15731591

15741592
def evaluate(
15751593
self,

paddlenlp/trainer/trainer_callback.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,14 @@ def on_save(self, args: TrainingArguments, state: TrainerState,
410410
return self.call_event("on_save", args, state, control)
411411

412412
def on_log(self, args: TrainingArguments, state: TrainerState,
413-
control: TrainerControl, logs):
413+
control: TrainerControl, logs, **kwargs):
414414
control.should_log = False
415-
return self.call_event("on_log", args, state, control, logs=logs)
415+
return self.call_event("on_log",
416+
args,
417+
state,
418+
control,
419+
logs=logs,
420+
**kwargs)
416421

417422
def on_prediction_step(self, args: TrainingArguments, state: TrainerState,
418423
control: TrainerControl):

0 commit comments

Comments
 (0)