@@ -793,17 +793,22 @@ def main():
793793 )
794794
795795 # Scheduler and math around the number of training steps.
796- overrode_max_train_steps = False
797- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
796+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
797+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
798798 if args .max_train_steps is None :
799- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
800- overrode_max_train_steps = True
799+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
800+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
801+ num_training_steps_for_scheduler = (
802+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
803+ )
804+ else :
805+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
801806
802807 lr_scheduler = get_scheduler (
803808 args .lr_scheduler ,
804809 optimizer = optimizer ,
805- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
806- num_training_steps = args . max_train_steps * accelerator . num_processes ,
810+ num_warmup_steps = num_warmup_steps_for_scheduler ,
811+ num_training_steps = num_training_steps_for_scheduler ,
807812 num_cycles = args .lr_num_cycles ,
808813 )
809814
@@ -829,8 +834,14 @@ def main():
829834
830835 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
831836 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
832- if overrode_max_train_steps :
837+ if args . max_train_steps is None :
833838 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
839+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
840+ logger .warning (
841+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
842+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
843+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
844+ )
834845 # Afterwards we recalculate our number of training epochs
835846 args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
836847
0 commit comments