104
104
from ..utils .log import logger
105
105
from .argparser import strtobool
106
106
from .integrations import get_reporting_integration_callbacks
107
- from .plugins .timer import get_timers , set_timers
107
+ from .plugins .timer import RuntimeTimer , get_timers , set_timers
108
108
from .plugins .unified_checkpoint import (
109
109
load_unified_checkpoint ,
110
110
load_unified_optimizer ,
@@ -304,6 +304,7 @@ def __init__(
304
304
if not args .skip_profile_timer :
305
305
set_timers ()
306
306
self .timers = get_timers ()
307
+ self .runtime_timer = RuntimeTimer ("RuntimeTimer" )
307
308
308
309
self .model_wrapped = model
309
310
self .model = model
@@ -506,6 +507,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
506
507
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
507
508
of [`Trainer`]. Only load model state dict.
508
509
"""
510
+ self .runtime_timer .start ("checkpoint loading time" )
509
511
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
510
512
511
513
# Load potential model checkpoint
@@ -531,10 +533,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
531
533
safe_serialization = True ,
532
534
)
533
535
logger .info (f"Loading model from { resume_from_checkpoint } using unified checkpoint." )
536
+ self .runtime_timer .stop ()
534
537
return
535
538
536
539
if isinstance (self .model , LoRAModel ) or isinstance (self .model , PrefixModelForCausalLM ):
537
540
self ._load_from_peft_checkpoint (resume_from_checkpoint )
541
+ self .runtime_timer .stop ()
538
542
return
539
543
540
544
weight_name = PADDLE_WEIGHTS_NAME
@@ -584,6 +588,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
584
588
585
589
elif resume_from_checkpoint is not None :
586
590
logger .info (f"not loading ckpt :{ self .args .dataset_rank } " )
591
+ self .runtime_timer .stop ()
587
592
588
593
def _wrap_model_and_load_sharded_checkpoint (self , resume_from_checkpoint ):
589
594
# In the sharded mode, should invoke _load_from_checkpoint after _wrap_model.
@@ -639,7 +644,6 @@ def train(
639
644
640
645
# memory metrics - must set up as early as possible
641
646
self ._memory_tracker .start ()
642
-
643
647
if not self .args .should_load_sharding_stage1_model :
644
648
self ._load_from_checkpoint (resume_from_checkpoint )
645
649
@@ -695,6 +699,7 @@ def train(
695
699
696
700
if self .args .should_load_sharding_stage1_model :
697
701
model = self ._wrap_model_and_load_sharded_checkpoint (resume_from_checkpoint )
702
+
698
703
elif self .args .should_save_sharding_stage1_model :
699
704
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
700
705
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
@@ -718,6 +723,8 @@ def train(
718
723
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
719
724
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
720
725
726
+ logger .info (f"{ self .runtime_timer .log ()} " )
727
+
721
728
logger .info ("***** Running training *****" )
722
729
logger .info (f" Num examples = { num_examples :,} " )
723
730
logger .info (f" Num Epochs = { num_train_epochs } " )
@@ -1239,6 +1246,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
1239
1246
paddle .device .cuda .synchronize ()
1240
1247
1241
1248
self ._save_checkpoint (model , metrics = metrics )
1249
+ logger .info (f"{ self .runtime_timer .log ()} " )
1242
1250
self .control = self .callback_handler .on_save (self .args , self .state , self .control )
1243
1251
1244
1252
def _get_learning_rate (self ):
@@ -2040,7 +2048,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
2040
2048
2041
2049
def _save_checkpoint (self , model , metrics = None ):
2042
2050
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2043
-
2051
+ self . runtime_timer . start ( "checkpoint saving time" )
2044
2052
# Save model checkpoint
2045
2053
checkpoint_folder = f"{ PREFIX_CHECKPOINT_DIR } -{ self .state .global_step } "
2046
2054
@@ -2086,6 +2094,7 @@ def _save_checkpoint(self, model, metrics=None):
2086
2094
if self .do_grad_scaling :
2087
2095
paddle .save (self .scaler .state_dict (), os .path .join (output_dir , SCALER_NAME ))
2088
2096
2097
+ self .runtime_timer .stop ()
2089
2098
# Determine the new best metric / best model checkpoint
2090
2099
if metrics is not None and self .args .metric_for_best_model is not None :
2091
2100
metric_to_check = self .args .metric_for_best_model
@@ -2304,10 +2313,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
2304
2313
2305
2314
def _load_optimizer_and_scheduler (self , checkpoint ):
2306
2315
"""If optimizer and scheduler states exist, load them."""
2316
+ self .runtime_timer .start ("checkpoint loading time" )
2307
2317
if checkpoint is None :
2318
+ self .runtime_timer .stop ()
2308
2319
return
2309
2320
2310
2321
if (not self .args .should_load_sharding_stage1_model ) and self .args .ignore_load_lr_and_optim :
2322
+ self .runtime_timer .stop ()
2311
2323
return
2312
2324
2313
2325
opt_state_dict = None
@@ -2366,6 +2378,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
2366
2378
self .scaler .load_state_dict (
2367
2379
paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
2368
2380
)
2381
+ self .runtime_timer .stop ()
2369
2382
2370
2383
def log (self , logs : Dict [str , float ], ** kwargs ) -> None :
2371
2384
"""
0 commit comments