@@ -1119,17 +1119,23 @@ def compute_text_embeddings(prompt):
11191119 )
11201120
11211121 # Scheduler and math around the number of training steps.
1122- overrode_max_train_steps = False
1123- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
1122+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1123+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
11241124 if args .max_train_steps is None :
1125- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1126- overrode_max_train_steps = True
1125+ len_train_dataloader_after_sharding = ceil (len (train_dataloader ) / accelerator .num_processes )
1126+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
1127+ num_training_steps_for_scheduler = (
1128+ args .num_train_epochs * accelerator .num_processes * num_update_steps_per_epoch
1129+ )
1130+ else :
1131+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
1132+
11271133
11281134 lr_scheduler = get_scheduler (
11291135 args .lr_scheduler ,
11301136 optimizer = optimizer ,
1131- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
1132- num_training_steps = args . max_train_steps * accelerator . num_processes ,
1137+ num_warmup_steps = num_warmup_steps_for_scheduler ,
1138+ num_training_steps = num_training_steps_for_scheduler ,
11331139 num_cycles = args .lr_num_cycles ,
11341140 power = args .lr_power ,
11351141 )
@@ -1146,8 +1152,15 @@ def compute_text_embeddings(prompt):
11461152
11471153 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
11481154 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
1149- if overrode_max_train_steps :
1150- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1155+ if args .max_train_steps is None :
1156+ args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1157+ if num_training_steps_for_scheduler != args .max_train_steps :
1158+ logger .warning (
1159+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
1160+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
1161+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1162+ )
1163+
11511164 # Afterwards we recalculate our number of training epochs
11521165 args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
11531166
0 commit comments