@@ -177,7 +177,7 @@ class Trainer:
177
177
def __init__ (
178
178
self ,
179
179
model : Union [PretrainedModel , nn .Layer ] = None ,
180
- criterion : Union [ nn .Layer ] = None ,
180
+ criterion : nn .Layer = None ,
181
181
args : TrainingArguments = None ,
182
182
data_collator : Optional [DataCollator ] = None ,
183
183
train_dataset : Optional [Dataset ] = None ,
@@ -241,6 +241,7 @@ def __init__(
241
241
self .state = TrainerState ()
242
242
self .control = TrainerControl ()
243
243
self ._signature_columns = None
244
+ self .optimizer_grouped_parameters = None
244
245
245
246
if (self .sharding is not None ) and (self .optimizer is not None
246
247
or self .lr_scheduler is not None ):
@@ -710,9 +711,11 @@ def train(
710
711
711
712
self .control = self .callback_handler .on_step_end (
712
713
args , self .state , self .control )
713
-
714
- self ._maybe_log_save_evaluate (tr_loss , model , epoch ,
715
- ignore_keys_for_eval )
714
+ self ._maybe_log_save_evaluate (tr_loss ,
715
+ model ,
716
+ epoch ,
717
+ ignore_keys_for_eval ,
718
+ inputs = inputs )
716
719
else :
717
720
self .control = self .callback_handler .on_substep_end (
718
721
args , self .state , self .control )
@@ -730,8 +733,11 @@ def train(
730
733
731
734
self .control = self .callback_handler .on_epoch_end (
732
735
args , self .state , self .control )
733
- self ._maybe_log_save_evaluate (tr_loss , model , epoch ,
734
- ignore_keys_for_eval )
736
+ self ._maybe_log_save_evaluate (tr_loss ,
737
+ model ,
738
+ epoch ,
739
+ ignore_keys_for_eval ,
740
+ inputs = inputs )
735
741
736
742
if self .control .should_training_stop :
737
743
break
@@ -805,7 +811,7 @@ def _set_state_dict_in_model(self, state_dict):
805
811
self .model .set_state_dict (state_dict )
806
812
807
813
def _maybe_log_save_evaluate (self , tr_loss , model , epoch ,
808
- ignore_keys_for_eval ):
814
+ ignore_keys_for_eval , ** kwargs ):
809
815
if self .control .should_log :
810
816
811
817
logs : Dict [str , float ] = {}
@@ -836,7 +842,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch,
836
842
self ._globalstep_last_logged = self .state .global_step
837
843
self ._globalstep_last_start_time = time .time ()
838
844
839
- self .log (logs )
845
+ self .log (logs , ** kwargs )
840
846
841
847
metrics = None
842
848
if self .control .should_evaluate :
@@ -1024,11 +1030,16 @@ def create_optimizer(self, lr_scheduler=None):
1024
1030
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1025
1031
"""
1026
1032
if self .optimizer is None :
1027
- decay_parameters = [
1028
- p .name for n , p in self .model .named_parameters ()
1029
- if not any (nd in n for nd in ["bias" , "norm" ])
1030
- ]
1031
- apply_decay_param_fun = lambda x : x in decay_parameters
1033
+ if self .optimizer_grouped_parameters is not None :
1034
+ params = self .optimizer_grouped_parameters
1035
+ apply_decay_param_fun = None
1036
+ else :
1037
+ params = self .model .parameters ()
1038
+ decay_parameters = [
1039
+ p .name for n , p in self .model .named_parameters ()
1040
+ if not any (nd in n for nd in ["bias" , "norm" ])
1041
+ ]
1042
+ apply_decay_param_fun = lambda x : x in decay_parameters
1032
1043
1033
1044
optimizer_cls , optimizer_kwargs = Trainer .get_optimizer_cls_and_kwargs (
1034
1045
self .args )
@@ -1038,22 +1049,24 @@ def create_optimizer(self, lr_scheduler=None):
1038
1049
self .optimizer = DygraphShardingOptimizer (
1039
1050
hcg = fleet .get_hybrid_communicate_group (),
1040
1051
user_defined_strategy = None ,
1041
- params = self . model . parameters () ,
1052
+ params = params ,
1042
1053
inner_optimizer_class = optimizer_cls ,
1043
1054
learning_rate = self .lr_scheduler
1044
1055
if lr_scheduler is None else lr_scheduler ,
1045
1056
apply_decay_param_fun = apply_decay_param_fun ,
1046
1057
weight_decay = self .args .weight_decay ,
1047
- grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ),
1058
+ grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm )
1059
+ if self .args .max_grad_norm > 0 else None ,
1048
1060
** optimizer_kwargs )
1049
1061
else :
1050
1062
self .optimizer = optimizer_cls (
1051
1063
learning_rate = self .lr_scheduler
1052
1064
if lr_scheduler is None else lr_scheduler ,
1053
1065
apply_decay_param_fun = apply_decay_param_fun ,
1054
- parameters = self . model . parameters () ,
1066
+ parameters = params ,
1055
1067
weight_decay = self .args .weight_decay ,
1056
- grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ),
1068
+ grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm )
1069
+ if self .args .max_grad_norm > 0 else None ,
1057
1070
** optimizer_kwargs )
1058
1071
1059
1072
return self .optimizer
@@ -1429,6 +1442,10 @@ def _save_checkpoint(self, model, metrics=None):
1429
1442
if self .args .should_save :
1430
1443
self ._rotate_checkpoints (use_mtime = True , output_dir = run_dir )
1431
1444
1445
+ def set_optimizer_grouped_parameters (self ,
1446
+ optimizer_grouped_parameters = None ):
1447
+ self .optimizer_grouped_parameters = optimizer_grouped_parameters
1448
+
1432
1449
def _sorted_checkpoints (self ,
1433
1450
output_dir = None ,
1434
1451
checkpoint_prefix = PREFIX_CHECKPOINT_DIR ,
@@ -1553,7 +1570,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
1553
1570
paddle .load (os .path .join (checkpoint , SCALER_NAME ),
1554
1571
return_numpy = True ))
1555
1572
1556
- def log (self , logs : Dict [str , float ]) -> None :
1573
+ def log (self , logs : Dict [str , float ], ** kwargs ) -> None :
1557
1574
"""
1558
1575
Log `logs` on the various objects watching training.
1559
1576
@@ -1569,7 +1586,8 @@ def log(self, logs: Dict[str, float]) -> None:
1569
1586
output = {** logs , ** {"step" : self .state .global_step }}
1570
1587
self .state .log_history .append (output )
1571
1588
self .control = self .callback_handler .on_log (self .args , self .state ,
1572
- self .control , logs )
1589
+ self .control , logs ,
1590
+ ** kwargs )
1573
1591
1574
1592
def evaluate (
1575
1593
self ,
0 commit comments