@@ -1524,17 +1524,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15241524 torch .cuda .empty_cache ()
15251525
15261526 # Scheduler and math around the number of training steps.
1527- overrode_max_train_steps = False
1528- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
1527+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1528+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
15291529 if args .max_train_steps is None :
1530- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1531- overrode_max_train_steps = True
1530+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
1531+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
1532+ num_training_steps_for_scheduler = (
1533+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
1534+ )
1535+ else :
1536+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
15321537
15331538 lr_scheduler = get_scheduler (
15341539 args .lr_scheduler ,
15351540 optimizer = optimizer ,
1536- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
1537- num_training_steps = args . max_train_steps * accelerator . num_processes ,
1541+ num_warmup_steps = num_warmup_steps_for_scheduler ,
1542+ num_training_steps = num_training_steps_for_scheduler ,
15381543 num_cycles = args .lr_num_cycles ,
15391544 power = args .lr_power ,
15401545 )
@@ -1551,8 +1556,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15511556
15521557 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
15531558 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
1554- if overrode_max_train_steps :
1559+ if args . max_train_steps is None :
15551560 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1561+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
1562+ logger .warning (
1563+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
1564+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
1565+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1566+ )
15561567 # Afterwards we recalculate our number of training epochs
15571568 args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
15581569
0 commit comments