@@ -826,17 +826,22 @@ def collate_fn(examples):
826826 )
827827
828828 # Scheduler and math around the number of training steps.
829- overrode_max_train_steps = False
830- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
829+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
830+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
831831 if args .max_train_steps is None :
832- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
833- overrode_max_train_steps = True
832+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
833+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
834+ num_training_steps_for_scheduler = (
835+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
836+ )
837+ else :
838+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
834839
835840 lr_scheduler = get_scheduler (
836841 args .lr_scheduler ,
837842 optimizer = optimizer ,
838- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
839- num_training_steps = args . max_train_steps * accelerator . num_processes ,
843+ num_warmup_steps = num_warmup_steps_for_scheduler ,
844+ num_training_steps = num_training_steps_for_scheduler ,
840845 )
841846
842847 # Prepare everything with our `accelerator`.
@@ -866,8 +871,14 @@ def collate_fn(examples):
866871
867872 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
868873 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
869- if overrode_max_train_steps :
874+ if args . max_train_steps is None :
870875 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
876+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
877+ logger .warning (
878+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
879+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
880+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
881+ )
871882 # Afterwards we recalculate our number of training epochs
872883 args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
873884
0 commit comments