@@ -758,25 +758,16 @@ def __run_pretrain_routine(self, model):
758
758
# ---------------------------
759
759
# CORE TRAINING LOOP
760
760
# ---------------------------
761
-
762
761
self .__train ()
763
762
764
763
def __train (self ):
765
764
# run all epochs
766
765
for epoch_nb in range (self .current_epoch , self .max_nb_epochs ):
767
- # update the lr scheduler
768
- if self .lr_schedulers is not None :
769
- for lr_scheduler in self .lr_schedulers :
770
- lr_scheduler .step ()
771
-
766
+ # get model
772
767
model = self .__get_model ()
773
- model .current_epoch = epoch_nb
774
-
775
- # hook
776
- if self .__is_function_implemented ('on_epoch_start' ):
777
- model = self .__get_model ()
778
- model .on_epoch_start ()
779
768
769
+ # update training progress in trainer and model
770
+ model .current_epoch = epoch_nb
780
771
self .current_epoch = epoch_nb
781
772
self .total_batches = self .nb_tng_batches + self .nb_val_batches
782
773
self .batch_loss_value = 0 # accumulated grads
@@ -786,92 +777,103 @@ def __train(self):
786
777
self .prog_bar = tqdm .tqdm (range (self .total_batches ),
787
778
position = self .process_position )
788
779
789
- for batch_nb , data_batch in enumerate (self .tng_dataloader ):
790
- self .batch_nb = batch_nb
791
- self .global_step += 1
792
-
793
- model = self .__get_model ()
794
- model .global_step = self .global_step
795
-
796
- # stop when the flag is changed or we've gone past the amount
797
- # requested in the batches
798
- self .total_batch_nb += 1
799
- met_batch_limit = batch_nb > self .nb_tng_batches
800
- if met_batch_limit :
801
- break
802
-
803
- # ---------------
804
- # RUN TRAIN STEP
805
- # ---------------
806
- batch_result = self .__run_tng_batch (data_batch , batch_nb )
807
- early_stop_epoch = batch_result == - 1
808
-
809
- # ---------------
810
- # RUN VAL STEP
811
- # ---------------
812
- is_val_check_batch = (batch_nb + 1 ) % self .val_check_batch == 0
813
- if self .fast_dev_run or is_val_check_batch or early_stop_epoch :
814
- self .__run_validation ()
815
-
816
- # when batch should be saved
817
- if (batch_nb + 1 ) % self .log_save_interval == 0 or early_stop_epoch :
818
- if self .proc_rank == 0 and self .experiment is not None :
819
- self .experiment .save ()
820
-
821
- # when metrics should be logged
822
- if batch_nb % self .add_log_row_interval == 0 or early_stop_epoch :
823
- # count items in memory
824
- # nb_params, nb_tensors = count_mem_items()
825
-
826
- model = self .__get_model ()
827
- metrics = self .__tng_tqdm_dic
828
-
829
- # add gpu memory
830
- if self .on_gpu :
831
- mem_map = get_gpu_memory_map ()
832
- metrics .update (mem_map )
833
-
834
- # add norms
835
- if self .track_grad_norm > 0 :
836
- model = self .__get_model ()
837
- grad_norm_dic = model .grad_norm (self .track_grad_norm )
838
- metrics .update (grad_norm_dic )
839
-
840
- if self .__is_function_implemented ('on_tng_metrics' ):
841
- model .on_tng_metrics (metrics )
842
-
843
- # log metrics
844
- scalar_metrics = self .__metrics_to_scalars (
845
- metrics , blacklist = self .__log_vals_blacklist ())
846
- if self .proc_rank == 0 and self .experiment is not None :
847
- self .experiment .log (scalar_metrics , global_step = self .global_step )
848
- self .experiment .save ()
849
-
850
- # hook
851
- if self .__is_function_implemented ('on_batch_end' ):
852
- model = self .__get_model ()
853
- model .on_batch_end ()
854
-
855
- # end epoch early
856
- if early_stop_epoch :
857
- break
780
+ # -----------------
781
+ # RUN TNG EPOCH
782
+ # -----------------
783
+ self .run_tng_epoch ()
858
784
859
- # hook
860
- if self .__is_function_implemented ( 'on_epoch_end' ) :
861
- model = self .__get_model ()
862
- model . on_epoch_end ()
785
+ # update LR schedulers
786
+ if self .lr_schedulers is not None :
787
+ for lr_scheduler in self .lr_schedulers :
788
+ lr_scheduler . step ()
863
789
864
790
# early stopping
865
791
met_min_epochs = epoch_nb > self .min_nb_epochs
866
792
if self .enable_early_stop and met_min_epochs :
867
793
should_stop = self .early_stop_callback .on_epoch_end (epoch = epoch_nb ,
868
794
logs = self .__tng_tqdm_dic )
869
-
870
795
# stop training
871
796
stop = should_stop and met_min_epochs
872
797
if stop :
873
798
return
874
799
800
+ def run_tng_epoch (self ):
801
+ # before epoch hook
802
+ if self .__is_function_implemented ('on_epoch_start' ):
803
+ model = self .__get_model ()
804
+ model .on_epoch_start ()
805
+
806
+ # run epoch
807
+ for batch_nb , data_batch in enumerate (self .tng_dataloader ):
808
+ self .batch_nb = batch_nb
809
+ self .global_step += 1
810
+
811
+ model = self .__get_model ()
812
+ model .global_step = self .global_step
813
+
814
+ # stop when the flag is changed or we've gone past the amount
815
+ # requested in the batches
816
+ self .total_batch_nb += 1
817
+ met_batch_limit = batch_nb > self .nb_tng_batches
818
+ if met_batch_limit :
819
+ break
820
+
821
+ # ---------------
822
+ # RUN TRAIN STEP
823
+ # ---------------
824
+ batch_result = self .__run_tng_batch (data_batch , batch_nb )
825
+ early_stop_epoch = batch_result == - 1
826
+
827
+ # ---------------
828
+ # RUN VAL STEP
829
+ # ---------------
830
+ is_val_check_batch = (batch_nb + 1 ) % self .val_check_batch == 0
831
+ if self .fast_dev_run or is_val_check_batch or early_stop_epoch :
832
+ self .__run_validation ()
833
+
834
+ # when batch should be saved
835
+ if (batch_nb + 1 ) % self .log_save_interval == 0 or early_stop_epoch :
836
+ if self .proc_rank == 0 and self .experiment is not None :
837
+ self .experiment .save ()
838
+
839
+ # when metrics should be logged
840
+ if batch_nb % self .add_log_row_interval == 0 or early_stop_epoch :
841
+ # count items in memory
842
+ # nb_params, nb_tensors = count_mem_items()
843
+
844
+ model = self .__get_model ()
845
+ metrics = self .__tng_tqdm_dic
846
+
847
+ # add gpu memory
848
+ if self .on_gpu :
849
+ mem_map = get_gpu_memory_map ()
850
+ metrics .update (mem_map )
851
+
852
+ # add norms
853
+ if self .track_grad_norm > 0 :
854
+ model = self .__get_model ()
855
+ grad_norm_dic = model .grad_norm (self .track_grad_norm )
856
+ metrics .update (grad_norm_dic )
857
+
858
+ if self .__is_function_implemented ('on_tng_metrics' ):
859
+ model .on_tng_metrics (metrics )
860
+
861
+ # log metrics
862
+ scalar_metrics = self .__metrics_to_scalars (
863
+ metrics , blacklist = self .__log_vals_blacklist ())
864
+ if self .proc_rank == 0 and self .experiment is not None :
865
+ self .experiment .log (scalar_metrics , global_step = self .global_step )
866
+ self .experiment .save ()
867
+
868
+ # end epoch early
869
+ if early_stop_epoch :
870
+ break
871
+
872
+ # epoch end hook
873
+ if self .__is_function_implemented ('on_epoch_end' ):
874
+ model = self .__get_model ()
875
+ model .on_epoch_end ()
876
+
875
877
def __metrics_to_scalars (self , metrics , blacklist = set ()):
876
878
new_metrics = {}
877
879
for k , v in metrics .items ():
0 commit comments