Skip to content

Commit b9d9d1d

Browse files
authored
RuntimeTimer for the toolkit (#7913) (#7921)
* RuntimeTimer for the toolekit * RuntimeTimer for the toolekit * reformat * fix timer and load checkpoints * remove reset
1 parent 9c5ff0d commit b9d9d1d

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

paddlenlp/trainer/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
from .argparser import *
16-
from .training_args import *
1716
from .compression_args import *
17+
from .plugins.timer import *
1818
from .trainer import *
1919
from .trainer_callback import *
20-
from .trainer_utils import *
2120
from .trainer_compress import *
22-
from .training_args_seq2seq import *
2321
from .trainer_seq2seq import *
22+
from .trainer_utils import *
23+
from .training_args import *
24+
from .training_args_seq2seq import *

paddlenlp/trainer/plugins/timer.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def __init__(self, name):
3131

3232
def start(self):
3333
"""Start the timer."""
34-
assert not self.started_, "timer has already started"
34+
assert not self.started_, f"{self.name} timer has already started"
3535
paddle.device.synchronize()
3636
self.start_time = time.time()
3737
self.started_ = True
3838

3939
def stop(self):
4040
"""Stop the timers."""
41-
assert self.started_, "timer is not started."
41+
assert self.started_, f"{self.name} timer is not started."
4242
paddle.device.synchronize()
4343
self.elapsed_ += time.time() - self.start_time
4444
self.started_ = False
@@ -65,6 +65,32 @@ def elapsed(self, reset=True):
6565
return elapsed_
6666

6767

68+
class RuntimeTimer:
69+
"""A timer that can be dynamically adjusted during runtime."""
70+
71+
def __init__(self, name):
72+
self.timer = _Timer(name)
73+
74+
def start(self, name):
75+
"""Start the RuntimeTimer."""
76+
self.timer.name = name
77+
self.timer.start()
78+
79+
def stop(self):
80+
"""Stop the RuntimeTimer."""
81+
self.timer.stop()
82+
83+
def log(self):
84+
"""Log, stop and reset the RuntimeTimer."""
85+
runtime = self.timer.elapsed(reset=True)
86+
if self.timer.started_ is True:
87+
self.timer.stop()
88+
self.timer.reset()
89+
90+
string = "[timelog] {}: {:.2f}s ({}) ".format(self.timer.name, runtime, time.strftime("%Y-%m-%d %H:%M:%S"))
91+
return string
92+
93+
6894
class Timers:
6995
"""Group of timers."""
7096

paddlenlp/trainer/trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
from ..utils.log import logger
105105
from .argparser import strtobool
106106
from .integrations import get_reporting_integration_callbacks
107-
from .plugins.timer import get_timers, set_timers
107+
from .plugins.timer import RuntimeTimer, get_timers, set_timers
108108
from .plugins.unified_checkpoint import (
109109
load_unified_checkpoint,
110110
load_unified_optimizer,
@@ -304,6 +304,7 @@ def __init__(
304304
if not args.skip_profile_timer:
305305
set_timers()
306306
self.timers = get_timers()
307+
self.runtime_timer = RuntimeTimer("RuntimeTimer")
307308

308309
self.model_wrapped = model
309310
self.model = model
@@ -506,6 +507,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
506507
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
507508
of [`Trainer`]. Only load model state dict.
508509
"""
510+
self.runtime_timer.start("checkpoint loading time")
509511
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
510512

511513
# Load potential model checkpoint
@@ -531,10 +533,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
531533
safe_serialization=True,
532534
)
533535
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
536+
self.runtime_timer.stop()
534537
return
535538

536539
if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
537540
self._load_from_peft_checkpoint(resume_from_checkpoint)
541+
self.runtime_timer.stop()
538542
return
539543

540544
weight_name = PADDLE_WEIGHTS_NAME
@@ -584,6 +588,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
584588

585589
elif resume_from_checkpoint is not None:
586590
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
591+
self.runtime_timer.stop()
587592

588593
def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
589594
# In the sharded mode, should invoke _load_from_checkpoint after _wrap_model.
@@ -639,7 +644,6 @@ def train(
639644

640645
# memory metrics - must set up as early as possible
641646
self._memory_tracker.start()
642-
643647
if not self.args.should_load_sharding_stage1_model:
644648
self._load_from_checkpoint(resume_from_checkpoint)
645649

@@ -695,6 +699,7 @@ def train(
695699

696700
if self.args.should_load_sharding_stage1_model:
697701
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
702+
698703
elif self.args.should_save_sharding_stage1_model:
699704
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
700705
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
@@ -718,6 +723,8 @@ def train(
718723
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
719724
self._load_optimizer_and_scheduler(resume_from_checkpoint)
720725

726+
logger.info(f"{self.runtime_timer.log()}")
727+
721728
logger.info("***** Running training *****")
722729
logger.info(f" Num examples = {num_examples:,}")
723730
logger.info(f" Num Epochs = {num_train_epochs}")
@@ -1239,6 +1246,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12391246
paddle.device.cuda.synchronize()
12401247

12411248
self._save_checkpoint(model, metrics=metrics)
1249+
logger.info(f"{self.runtime_timer.log()}")
12421250
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
12431251

12441252
def _get_learning_rate(self):
@@ -2040,7 +2048,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
20402048

20412049
def _save_checkpoint(self, model, metrics=None):
20422050
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2043-
2051+
self.runtime_timer.start("checkpoint saving time")
20442052
# Save model checkpoint
20452053
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
20462054

@@ -2086,6 +2094,7 @@ def _save_checkpoint(self, model, metrics=None):
20862094
if self.do_grad_scaling:
20872095
paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
20882096

2097+
self.runtime_timer.stop()
20892098
# Determine the new best metric / best model checkpoint
20902099
if metrics is not None and self.args.metric_for_best_model is not None:
20912100
metric_to_check = self.args.metric_for_best_model
@@ -2304,10 +2313,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
23042313

23052314
def _load_optimizer_and_scheduler(self, checkpoint):
23062315
"""If optimizer and scheduler states exist, load them."""
2316+
self.runtime_timer.start("checkpoint loading time")
23072317
if checkpoint is None:
2318+
self.runtime_timer.stop()
23082319
return
23092320

23102321
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
2322+
self.runtime_timer.stop()
23112323
return
23122324

23132325
opt_state_dict = None
@@ -2366,6 +2378,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
23662378
self.scaler.load_state_dict(
23672379
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
23682380
)
2381+
self.runtime_timer.stop()
23692382

23702383
def log(self, logs: Dict[str, float], **kwargs) -> None:
23712384
"""

0 commit comments

Comments
 (0)