127
127
PREFIX_CHECKPOINT_DIR ,
128
128
EvalLoopOutput ,
129
129
EvalPrediction ,
130
+ IntervalStrategy ,
130
131
IterableDatasetShard ,
131
132
OptimizerNames ,
132
133
PredictionOutput ,
139
140
get_scheduler ,
140
141
has_length ,
141
142
set_seed ,
143
+ should_skip_data ,
142
144
speed_metrics ,
143
145
)
144
146
from .training_args import TrainingArguments
@@ -287,9 +289,16 @@ def __init__(
287
289
288
290
# Seed must be set before instantiating the model when using model
289
291
set_seed (seed = self .args .seed )
290
-
292
+ self ._skip_global_steps = 0 # total skip global steps
293
+ self ._skip_steps_since_last_logged = 0 # skip steps since last logged
291
294
if model is None :
292
- raise RuntimeError ("`Trainer` requires either a `model` or `model_init` argument" )
295
+ logger .warning ("Model is None." )
296
+ self .model = None
297
+ self .train_dataset = train_dataset
298
+ self .tokenizer = tokenizer
299
+ default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding (tokenizer )
300
+ self .data_collator = data_collator if data_collator is not None else default_collator
301
+ return
293
302
294
303
if self .args .to_static :
295
304
model = paddle .jit .to_static (model )
@@ -945,6 +954,7 @@ def _inner_training_loop(
945
954
step_control = 0 # used in loop control, reset to 0 after every step
946
955
self .control = self .callback_handler .on_epoch_begin (args , self .state , self .control )
947
956
957
+ step = - 1
948
958
for step , inputs in enumerate (epoch_iterator ):
949
959
if self .args .use_hybrid_parallel and self .args .sep_parallel_degree > 1 :
950
960
inputs = split_inputs_sequence_dim (inputs )
@@ -981,6 +991,44 @@ def _inner_training_loop(
981
991
steps_trained_progress_bar .close ()
982
992
steps_trained_progress_bar = None
983
993
994
+ if should_skip_data (self .state .global_step , self .args .skip_data_intervals ):
995
+ # skip this step
996
+
997
+ if (step_control + 1 ) % self .args .gradient_accumulation_steps == 0 or (
998
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
999
+ steps_in_epoch <= args .gradient_accumulation_steps
1000
+ and (step + 1 ) == steps_in_epoch
1001
+ ):
1002
+ # update current global step and skip step
1003
+ self .state .global_step += 1
1004
+ self ._skip_global_steps += 1
1005
+ self ._skip_steps_since_last_logged += 1
1006
+
1007
+ self .state .epoch = epoch + (step + 1 ) / steps_in_epoch
1008
+
1009
+ if self .state .global_step == 1 and self .args .logging_first_step :
1010
+ self .control .should_log = True
1011
+ if (
1012
+ self .args .logging_strategy == IntervalStrategy .STEPS
1013
+ and self .state .global_step % self .args .logging_steps == 0
1014
+ ):
1015
+ self .control .should_log = True
1016
+
1017
+ self .control .should_evaluate = False
1018
+ self .control .should_save = False
1019
+
1020
+ # log loss and memeory usage
1021
+ self ._maybe_log_save_evaluate (tr_loss , model , epoch , ignore_keys_for_eval , inputs = inputs )
1022
+ self ._print_timer ()
1023
+ step_control = 0
1024
+ else :
1025
+ step_control += 1
1026
+ if self .state .global_step >= self .state .max_steps :
1027
+ break
1028
+
1029
+ self .timers and self .timers ("read-data" ).start ()
1030
+ continue
1031
+
984
1032
if step_control % args .gradient_accumulation_steps == 0 :
985
1033
self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
986
1034
self .timers and self .timers ("forward-backward" ).start ()
@@ -1202,7 +1250,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
1202
1250
)
1203
1251
1204
1252
self ._total_loss_scalar += tr_loss .item ()
1205
- train_loss = self ._total_loss_scalar / self .state .global_step
1253
+
1254
+ # In case all steps were skipped, the total loss is set to 0.
1255
+ if self .state .global_step == self ._skip_global_steps :
1256
+ logger .info ("All steps were skipped, the total loss is set to 0." )
1257
+ train_loss = 0.0
1258
+ else :
1259
+ train_loss = self ._total_loss_scalar / (self .state .global_step - self ._skip_global_steps )
1206
1260
1207
1261
metrics = speed_metrics ("train" , start_time , num_samples = num_train_samples , num_steps = self .state .max_steps )
1208
1262
@@ -1321,15 +1375,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
1321
1375
if self .control .should_log :
1322
1376
1323
1377
logs : Dict [str , float ] = {}
1324
-
1378
+ num_steps = self .state .global_step - self ._globalstep_last_logged - self ._skip_steps_since_last_logged
1379
+ self ._skip_steps_since_last_logged = 0
1325
1380
# all_gather + mean() to get average loss over all processes
1326
1381
avg_loss = self ._nested_gather (tr_loss ).mean ()
1327
1382
tr_loss_scalar = self ._get_item_from_loss (avg_loss )
1328
1383
1329
1384
# reset tr_loss to zero
1330
1385
tr_loss .subtract_ (tr_loss )
1386
+ # set loss to zero if all steps are skipped since last log
1387
+ if num_steps == 0 :
1388
+ logs ["loss" ] = 0.0
1389
+ else :
1390
+ logs ["loss" ] = round (tr_loss_scalar / num_steps , 8 )
1331
1391
1332
- logs ["loss" ] = round (tr_loss_scalar / (self .state .global_step - self ._globalstep_last_logged ), 8 )
1333
1392
logs ["learning_rate" ] = float ("{0:.3e}" .format (self ._get_learning_rate ()))
1334
1393
logs ["global_step" ] = int (self .state .global_step )
1335
1394
if in_auto_parallel_align_mode ():
@@ -1352,7 +1411,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
1352
1411
total_train_batch_size = (
1353
1412
self .args .train_batch_size * self .args .gradient_accumulation_steps * self .args .dataset_world_size
1354
1413
)
1355
- num_steps = self . state . global_step - self . _globalstep_last_logged
1414
+
1356
1415
seq_length = None
1357
1416
model_flops = None
1358
1417
if getattr (self , "is_pretraining" , False ) and hasattr (self .model , "config" ):
@@ -1362,16 +1421,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
1362
1421
except NotImplementedError :
1363
1422
model_flops = None
1364
1423
1365
- logs .update (
1366
- speed_metrics (
1367
- "interval" ,
1368
- self ._globalstep_last_start_time ,
1369
- num_samples = total_train_batch_size * num_steps ,
1370
- num_steps = num_steps ,
1371
- seq_length = seq_length ,
1372
- model_flops = model_flops ,
1424
+ # Do not log speed metrics if all steps are skipped since last log.
1425
+ if num_steps > 0 :
1426
+ logs .update (
1427
+ speed_metrics (
1428
+ "interval" ,
1429
+ self ._globalstep_last_start_time ,
1430
+ num_samples = total_train_batch_size * num_steps ,
1431
+ num_steps = num_steps ,
1432
+ seq_length = seq_length ,
1433
+ model_flops = model_flops ,
1434
+ )
1373
1435
)
1374
- )
1375
1436
1376
1437
self ._total_loss_scalar += tr_loss_scalar
1377
1438
self ._globalstep_last_logged = self .state .global_step
@@ -3255,7 +3316,7 @@ def _set_signature_columns_if_needed(self):
3255
3316
self ._signature_columns += list (set (["label" , "label_ids" ] + self .label_names ))
3256
3317
3257
3318
def _remove_unused_columns (self , dataset : "datasets.Dataset" , description : Optional [str ] = None ):
3258
- if not self .args .remove_unused_columns :
3319
+ if not self .args .remove_unused_columns or self . model is None :
3259
3320
return dataset
3260
3321
if self ._signature_columns is None :
3261
3322
# Inspect model forward signature to keep only the arguments it accepts.
0 commit comments