@@ -701,44 +701,13 @@ def fit(self):
701701 train_begin = time .time ()
702702 time_before_get_data = time .time ()
703703 for data_batch in self ._data_iter ():
704+ consumed_samples = len (data_batch )
705+ time_before_train_step = time .time ()
706+
704707 ProberList .set_step (self ._cur_step + 1 )
705708 DEVICE_MODULE .reset_peak_memory_stats ()
706709
707- time_before_train_step = time .time ()
708- data_time = time_before_train_step - time_before_get_data
709-
710- seq_ctx_list : list [SequenceContext ] = []
711- loss_ctx_input_list : list [CELossContextInputItem ] = []
712- for data in data_batch :
713- seq_ctx = data ["seq_ctx" ].to (DEVICE )
714- loss_ctx_input = CELossContextInputItem (shifted_labels = data ["shifted_labels" ]).to (DEVICE )
715- if self .sp_mesh .size () > 1 :
716- seq_ctx = seq_ctx .split (sequence_parallel_mesh = self .sp_mesh )
717- loss_ctx_input = loss_ctx_input .sp_split (self .sp_mesh )
718- seq_ctx_list .append (seq_ctx )
719- loss_ctx_input_list .append (loss_ctx_input )
720-
721- del data_batch
722-
723- LossContext = self .loss_cfg .loss_ctx_cls
724- batches_loss_kwargs = LossContext .build_batches_loss_kwargs (
725- loss_ctx_input_list ,
726- self .loss_cfg ,
727- cu_seq_lens_list = [seq_ctx .cu_seq_lens_q for seq_ctx in seq_ctx_list ],
728- sp_mesh = self .sp_mesh ,
729- )
730- engine_input = []
731- for seq_ctx , loss_kwargs in zip (seq_ctx_list , batches_loss_kwargs ):
732- loss_ctx = LossContext (
733- loss_cfg = self .loss_cfg ,
734- loss_kwargs = loss_kwargs ,
735- )
736- engine_input .append (
737- ModelItem (
738- seq_ctx = seq_ctx ,
739- loss_ctx = loss_ctx ,
740- )
741- )
710+ engine_input = self ._prepare_model_input (data_batch )
742711
743712 with self ._maybe_profiling ():
744713 loss_log , other_log = self ._engine .train_step (engine_input )
@@ -756,47 +725,30 @@ def fit(self):
756725
757726 grad_norm = self ._engine .clip_grad_norm (do_clip = self ._do_clip , dtype = self ._grad_norm_dtype )
758727 self ._engine .step_optimizer (grad_norm )
728+
759729 time_after_train_step = time .time ()
760730 ProberList .after_step ()
761- step_time = time_after_train_step - time_before_train_step
762- step_consumed_tokens = other_log ["consumed_tokens" ]
763- step_consumed_img_tokens = other_log .get ("consumed_img_tokens" , None )
764731
765- extra_info = other_log .get ("extra_info" , {})
766- if isinstance (extra_info , ModelForwardExtraLogInfo ):
767- extra_info_dict = extra_info .get ()
768- else :
769- extra_info_updated = ModelForwardExtraLogInfo (extra_info )
770- extra_info_dict = extra_info_updated .get ()
771- loss_log .update (extra_info_dict )
772-
773- if "maxvio" in other_log :
774- loss_log ["maxvio" ] = other_log ["maxvio" ]
775- loss_log ["efficient_attn_ratio" ] = other_log ["efficient_attn_ratio" ]
732+ data_time = time_before_train_step - time_before_get_data
733+ step_time = time_after_train_step - time_before_train_step
776734
777735 internal_metrics = self ._maybe_pop_model_internal_metrics (engine_input )
778736
779737 self ._cur_step += 1
780-
781- reduced_step_consumed_tokens = self ._reduce_number_across_rank (step_consumed_tokens )
782- self ._reduced_consumed_tokens += reduced_step_consumed_tokens
783-
784- self ._exp_consumed_tokens += step_consumed_tokens
738+ self ._reduced_consumed_tokens += self ._reduce_number_across_rank (other_log ["consumed_tokens" ])
739+ self ._reduced_consumed_samples += self ._reduce_number_across_rank (consumed_samples )
740+ self ._exp_consumed_tokens += other_log ["consumed_tokens" ]
785741 self ._train_time = time_after_train_step - train_begin
786742
787743 # TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
788744 self ._log_step (
789745 loss_log = loss_log ,
790- step_consumed_tokens = step_consumed_tokens ,
791- exp_consumed_tokens = self ._exp_consumed_tokens ,
792- step_consumed_img_tokens = step_consumed_img_tokens ,
793- reduced_consumed_tokens = self ._reduced_consumed_tokens ,
794- data_time = data_time ,
795- step_time = step_time ,
796- train_time = self ._train_time ,
797- train_time_offset = self ._train_time_offset ,
746+ step_consumed_tokens = other_log ["consumed_tokens" ],
747+ step_consumed_img_tokens = other_log .get ("consumed_img_tokens" , None ),
798748 grad_norm = grad_norm .item (),
799749 internal_metrics = internal_metrics ,
750+ data_time = data_time ,
751+ step_time = step_time ,
800752 )
801753
802754 self ._lr_scheduler .step ()
@@ -817,7 +769,44 @@ def fit(self):
817769 self ._metrics_recorder .close ()
818770 self .logger .info (f"Training finished in { time .time () - train_begin :.2f} seconds" )
819771
820- def _reduce_number_across_rank (self , rank_number : int ) -> int :
772+ def _prepare_model_input (self , data_batch ) -> list [ModelItem ]:
773+ seq_ctx_list : list [SequenceContext ] = []
774+ loss_ctx_input_list : list [CELossContextInputItem ] = []
775+
776+ for data in data_batch :
777+ seq_ctx = data ["seq_ctx" ].to (DEVICE )
778+ loss_ctx_input = CELossContextInputItem (shifted_labels = data ["shifted_labels" ]).to (DEVICE )
779+ if self .sp_mesh .size () > 1 :
780+ seq_ctx = seq_ctx .split (sequence_parallel_mesh = self .sp_mesh )
781+ loss_ctx_input = loss_ctx_input .sp_split (self .sp_mesh )
782+ seq_ctx_list .append (seq_ctx )
783+ loss_ctx_input_list .append (loss_ctx_input )
784+
785+ # TODO: Consider moving data_batch deletion to the caller for better memory management.
786+ del data_batch
787+
788+ LossContext = self .loss_cfg .loss_ctx_cls
789+ batches_loss_kwargs = LossContext .build_batches_loss_kwargs (
790+ loss_ctx_input_list ,
791+ self .loss_cfg ,
792+ cu_seq_lens_list = [seq_ctx .cu_seq_lens_q for seq_ctx in seq_ctx_list ],
793+ sp_mesh = self .sp_mesh ,
794+ )
795+ engine_input = []
796+ for seq_ctx , loss_kwargs in zip (seq_ctx_list , batches_loss_kwargs ):
797+ loss_ctx = LossContext (
798+ loss_cfg = self .loss_cfg ,
799+ loss_kwargs = loss_kwargs ,
800+ )
801+ engine_input .append (
802+ ModelItem (
803+ seq_ctx = seq_ctx ,
804+ loss_ctx = loss_ctx ,
805+ )
806+ )
807+ return engine_input
808+
809+ def _reduce_number_across_rank (self , rank_number : int | float ) -> int :
821810 _gathered_list = [None for _ in range (self .world_size )]
822811 dist .all_gather_object (_gathered_list , rank_number )
823812 reduced_number = sum (_gathered_list ) # type: ignore[arg-type]
@@ -1257,7 +1246,6 @@ def _data_iter(self):
12571246 data_iter = iter (self ._dataloader )
12581247 data = next (data_iter )
12591248
1260- self ._reduced_consumed_samples += self ._reduce_number_across_rank (len (data ))
12611249 yield data
12621250
12631251 def _get_checkpoint_path (self , epoch : int , step : int , is_snapshot : bool = False ) -> Path :
@@ -1434,24 +1422,21 @@ def _maybe_profiling(self):
14341422
14351423 def _log_step (
14361424 self ,
1437- loss_log : dict ,
1425+ loss_log : LossLog ,
14381426 step_consumed_tokens : int ,
1439- exp_consumed_tokens : int ,
1440- reduced_consumed_tokens : int ,
1427+ step_consumed_img_tokens : int | None ,
1428+ grad_norm : float ,
14411429 data_time : float ,
14421430 step_time : float ,
1443- train_time : float ,
1444- train_time_offset : float ,
1445- grad_norm : float ,
1446- step_consumed_img_tokens : float | None ,
14471431 internal_metrics : InternalMetrics | None = None ,
14481432 ):
14491433 """Log the training step information."""
1450- e2e_train_time = train_time + train_time_offset
1434+ e2e_train_time = self ._train_time + self ._train_time_offset
1435+
14511436 tgs = step_consumed_tokens / step_time
1452- rank_consumed_tokens = reduced_consumed_tokens / self .world_size
1437+ rank_consumed_tokens = self . _reduced_consumed_tokens / self .world_size
14531438 e2e_tgs = rank_consumed_tokens / e2e_train_time
1454- exp_tgs = exp_consumed_tokens / train_time
1439+ exp_tgs = self . _exp_consumed_tokens / self . _train_time
14551440 lr = self ._lr_scheduler .get_last_lr ()[0 ]
14561441
14571442 remaining_steps = self .total_step - self .cur_step
@@ -1481,7 +1466,7 @@ def _log_step(
14811466 f"Epoch { self ._cur_epoch } Step { self .cur_step } /{ self .total_step } "
14821467 f"data_time: { data_time :.4f} lr: { lr :.6e} time: { step_time :.4f} "
14831468 f"text_tokens: { step_consumed_tokens } { img_tokens_str } "
1484- f"reduced_consumed_tokens: { reduced_consumed_tokens } "
1469+ f"reduced_consumed_tokens: { self . _reduced_consumed_tokens } "
14851470 f"{ loss_log_str } "
14861471 f"grad_norm: { grad_norm :.8f} "
14871472 f"max_memory: { max_memory / (1024 ** 3 ):.2f} GB "
@@ -1497,11 +1482,11 @@ def _log_step(
14971482 "lr" : lr ,
14981483 "time/data_time" : round (data_time , 4 ),
14991484 "time/step_time" : round (step_time , 4 ),
1500- "time/train_time" : round (train_time , 4 ),
1485+ "time/train_time" : round (self . _train_time , 4 ),
15011486 "time/eta_seconds" : round (eta_seconds , 1 ),
15021487 "runtime_info/text_tokens" : step_consumed_tokens ,
15031488 "runtime_info/est_global_batch_tokens" : est_global_batch_tokens ,
1504- "runtime_info/reduced_consumed_tokens" : reduced_consumed_tokens ,
1489+ "runtime_info/reduced_consumed_tokens" : self . _reduced_consumed_tokens ,
15051490 "runtime_info/tgs" : tgs ,
15061491 "runtime_info/exp_tgs" : exp_tgs ,
15071492 "runtime_info/e2e_tgs" : e2e_tgs ,
0 commit comments