@@ -758,25 +758,16 @@ def __run_pretrain_routine(self, model):
758758 # ---------------------------
759759 # CORE TRAINING LOOP
760760 # ---------------------------
761-
762761 self .__train ()
763762
764763 def __train (self ):
765764 # run all epochs
766765 for epoch_nb in range (self .current_epoch , self .max_nb_epochs ):
767- # update the lr scheduler
768- if self .lr_schedulers is not None :
769- for lr_scheduler in self .lr_schedulers :
770- lr_scheduler .step ()
771-
766+ # get model
772767 model = self .__get_model ()
773- model .current_epoch = epoch_nb
774-
775- # hook
776- if self .__is_function_implemented ('on_epoch_start' ):
777- model = self .__get_model ()
778- model .on_epoch_start ()
779768
769+ # update training progress in trainer and model
770+ model .current_epoch = epoch_nb
780771 self .current_epoch = epoch_nb
781772 self .total_batches = self .nb_tng_batches + self .nb_val_batches
782773 self .batch_loss_value = 0 # accumulated grads
@@ -786,92 +777,103 @@ def __train(self):
786777 self .prog_bar = tqdm .tqdm (range (self .total_batches ),
787778 position = self .process_position )
788779
789- for batch_nb , data_batch in enumerate (self .tng_dataloader ):
790- self .batch_nb = batch_nb
791- self .global_step += 1
792-
793- model = self .__get_model ()
794- model .global_step = self .global_step
795-
796- # stop when the flag is changed or we've gone past the amount
797- # requested in the batches
798- self .total_batch_nb += 1
799- met_batch_limit = batch_nb > self .nb_tng_batches
800- if met_batch_limit :
801- break
802-
803- # ---------------
804- # RUN TRAIN STEP
805- # ---------------
806- batch_result = self .__run_tng_batch (data_batch , batch_nb )
807- early_stop_epoch = batch_result == - 1
808-
809- # ---------------
810- # RUN VAL STEP
811- # ---------------
812- is_val_check_batch = (batch_nb + 1 ) % self .val_check_batch == 0
813- if self .fast_dev_run or is_val_check_batch or early_stop_epoch :
814- self .__run_validation ()
815-
816- # when batch should be saved
817- if (batch_nb + 1 ) % self .log_save_interval == 0 or early_stop_epoch :
818- if self .proc_rank == 0 and self .experiment is not None :
819- self .experiment .save ()
820-
821- # when metrics should be logged
822- if batch_nb % self .add_log_row_interval == 0 or early_stop_epoch :
823- # count items in memory
824- # nb_params, nb_tensors = count_mem_items()
825-
826- model = self .__get_model ()
827- metrics = self .__tng_tqdm_dic
828-
829- # add gpu memory
830- if self .on_gpu :
831- mem_map = get_gpu_memory_map ()
832- metrics .update (mem_map )
833-
834- # add norms
835- if self .track_grad_norm > 0 :
836- model = self .__get_model ()
837- grad_norm_dic = model .grad_norm (self .track_grad_norm )
838- metrics .update (grad_norm_dic )
839-
840- if self .__is_function_implemented ('on_tng_metrics' ):
841- model .on_tng_metrics (metrics )
842-
843- # log metrics
844- scalar_metrics = self .__metrics_to_scalars (
845- metrics , blacklist = self .__log_vals_blacklist ())
846- if self .proc_rank == 0 and self .experiment is not None :
847- self .experiment .log (scalar_metrics , global_step = self .global_step )
848- self .experiment .save ()
849-
850- # hook
851- if self .__is_function_implemented ('on_batch_end' ):
852- model = self .__get_model ()
853- model .on_batch_end ()
854-
855- # end epoch early
856- if early_stop_epoch :
857- break
780+ # -----------------
781+ # RUN TNG EPOCH
782+ # -----------------
783+ self .run_tng_epoch ()
858784
859- # hook
860- if self .__is_function_implemented ( 'on_epoch_end' ) :
861- model = self .__get_model ()
862- model . on_epoch_end ()
785+ # update LR schedulers
786+ if self .lr_schedulers is not None :
787+ for lr_scheduler in self .lr_schedulers :
788+ lr_scheduler . step ()
863789
864790 # early stopping
865791 met_min_epochs = epoch_nb > self .min_nb_epochs
866792 if self .enable_early_stop and met_min_epochs :
867793 should_stop = self .early_stop_callback .on_epoch_end (epoch = epoch_nb ,
868794 logs = self .__tng_tqdm_dic )
869-
870795 # stop training
871796 stop = should_stop and met_min_epochs
872797 if stop :
873798 return
874799
800+ def run_tng_epoch (self ):
801+ # before epoch hook
802+ if self .__is_function_implemented ('on_epoch_start' ):
803+ model = self .__get_model ()
804+ model .on_epoch_start ()
805+
806+ # run epoch
807+ for batch_nb , data_batch in enumerate (self .tng_dataloader ):
808+ self .batch_nb = batch_nb
809+ self .global_step += 1
810+
811+ model = self .__get_model ()
812+ model .global_step = self .global_step
813+
814+ # stop when the flag is changed or we've gone past the amount
815+ # requested in the batches
816+ self .total_batch_nb += 1
817+ met_batch_limit = batch_nb > self .nb_tng_batches
818+ if met_batch_limit :
819+ break
820+
821+ # ---------------
822+ # RUN TRAIN STEP
823+ # ---------------
824+ batch_result = self .__run_tng_batch (data_batch , batch_nb )
825+ early_stop_epoch = batch_result == - 1
826+
827+ # ---------------
828+ # RUN VAL STEP
829+ # ---------------
830+ is_val_check_batch = (batch_nb + 1 ) % self .val_check_batch == 0
831+ if self .fast_dev_run or is_val_check_batch or early_stop_epoch :
832+ self .__run_validation ()
833+
834+ # when batch should be saved
835+ if (batch_nb + 1 ) % self .log_save_interval == 0 or early_stop_epoch :
836+ if self .proc_rank == 0 and self .experiment is not None :
837+ self .experiment .save ()
838+
839+ # when metrics should be logged
840+ if batch_nb % self .add_log_row_interval == 0 or early_stop_epoch :
841+ # count items in memory
842+ # nb_params, nb_tensors = count_mem_items()
843+
844+ model = self .__get_model ()
845+ metrics = self .__tng_tqdm_dic
846+
847+ # add gpu memory
848+ if self .on_gpu :
849+ mem_map = get_gpu_memory_map ()
850+ metrics .update (mem_map )
851+
852+ # add norms
853+ if self .track_grad_norm > 0 :
854+ model = self .__get_model ()
855+ grad_norm_dic = model .grad_norm (self .track_grad_norm )
856+ metrics .update (grad_norm_dic )
857+
858+ if self .__is_function_implemented ('on_tng_metrics' ):
859+ model .on_tng_metrics (metrics )
860+
861+ # log metrics
862+ scalar_metrics = self .__metrics_to_scalars (
863+ metrics , blacklist = self .__log_vals_blacklist ())
864+ if self .proc_rank == 0 and self .experiment is not None :
865+ self .experiment .log (scalar_metrics , global_step = self .global_step )
866+ self .experiment .save ()
867+
868+ # end epoch early
869+ if early_stop_epoch :
870+ break
871+
872+ # epoch end hook
873+ if self .__is_function_implemented ('on_epoch_end' ):
874+ model = self .__get_model ()
875+ model .on_epoch_end ()
876+
875877 def __metrics_to_scalars (self , metrics , blacklist = set ()):
876878 new_metrics = {}
877879 for k , v in metrics .items ():
0 commit comments