@@ -793,17 +793,22 @@ def main():
793793    )
794794
795795    # Scheduler and math around the number of training steps. 
796-     overrode_max_train_steps   =   False 
797-     num_update_steps_per_epoch  =  math . ceil ( len ( train_dataloader )  /   args . gradient_accumulation_steps ) 
796+     # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. 
797+     num_warmup_steps_for_scheduler  =  args . lr_warmup_steps   *   accelerator . num_processes 
798798    if  args .max_train_steps  is  None :
799-         args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
800-         overrode_max_train_steps  =  True 
799+         len_train_dataloader_after_sharding  =  math .ceil (len (train_dataloader ) /  accelerator .num_processes )
800+         num_update_steps_per_epoch  =  math .ceil (len_train_dataloader_after_sharding  /  args .gradient_accumulation_steps )
801+         num_training_steps_for_scheduler  =  (
802+             args .num_train_epochs  *  num_update_steps_per_epoch  *  accelerator .num_processes 
803+         )
804+     else :
805+         num_training_steps_for_scheduler  =  args .max_train_steps  *  accelerator .num_processes 
801806
802807    lr_scheduler  =  get_scheduler (
803808        args .lr_scheduler ,
804809        optimizer = optimizer ,
805-         num_warmup_steps = args . lr_warmup_steps   *   accelerator . num_processes ,
806-         num_training_steps = args . max_train_steps   *   accelerator . num_processes ,
810+         num_warmup_steps = num_warmup_steps_for_scheduler ,
811+         num_training_steps = num_training_steps_for_scheduler ,
807812        num_cycles = args .lr_num_cycles ,
808813    )
809814
@@ -829,8 +834,14 @@ def main():
829834
830835    # We need to recalculate our total training steps as the size of the training dataloader may have changed. 
831836    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
832-     if  overrode_max_train_steps :
837+     if  args . max_train_steps   is   None :
833838        args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
839+         if  num_training_steps_for_scheduler  !=  args .max_train_steps  *  accelerator .num_processes :
840+             logger .warning (
841+                 f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )}  
842+                 f"the expected length ({ len_train_dataloader_after_sharding }  
843+                 f"This inconsistency may result in the learning rate scheduler not functioning properly." 
844+             )
834845    # Afterwards we recalculate our number of training epochs 
835846    args .num_train_epochs  =  math .ceil (args .max_train_steps  /  num_update_steps_per_epoch )
836847
0 commit comments