@@ -931,17 +931,24 @@ def load_model_hook(models, input_dir):
931931 )
932932
933933 # Scheduler and math around the number of training steps.
934- overrode_max_train_steps = False
935- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
934+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
935+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
936936 if args .max_train_steps is None :
937937 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
938938 overrode_max_train_steps = True
939+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
940+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
941+ num_training_steps_for_scheduler = (
942+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
943+ )
944+ else :
945+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
939946
940947 lr_scheduler = get_scheduler (
941948 args .lr_scheduler ,
942949 optimizer = optimizer ,
943- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
944- num_training_steps = args . max_train_steps * accelerator . num_processes ,
950+ num_warmup_steps = num_warmup_steps_for_scheduler ,
951+ num_training_steps = num_training_steps_for_scheduler ,
945952 num_cycles = args .lr_num_cycles ,
946953 power = args .lr_power ,
947954 )
@@ -966,8 +973,14 @@ def load_model_hook(models, input_dir):
966973
967974 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
968975 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
969- if overrode_max_train_steps :
976+ if args . max_train_steps is None :
970977 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
978+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
979+ logger .warning (
980+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
981+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
982+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
983+ )
971984 # Afterwards we recalculate our number of training epochs
972985 args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
973986
0 commit comments