File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -673,6 +673,8 @@ def train(self) -> None:
673673
674674            self .transformer .train ()
675675            models_to_accumulate  =  [self .transformer ]
676+             epoch_loss  =  0.0 
677+             num_loss_updates  =  0 
676678
677679            for  step , batch  in  enumerate (self .dataloader ):
678680                logger .debug (f"Starting step { step  +  1 }  " )
@@ -843,14 +845,20 @@ def train(self) -> None:
843845                    if  should_run_validation :
844846                        self .validate (global_step )
845847
846-                 logs ["loss" ] =  loss .detach ().item ()
848+                 loss_item  =  loss .detach ().item ()
849+                 epoch_loss  +=  loss_item 
850+                 num_loss_updates  +=  1 
851+                 logs ["step_loss" ] =  loss_item 
847852                logs ["lr" ] =  self .lr_scheduler .get_last_lr ()[0 ]
848853                progress_bar .set_postfix (logs )
849854                accelerator .log (logs , step = global_step )
850855
851856                if  global_step  >=  self .state .train_steps :
852857                    break 
853858
859+             if  num_loss_updates  >  0 :
860+                 epoch_loss  /=  num_loss_updates 
861+             accelerator .log ({"epoch_loss" : epoch_loss }, step = global_step )
854862            memory_statistics  =  get_memory_statistics ()
855863            logger .info (f"Memory after epoch { epoch  +  1 }  : { json .dumps (memory_statistics , indent = 4 )}  " )
856864
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments