Skip to content

Commit 9040378

Browse files
authored
[Trainer] Update trainer for more timer info
1 parent 48db6d4 commit 9040378

File tree

4 files changed

+230
-26
lines changed

4 files changed

+230
-26
lines changed

paddlenlp/trainer/plugins/timer.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2020-present the HuggingFace Inc. team.
2+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import time
17+
18+
import paddle
19+
20+
from paddlenlp.utils.log import logger
21+
22+
23+
class _Timer:
24+
"""Profile Timer for recording time taken by forward/ bacward/ reduce/ step."""
25+
26+
def __init__(self, name):
27+
self.name = name
28+
self.elapsed_ = 0.0
29+
self.started_ = False
30+
self.start_time = time.time()
31+
32+
def start(self):
33+
"""Start the timer."""
34+
assert not self.started_, "timer has already started"
35+
paddle.device.cuda.synchronize()
36+
self.start_time = time.time()
37+
self.started_ = True
38+
39+
def stop(self):
40+
"""Stop the timers."""
41+
assert self.started_, "timer is not started."
42+
paddle.device.cuda.synchronize()
43+
self.elapsed_ += time.time() - self.start_time
44+
self.started_ = False
45+
46+
def reset(self):
47+
"""Reset timer."""
48+
self.elapsed_ = 0.0
49+
self.started_ = False
50+
51+
def elapsed(self, reset=True):
52+
"""Calculate the elapsed time."""
53+
started_ = self.started_
54+
# If the timing in progress, end it first.
55+
if self.started_:
56+
self.stop()
57+
# Get the elapsed time.
58+
elapsed_ = self.elapsed_
59+
# Reset the elapsed time
60+
if reset:
61+
self.reset()
62+
# If timing was in progress, set it back.
63+
if started_:
64+
self.start()
65+
return elapsed_
66+
67+
68+
class Timers:
69+
"""Group of timers."""
70+
71+
def __init__(self):
72+
self.timers = {}
73+
74+
def __call__(self, name):
75+
if name not in self.timers:
76+
self.timers[name] = _Timer(name)
77+
return self.timers[name]
78+
79+
def write(self, names, writer, iteration, normalizer=1.0, reset=True):
80+
"""Write timers to a tensorboard writer"""
81+
assert normalizer > 0.0
82+
for name in names:
83+
value = self.timers[name].elapsed(reset=reset) / normalizer
84+
writer.add_scalar("timers/" + name, value, iteration)
85+
86+
def log(self, names, normalizer=1.0, reset=True):
87+
"""Log a group of timers."""
88+
assert normalizer > 0.0
89+
# string = "time (ms) / rate"
90+
string = "time (ms)"
91+
92+
time_dict = {}
93+
for name in names:
94+
time_dict[name] = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
95+
96+
# total_time = sum(list(time_dict.values()))
97+
# string += " | total_time : {:.2f} ".format(total_time)
98+
time_dict = sorted(time_dict.items(), key=lambda x: x[1], reverse=True)
99+
100+
for time_tuple in time_dict:
101+
name, value = time_tuple
102+
# string += " | {} : {:.2f} ({:.2f}%) ".format(name, value, value * 100.0 / total_time)
103+
string += " | {} : {:.2f}".format(name, value)
104+
return string
105+
106+
107+
_GLOBAL_TIMERS = None
108+
109+
110+
def get_timers():
111+
global _GLOBAL_TIMERS
112+
return _GLOBAL_TIMERS
113+
114+
115+
def set_timers():
116+
global _GLOBAL_TIMERS
117+
logger.info("enable PaddleNLP timer")
118+
_GLOBAL_TIMERS = Timers()
119+
120+
121+
def disable_timers():
122+
global _GLOBAL_TIMERS
123+
logger.info("disable PaddleNLP timer")
124+
_GLOBAL_TIMERS = None

paddlenlp/trainer/trainer.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from ..utils.import_utils import is_datasets_available
6666
from ..utils.log import logger
6767
from .integrations import get_reporting_integration_callbacks
68+
from .plugins.timer import get_timers, set_timers
6869
from .trainer_callback import (
6970
CallbackHandler,
7071
DefaultFlowCallback,
@@ -250,6 +251,9 @@ def __init__(
250251
self.train_dataset = train_dataset
251252
self.eval_dataset = eval_dataset
252253
self.tokenizer = tokenizer
254+
if not args.skip_profile_timer:
255+
set_timers()
256+
self.timers = get_timers()
253257

254258
self.model_wrapped = model
255259
self.model = model
@@ -410,7 +414,7 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None):
410414
if resume_from_checkpoint is None:
411415
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
412416

413-
if resume_from_checkpoint is not None:
417+
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
414418
if isinstance(self.model, LoRAModel):
415419
weight_name = LORA_WEIGHTS_NAME
416420
elif isinstance(self.model, PrefixModelForCausalLM):
@@ -435,6 +439,8 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None):
435439

436440
# release memory
437441
del state_dict
442+
elif resume_from_checkpoint is not None:
443+
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
438444

439445
def train(
440446
self,
@@ -466,7 +472,7 @@ def train(
466472
if resume_from_checkpoint is None:
467473
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
468474

469-
if resume_from_checkpoint is not None:
475+
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
470476
if isinstance(self.model, LoRAModel):
471477
weight_name = LORA_WEIGHTS_NAME
472478
elif isinstance(self.model, PrefixModelForCausalLM):
@@ -490,6 +496,8 @@ def train(
490496

491497
# release memory
492498
del state_dict
499+
elif resume_from_checkpoint is not None:
500+
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
493501

494502
train_dataloader = self.get_train_dataloader()
495503

@@ -629,6 +637,12 @@ def train(
629637
steps_in_epoch = (
630638
len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
631639
)
640+
if len_dataloader is not None:
641+
if self.args.gradient_accumulation_steps > len(epoch_iterator):
642+
logger.warning(
643+
f"changing accumulation step from `{self.args.gradient_accumulation_steps}` to `{len(epoch_iterator)}` to avoid, cross epoch accumulate"
644+
)
645+
self.args.gradient_accumulation_steps = len(epoch_iterator)
632646

633647
self.callback_handler.model = self.model
634648
self.callback_handler.optimizer = self.optimizer
@@ -651,18 +665,22 @@ def train(
651665

652666
npu_accelerate_plugin(self.optimizer)
653667

668+
self.timers and self.timers("read-data").start()
669+
654670
for epoch in range(epochs_trained, num_train_epochs):
655671
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
656672
train_dataloader.batch_sampler, DistributedBatchSampler
657673
):
658674
train_dataloader.batch_sampler.set_epoch(epoch)
659675

660-
step = -1
676+
step_control = 0 # used in loop control, reset to 0 after every step
661677
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
662678

663679
for step, inputs in enumerate(epoch_iterator):
680+
self.timers and self.timers("read-data").stop()
664681
os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
665682
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)
683+
666684
# Skip past any already trained steps if resuming training
667685
# for paddlenlp.utils.batch_sampler.DistributedBatchSampler
668686
# We use consumed_samples to reset the status
@@ -687,8 +705,9 @@ def train(
687705
steps_trained_progress_bar.close()
688706
steps_trained_progress_bar = None
689707

690-
if step % args.gradient_accumulation_steps == 0:
708+
if step_control % args.gradient_accumulation_steps == 0:
691709
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
710+
self.timers and self.timers("forward-backward").start()
692711

693712
dp_enabled = (
694713
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
@@ -706,14 +725,13 @@ def train(
706725
availiable_no_sync = dp_enabled and not forbidden_no_sync
707726

708727
is_no_sync = (
709-
((step + 1) % args.gradient_accumulation_steps != 0)
728+
((step_control + 1) % args.gradient_accumulation_steps != 0)
710729
and availiable_no_sync
711730
and args._no_sync_in_gradient_accumulation
712731
) or (args.recompute and availiable_no_sync)
713732
# sharding
714733
# stage1. the same as ddp
715734
# stage2. manualy collect gradient on dp group
716-
717735
if is_no_sync:
718736
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
719737
with model.no_sync():
@@ -723,15 +741,18 @@ def train(
723741

724742
tr_loss += tr_loss_step
725743

726-
if (step + 1) % args.gradient_accumulation_steps == 0 or (
744+
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
727745
# last step in epoch but step is always smaller than gradient_accumulation_steps
728746
steps_in_epoch <= args.gradient_accumulation_steps
729747
and (step + 1) == steps_in_epoch
730748
):
749+
self.timers and self.timers("forward-backward").stop()
731750
# Maunally collect gradients when group_sharded_parallel can't accept dp_group
732751
# Case 1: Use sharding stage 2/3 with dp
733752
# Case 2: Use recompute and dp
734753
# local_rank != -1 don't means dp in networks.
754+
self.timers and self.timers("all-reduce").start()
755+
735756
if self.sharding and ShardingOption.SHARD_OP not in self.args.sharding:
736757
if self.args.data_parallel_degree > 1 and not is_dp_group_support_in_group_sharded_parallel():
737758
fused_allreduce_gradients(model.parameters(), fleet.get_hybrid_communicate_group())
@@ -763,15 +784,18 @@ def train(
763784

764785
if self.optimizer._dp_enable:
765786
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
787+
self.timers and self.timers("all-reduce").stop()
788+
self.timers and self.timers("optimizer-step").start()
766789

767790
# pipeline parallel mode, handle gradient merge here
768791
if args.pipeline_parallel_degree > 1 and enable_delay_scale_loss:
769792
for p in model._layers.parameters():
770-
if hasattr(p, "main_grad") and p.main_grad is not None:
771-
assert p.grad is None
772-
p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps)
773-
elif p.grad is not None:
774-
p.grad.scale_(1.0 / self.args.gradient_accumulation_steps)
793+
with paddle.no_grad():
794+
if hasattr(p, "main_grad") and p.main_grad is not None:
795+
assert p.grad is None
796+
p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps)
797+
elif p.grad is not None:
798+
p.grad.scale_(1.0 / self.args.gradient_accumulation_steps)
775799

776800
# Optimizer step
777801
self.callback_handler.on_optimizer_begin(
@@ -793,6 +817,8 @@ def train(
793817
else:
794818
self.optimizer.step()
795819

820+
self.timers and self.timers("optimizer-step").stop()
821+
796822
if optimizer_was_run:
797823
self.lr_scheduler.step()
798824

@@ -802,15 +828,18 @@ def train(
802828
)
803829

804830
self.state.global_step += 1
805-
self.state.epoch = epoch + (step + 1) / steps_in_epoch
806-
831+
self.state.epoch = epoch + self.state.global_step / steps_in_epoch
807832
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
808833
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
834+
self._print_timer()
835+
step_control = 0
809836
else:
810837
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
838+
step_control += 1
811839

812840
if self.control.should_epoch_stop or self.control.should_training_stop:
813841
break
842+
self.timers and self.timers("read-data").start()
814843

815844
if step < 0:
816845
logger.warning(
@@ -905,7 +934,33 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
905934

906935
def _set_state_dict_in_model(self, state_dict):
907936
# TODO @ZHUI paddle need return the results of set_state_dict.
908-
self.model.set_state_dict(state_dict)
937+
logger.info(f"set state-dict :{self.model.set_state_dict(state_dict)}")
938+
939+
def _print_timer(self):
940+
"""print timer and clear states"""
941+
paddle_timer_info = ""
942+
try:
943+
from paddle.distributed.fleet.utils.timer_helper import (
944+
get_timers as paddle_get_timers,
945+
)
946+
947+
paddle_pipeline_timers = paddle_get_timers()
948+
for name, timer in paddle_pipeline_timers.timers.items():
949+
elapsed_time = timer.elapsed(reset=False) * 1000.0
950+
paddle_timer_info += f" | {name}: {elapsed_time:.2f}"
951+
paddle_pipeline_timers.log(paddle_pipeline_timers.timers.keys(), reset=True)
952+
except ImportError: # paddle version too old, timer not support
953+
logger.warning(f"paddle version:{paddle._git_commit__} does not support pipeline timer")
954+
except AssertionError: # paddle timer not enabled
955+
pass
956+
957+
if self.timers is not None:
958+
timer_info = self.timers.log(self.timers.timers.keys(), reset=True)
959+
else:
960+
timer_info = ""
961+
962+
if timer_info or paddle_timer_info:
963+
logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}")
909964

910965
def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
911966
if self.control.should_log:
@@ -1615,7 +1670,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
16151670
self._pp_data_buffer = []
16161671

16171672
model.train()
1618-
16191673
# hack pipeline-layers
16201674
# since the pipeline layer will check input is valid every iter.
16211675
# in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement.
@@ -1872,6 +1926,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
18721926
self.lr_scheduler.set_state_dict(paddle.load(os.path.join(checkpoint, SCHEDULER_NAME)))
18731927
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
18741928
self.scaler.load_state_dict(paddle.load(os.path.join(checkpoint, SCALER_NAME), return_numpy=True))
1929+
else:
1930+
raise ValueError(
1931+
f"optimizer-state-dict not found, opt:{os.path.join(checkpoint, optimizer_name)} scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}"
1932+
)
18751933

18761934
def log(self, logs: Dict[str, float], **kwargs) -> None:
18771935
"""
@@ -1883,9 +1941,21 @@ def log(self, logs: Dict[str, float], **kwargs) -> None:
18831941
logs (`Dict[str, float]`):
18841942
The values to log.
18851943
"""
1944+
1945+
try:
1946+
from paddle.distributed.fleet.utils.timer_helper import (
1947+
get_timers as paddle_get_timers,
1948+
)
1949+
1950+
paddle_pipeline_timers = paddle_get_timers()
1951+
except ImportError: # paddle version too old, timer not support
1952+
logger.warning(f"paddle version:{paddle._git_commit__} does not support pipeline timer")
1953+
except AssertionError:
1954+
paddle_pipeline_timers = None
1955+
kwargs.update(timer=self.timers, paddle_pipeline_timers=paddle_pipeline_timers)
1956+
18861957
if self.state.epoch is not None:
18871958
logs["epoch"] = round(self.state.epoch, 4)
1888-
18891959
output = {**logs, **{"step": self.state.global_step}}
18901960
self.state.log_history.append(output)
18911961
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs, **kwargs)

0 commit comments

Comments
 (0)