65
65
from ..utils .import_utils import is_datasets_available
66
66
from ..utils .log import logger
67
67
from .integrations import get_reporting_integration_callbacks
68
+ from .plugins .timer import get_timers , set_timers
68
69
from .trainer_callback import (
69
70
CallbackHandler ,
70
71
DefaultFlowCallback ,
@@ -250,6 +251,9 @@ def __init__(
250
251
self .train_dataset = train_dataset
251
252
self .eval_dataset = eval_dataset
252
253
self .tokenizer = tokenizer
254
+ if not args .skip_profile_timer :
255
+ set_timers ()
256
+ self .timers = get_timers ()
253
257
254
258
self .model_wrapped = model
255
259
self .model = model
@@ -410,7 +414,7 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None):
410
414
if resume_from_checkpoint is None :
411
415
raise ValueError (f"No valid checkpoint found in output directory ({ self .args .output_dir } )" )
412
416
413
- if resume_from_checkpoint is not None :
417
+ if resume_from_checkpoint is not None and self . args . dataset_rank == 0 :
414
418
if isinstance (self .model , LoRAModel ):
415
419
weight_name = LORA_WEIGHTS_NAME
416
420
elif isinstance (self .model , PrefixModelForCausalLM ):
@@ -435,6 +439,8 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None):
435
439
436
440
# release memory
437
441
del state_dict
442
+ elif resume_from_checkpoint is not None :
443
+ logger .info (f"not loading ckpt :{ self .args .dataset_rank } " )
438
444
439
445
def train (
440
446
self ,
@@ -466,7 +472,7 @@ def train(
466
472
if resume_from_checkpoint is None :
467
473
raise ValueError (f"No valid checkpoint found in output directory ({ args .output_dir } )" )
468
474
469
- if resume_from_checkpoint is not None :
475
+ if resume_from_checkpoint is not None and self . args . dataset_rank == 0 :
470
476
if isinstance (self .model , LoRAModel ):
471
477
weight_name = LORA_WEIGHTS_NAME
472
478
elif isinstance (self .model , PrefixModelForCausalLM ):
@@ -490,6 +496,8 @@ def train(
490
496
491
497
# release memory
492
498
del state_dict
499
+ elif resume_from_checkpoint is not None :
500
+ logger .info (f"not loading ckpt :{ self .args .dataset_rank } " )
493
501
494
502
train_dataloader = self .get_train_dataloader ()
495
503
@@ -629,6 +637,12 @@ def train(
629
637
steps_in_epoch = (
630
638
len (epoch_iterator ) if len_dataloader is not None else args .max_steps * args .gradient_accumulation_steps
631
639
)
640
+ if len_dataloader is not None :
641
+ if self .args .gradient_accumulation_steps > len (epoch_iterator ):
642
+ logger .warning (
643
+ f"changing accumulation step from `{ self .args .gradient_accumulation_steps } ` to `{ len (epoch_iterator )} ` to avoid, cross epoch accumulate"
644
+ )
645
+ self .args .gradient_accumulation_steps = len (epoch_iterator )
632
646
633
647
self .callback_handler .model = self .model
634
648
self .callback_handler .optimizer = self .optimizer
@@ -651,18 +665,22 @@ def train(
651
665
652
666
npu_accelerate_plugin (self .optimizer )
653
667
668
+ self .timers and self .timers ("read-data" ).start ()
669
+
654
670
for epoch in range (epochs_trained , num_train_epochs ):
655
671
if isinstance (train_dataloader , paddle .io .DataLoader ) and isinstance (
656
672
train_dataloader .batch_sampler , DistributedBatchSampler
657
673
):
658
674
train_dataloader .batch_sampler .set_epoch (epoch )
659
675
660
- step = - 1
676
+ step_control = 0 # used in loop control, reset to 0 after every step
661
677
self .control = self .callback_handler .on_epoch_begin (args , self .state , self .control )
662
678
663
679
for step , inputs in enumerate (epoch_iterator ):
680
+ self .timers and self .timers ("read-data" ).stop ()
664
681
os .environ ["TRAINER_GLOBAL_STEP" ] = str (self .state .global_step )
665
682
self .callback_handler .on_load_data_end (args , self .state , self .control , inputs = inputs )
683
+
666
684
# Skip past any already trained steps if resuming training
667
685
# for paddlenlp.utils.batch_sampler.DistributedBatchSampler
668
686
# We use consumed_samples to reset the status
@@ -687,8 +705,9 @@ def train(
687
705
steps_trained_progress_bar .close ()
688
706
steps_trained_progress_bar = None
689
707
690
- if step % args .gradient_accumulation_steps == 0 :
708
+ if step_control % args .gradient_accumulation_steps == 0 :
691
709
self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
710
+ self .timers and self .timers ("forward-backward" ).start ()
692
711
693
712
dp_enabled = (
694
713
self .args .data_parallel_degree > 1 if self .args .use_hybrid_parallel else args .local_rank != - 1
@@ -706,14 +725,13 @@ def train(
706
725
availiable_no_sync = dp_enabled and not forbidden_no_sync
707
726
708
727
is_no_sync = (
709
- ((step + 1 ) % args .gradient_accumulation_steps != 0 )
728
+ ((step_control + 1 ) % args .gradient_accumulation_steps != 0 )
710
729
and availiable_no_sync
711
730
and args ._no_sync_in_gradient_accumulation
712
731
) or (args .recompute and availiable_no_sync )
713
732
# sharding
714
733
# stage1. the same as ddp
715
734
# stage2. manualy collect gradient on dp group
716
-
717
735
if is_no_sync :
718
736
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
719
737
with model .no_sync ():
@@ -723,15 +741,18 @@ def train(
723
741
724
742
tr_loss += tr_loss_step
725
743
726
- if (step + 1 ) % args .gradient_accumulation_steps == 0 or (
744
+ if (step_control + 1 ) % args .gradient_accumulation_steps == 0 or (
727
745
# last step in epoch but step is always smaller than gradient_accumulation_steps
728
746
steps_in_epoch <= args .gradient_accumulation_steps
729
747
and (step + 1 ) == steps_in_epoch
730
748
):
749
+ self .timers and self .timers ("forward-backward" ).stop ()
731
750
# Maunally collect gradients when group_sharded_parallel can't accept dp_group
732
751
# Case 1: Use sharding stage 2/3 with dp
733
752
# Case 2: Use recompute and dp
734
753
# local_rank != -1 don't means dp in networks.
754
+ self .timers and self .timers ("all-reduce" ).start ()
755
+
735
756
if self .sharding and ShardingOption .SHARD_OP not in self .args .sharding :
736
757
if self .args .data_parallel_degree > 1 and not is_dp_group_support_in_group_sharded_parallel ():
737
758
fused_allreduce_gradients (model .parameters (), fleet .get_hybrid_communicate_group ())
@@ -763,15 +784,18 @@ def train(
763
784
764
785
if self .optimizer ._dp_enable :
765
786
fused_allreduce_gradients (list (parameters_list ), self .optimizer ._hcg )
787
+ self .timers and self .timers ("all-reduce" ).stop ()
788
+ self .timers and self .timers ("optimizer-step" ).start ()
766
789
767
790
# pipeline parallel mode, handle gradient merge here
768
791
if args .pipeline_parallel_degree > 1 and enable_delay_scale_loss :
769
792
for p in model ._layers .parameters ():
770
- if hasattr (p , "main_grad" ) and p .main_grad is not None :
771
- assert p .grad is None
772
- p .main_grad .scale_ (1.0 / self .args .gradient_accumulation_steps )
773
- elif p .grad is not None :
774
- p .grad .scale_ (1.0 / self .args .gradient_accumulation_steps )
793
+ with paddle .no_grad ():
794
+ if hasattr (p , "main_grad" ) and p .main_grad is not None :
795
+ assert p .grad is None
796
+ p .main_grad .scale_ (1.0 / self .args .gradient_accumulation_steps )
797
+ elif p .grad is not None :
798
+ p .grad .scale_ (1.0 / self .args .gradient_accumulation_steps )
775
799
776
800
# Optimizer step
777
801
self .callback_handler .on_optimizer_begin (
@@ -793,6 +817,8 @@ def train(
793
817
else :
794
818
self .optimizer .step ()
795
819
820
+ self .timers and self .timers ("optimizer-step" ).stop ()
821
+
796
822
if optimizer_was_run :
797
823
self .lr_scheduler .step ()
798
824
@@ -802,15 +828,18 @@ def train(
802
828
)
803
829
804
830
self .state .global_step += 1
805
- self .state .epoch = epoch + (step + 1 ) / steps_in_epoch
806
-
831
+ self .state .epoch = epoch + self .state .global_step / steps_in_epoch
807
832
self .control = self .callback_handler .on_step_end (args , self .state , self .control )
808
833
self ._maybe_log_save_evaluate (tr_loss , model , epoch , ignore_keys_for_eval , inputs = inputs )
834
+ self ._print_timer ()
835
+ step_control = 0
809
836
else :
810
837
self .control = self .callback_handler .on_substep_end (args , self .state , self .control )
838
+ step_control += 1
811
839
812
840
if self .control .should_epoch_stop or self .control .should_training_stop :
813
841
break
842
+ self .timers and self .timers ("read-data" ).start ()
814
843
815
844
if step < 0 :
816
845
logger .warning (
@@ -905,7 +934,33 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
905
934
906
935
def _set_state_dict_in_model (self , state_dict ):
907
936
# TODO @ZHUI paddle need return the results of set_state_dict.
908
- self .model .set_state_dict (state_dict )
937
+ logger .info (f"set state-dict :{ self .model .set_state_dict (state_dict )} " )
938
+
939
+ def _print_timer (self ):
940
+ """print timer and clear states"""
941
+ paddle_timer_info = ""
942
+ try :
943
+ from paddle .distributed .fleet .utils .timer_helper import (
944
+ get_timers as paddle_get_timers ,
945
+ )
946
+
947
+ paddle_pipeline_timers = paddle_get_timers ()
948
+ for name , timer in paddle_pipeline_timers .timers .items ():
949
+ elapsed_time = timer .elapsed (reset = False ) * 1000.0
950
+ paddle_timer_info += f" | { name } : { elapsed_time :.2f} "
951
+ paddle_pipeline_timers .log (paddle_pipeline_timers .timers .keys (), reset = True )
952
+ except ImportError : # paddle version too old, timer not support
953
+ logger .warning (f"paddle version:{ paddle ._git_commit__ } does not support pipeline timer" )
954
+ except AssertionError : # paddle timer not enabled
955
+ pass
956
+
957
+ if self .timers is not None :
958
+ timer_info = self .timers .log (self .timers .timers .keys (), reset = True )
959
+ else :
960
+ timer_info = ""
961
+
962
+ if timer_info or paddle_timer_info :
963
+ logger .info (f"[Profile global_step: { self .state .global_step } ] { timer_info } { paddle_timer_info } " )
909
964
910
965
def _maybe_log_save_evaluate (self , tr_loss , model , epoch , ignore_keys_for_eval , ** kwargs ):
911
966
if self .control .should_log :
@@ -1615,7 +1670,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
1615
1670
self ._pp_data_buffer = []
1616
1671
1617
1672
model .train ()
1618
-
1619
1673
# hack pipeline-layers
1620
1674
# since the pipeline layer will check input is valid every iter.
1621
1675
# in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement.
@@ -1872,6 +1926,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
1872
1926
self .lr_scheduler .set_state_dict (paddle .load (os .path .join (checkpoint , SCHEDULER_NAME )))
1873
1927
if self .do_grad_scaling and os .path .isfile (os .path .join (checkpoint , SCALER_NAME )):
1874
1928
self .scaler .load_state_dict (paddle .load (os .path .join (checkpoint , SCALER_NAME ), return_numpy = True ))
1929
+ else :
1930
+ raise ValueError (
1931
+ f"optimizer-state-dict not found, opt:{ os .path .join (checkpoint , optimizer_name )} scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} "
1932
+ )
1875
1933
1876
1934
def log (self , logs : Dict [str , float ], ** kwargs ) -> None :
1877
1935
"""
@@ -1883,9 +1941,21 @@ def log(self, logs: Dict[str, float], **kwargs) -> None:
1883
1941
logs (`Dict[str, float]`):
1884
1942
The values to log.
1885
1943
"""
1944
+
1945
+ try :
1946
+ from paddle .distributed .fleet .utils .timer_helper import (
1947
+ get_timers as paddle_get_timers ,
1948
+ )
1949
+
1950
+ paddle_pipeline_timers = paddle_get_timers ()
1951
+ except ImportError : # paddle version too old, timer not support
1952
+ logger .warning (f"paddle version:{ paddle ._git_commit__ } does not support pipeline timer" )
1953
+ except AssertionError :
1954
+ paddle_pipeline_timers = None
1955
+ kwargs .update (timer = self .timers , paddle_pipeline_timers = paddle_pipeline_timers )
1956
+
1886
1957
if self .state .epoch is not None :
1887
1958
logs ["epoch" ] = round (self .state .epoch , 4 )
1888
-
1889
1959
output = {** logs , ** {"step" : self .state .global_step }}
1890
1960
self .state .log_history .append (output )
1891
1961
self .control = self .callback_handler .on_log (self .args , self .state , self .control , logs , ** kwargs )
0 commit comments