3737from xtuner .v1 .model .base import ModelItem , TransformerConfig , XTunerBaseModelConfig
3838from xtuner .v1 .model .compose .base import BaseComposeConfig
3939from xtuner .v1 .model .moe .moe import MoEConfig
40- from xtuner .v1 .model .utils import ModelForwardExtraLogInfo
4140from xtuner .v1 .patch import patch_default_save_plan
4241from xtuner .v1 .profiler import profiling_memory , profiling_time
4342from xtuner .v1 .profiler .prober import ProberList
@@ -701,44 +700,13 @@ def fit(self):
701700 train_begin = time .time ()
702701 time_before_get_data = time .time ()
703702 for data_batch in self ._data_iter ():
703+ consumed_samples = len (data_batch )
704+ time_before_train_step = time .time ()
705+
704706 ProberList .set_step (self ._cur_step + 1 )
705707 DEVICE_MODULE .reset_peak_memory_stats ()
706708
707- time_before_train_step = time .time ()
708- data_time = time_before_train_step - time_before_get_data
709-
710- seq_ctx_list : list [SequenceContext ] = []
711- loss_ctx_input_list : list [CELossContextInputItem ] = []
712- for data in data_batch :
713- seq_ctx = data ["seq_ctx" ].to (DEVICE )
714- loss_ctx_input = CELossContextInputItem (shifted_labels = data ["shifted_labels" ]).to (DEVICE )
715- if self .sp_mesh .size () > 1 :
716- seq_ctx = seq_ctx .split (sequence_parallel_mesh = self .sp_mesh )
717- loss_ctx_input = loss_ctx_input .sp_split (self .sp_mesh )
718- seq_ctx_list .append (seq_ctx )
719- loss_ctx_input_list .append (loss_ctx_input )
720-
721- del data_batch
722-
723- LossContext = self .loss_cfg .loss_ctx_cls
724- batches_loss_kwargs = LossContext .build_batches_loss_kwargs (
725- loss_ctx_input_list ,
726- self .loss_cfg ,
727- cu_seq_lens_list = [seq_ctx .cu_seq_lens_q for seq_ctx in seq_ctx_list ],
728- sp_mesh = self .sp_mesh ,
729- )
730- engine_input = []
731- for seq_ctx , loss_kwargs in zip (seq_ctx_list , batches_loss_kwargs ):
732- loss_ctx = LossContext (
733- loss_cfg = self .loss_cfg ,
734- loss_kwargs = loss_kwargs ,
735- )
736- engine_input .append (
737- ModelItem (
738- seq_ctx = seq_ctx ,
739- loss_ctx = loss_ctx ,
740- )
741- )
709+ engine_input = self ._prepare_model_input (data_batch )
742710
743711 with self ._maybe_profiling ():
744712 loss_log , other_log = self ._engine .train_step (engine_input )
@@ -756,47 +724,30 @@ def fit(self):
756724
757725 grad_norm = self ._engine .clip_grad_norm (do_clip = self ._do_clip , dtype = self ._grad_norm_dtype )
758726 self ._engine .step_optimizer (grad_norm )
727+
759728 time_after_train_step = time .time ()
760729 ProberList .after_step ()
761- step_time = time_after_train_step - time_before_train_step
762- step_consumed_tokens = other_log ["consumed_tokens" ]
763- step_consumed_img_tokens = other_log .get ("consumed_img_tokens" , None )
764730
765- extra_info = other_log .get ("extra_info" , {})
766- if isinstance (extra_info , ModelForwardExtraLogInfo ):
767- extra_info_dict = extra_info .get ()
768- else :
769- extra_info_updated = ModelForwardExtraLogInfo (extra_info )
770- extra_info_dict = extra_info_updated .get ()
771- loss_log .update (extra_info_dict )
772-
773- if "maxvio" in other_log :
774- loss_log ["maxvio" ] = other_log ["maxvio" ]
775- loss_log ["efficient_attn_ratio" ] = other_log ["efficient_attn_ratio" ]
731+ data_time = time_before_train_step - time_before_get_data
732+ step_time = time_after_train_step - time_before_train_step
776733
777734 internal_metrics = self ._maybe_pop_model_internal_metrics (engine_input )
778735
779736 self ._cur_step += 1
780-
781- reduced_step_consumed_tokens = self ._reduce_number_across_rank (step_consumed_tokens )
782- self ._reduced_consumed_tokens += reduced_step_consumed_tokens
783-
784- self ._exp_consumed_tokens += step_consumed_tokens
737+ self ._reduced_consumed_tokens += self ._reduce_number_across_rank (other_log ["consumed_tokens" ])
738+ self ._reduced_consumed_samples += self ._reduce_number_across_rank (consumed_samples )
739+ self ._exp_consumed_tokens += other_log ["consumed_tokens" ]
785740 self ._train_time = time_after_train_step - train_begin
786741
787742 # TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
788743 self ._log_step (
789744 loss_log = loss_log ,
790- step_consumed_tokens = step_consumed_tokens ,
791- exp_consumed_tokens = self ._exp_consumed_tokens ,
792- step_consumed_img_tokens = step_consumed_img_tokens ,
793- reduced_consumed_tokens = self ._reduced_consumed_tokens ,
794- data_time = data_time ,
795- step_time = step_time ,
796- train_time = self ._train_time ,
797- train_time_offset = self ._train_time_offset ,
745+ step_consumed_tokens = other_log ["consumed_tokens" ],
746+ step_consumed_img_tokens = other_log .get ("consumed_img_tokens" , None ),
798747 grad_norm = grad_norm .item (),
799748 internal_metrics = internal_metrics ,
749+ data_time = data_time ,
750+ step_time = step_time ,
800751 )
801752
802753 self ._lr_scheduler .step ()
@@ -817,7 +768,44 @@ def fit(self):
817768 self ._metrics_recorder .close ()
818769 self .logger .info (f"Training finished in { time .time () - train_begin :.2f} seconds" )
819770
820- def _reduce_number_across_rank (self , rank_number : int ) -> int :
771+ def _prepare_model_input (self , data_batch ) -> list [ModelItem ]:
772+ seq_ctx_list : list [SequenceContext ] = []
773+ loss_ctx_input_list : list [CELossContextInputItem ] = []
774+
775+ for data in data_batch :
776+ seq_ctx = data ["seq_ctx" ].to (DEVICE )
777+ loss_ctx_input = CELossContextInputItem (shifted_labels = data ["shifted_labels" ]).to (DEVICE )
778+ if self .sp_mesh .size () > 1 :
779+ seq_ctx = seq_ctx .split (sequence_parallel_mesh = self .sp_mesh )
780+ loss_ctx_input = loss_ctx_input .sp_split (self .sp_mesh )
781+ seq_ctx_list .append (seq_ctx )
782+ loss_ctx_input_list .append (loss_ctx_input )
783+
784+ # TODO: Consider moving data_batch deletion to the caller for better memory management.
785+ del data_batch
786+
787+ LossContext = self .loss_cfg .loss_ctx_cls
788+ batches_loss_kwargs = LossContext .build_batches_loss_kwargs (
789+ loss_ctx_input_list ,
790+ self .loss_cfg ,
791+ cu_seq_lens_list = [seq_ctx .cu_seq_lens_q for seq_ctx in seq_ctx_list ],
792+ sp_mesh = self .sp_mesh ,
793+ )
794+ engine_input = []
795+ for seq_ctx , loss_kwargs in zip (seq_ctx_list , batches_loss_kwargs ):
796+ loss_ctx = LossContext (
797+ loss_cfg = self .loss_cfg ,
798+ loss_kwargs = loss_kwargs ,
799+ )
800+ engine_input .append (
801+ ModelItem (
802+ seq_ctx = seq_ctx ,
803+ loss_ctx = loss_ctx ,
804+ )
805+ )
806+ return engine_input
807+
808+ def _reduce_number_across_rank (self , rank_number : int | float ) -> int :
821809 _gathered_list = [None for _ in range (self .world_size )]
822810 dist .all_gather_object (_gathered_list , rank_number )
823811 reduced_number = sum (_gathered_list ) # type: ignore[arg-type]
@@ -1257,7 +1245,6 @@ def _data_iter(self):
12571245 data_iter = iter (self ._dataloader )
12581246 data = next (data_iter )
12591247
1260- self ._reduced_consumed_samples += self ._reduce_number_across_rank (len (data ))
12611248 yield data
12621249
12631250 def _get_checkpoint_path (self , epoch : int , step : int , is_snapshot : bool = False ) -> Path :
@@ -1434,24 +1421,21 @@ def _maybe_profiling(self):
14341421
14351422 def _log_step (
14361423 self ,
1437- loss_log : dict ,
1424+ loss_log : LossLog ,
14381425 step_consumed_tokens : int ,
1439- exp_consumed_tokens : int ,
1440- reduced_consumed_tokens : int ,
1426+ step_consumed_img_tokens : int | None ,
1427+ grad_norm : float ,
14411428 data_time : float ,
14421429 step_time : float ,
1443- train_time : float ,
1444- train_time_offset : float ,
1445- grad_norm : float ,
1446- step_consumed_img_tokens : float | None ,
14471430 internal_metrics : InternalMetrics | None = None ,
14481431 ):
14491432 """Log the training step information."""
1450- e2e_train_time = train_time + train_time_offset
1433+ e2e_train_time = self ._train_time + self ._train_time_offset
1434+
14511435 tgs = step_consumed_tokens / step_time
1452- rank_consumed_tokens = reduced_consumed_tokens / self .world_size
1436+ rank_consumed_tokens = self . _reduced_consumed_tokens / self .world_size
14531437 e2e_tgs = rank_consumed_tokens / e2e_train_time
1454- exp_tgs = exp_consumed_tokens / train_time
1438+ exp_tgs = self . _exp_consumed_tokens / self . _train_time
14551439 lr = self ._lr_scheduler .get_last_lr ()[0 ]
14561440
14571441 remaining_steps = self .total_step - self .cur_step
@@ -1481,7 +1465,7 @@ def _log_step(
14811465 f"Epoch { self ._cur_epoch } Step { self .cur_step } /{ self .total_step } "
14821466 f"data_time: { data_time :.4f} lr: { lr :.6e} time: { step_time :.4f} "
14831467 f"text_tokens: { step_consumed_tokens } { img_tokens_str } "
1484- f"reduced_consumed_tokens: { reduced_consumed_tokens } "
1468+ f"reduced_consumed_tokens: { self . _reduced_consumed_tokens } "
14851469 f"{ loss_log_str } "
14861470 f"grad_norm: { grad_norm :.8f} "
14871471 f"max_memory: { max_memory / (1024 ** 3 ):.2f} GB "
@@ -1497,11 +1481,11 @@ def _log_step(
14971481 "lr" : lr ,
14981482 "time/data_time" : round (data_time , 4 ),
14991483 "time/step_time" : round (step_time , 4 ),
1500- "time/train_time" : round (train_time , 4 ),
1484+ "time/train_time" : round (self . _train_time , 4 ),
15011485 "time/eta_seconds" : round (eta_seconds , 1 ),
15021486 "runtime_info/text_tokens" : step_consumed_tokens ,
15031487 "runtime_info/est_global_batch_tokens" : est_global_batch_tokens ,
1504- "runtime_info/reduced_consumed_tokens" : reduced_consumed_tokens ,
1488+ "runtime_info/reduced_consumed_tokens" : self . _reduced_consumed_tokens ,
15051489 "runtime_info/tgs" : tgs ,
15061490 "runtime_info/exp_tgs" : exp_tgs ,
15071491 "runtime_info/e2e_tgs" : e2e_tgs ,
0 commit comments