Skip to content

Commit 8e3affc

Browse files
fix for lr scheduler in distributed training (#9103)
* fix for lr scheduler in distributed training * Fixed the recalculation of the total training step section * Fixed lint error --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent ba7e484 commit 8e3affc

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,17 +826,22 @@ def collate_fn(examples):
826826
)
827827

828828
# Scheduler and math around the number of training steps.
829-
overrode_max_train_steps = False
830-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
829+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
830+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
831831
if args.max_train_steps is None:
832-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
833-
overrode_max_train_steps = True
832+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
833+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
834+
num_training_steps_for_scheduler = (
835+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
836+
)
837+
else:
838+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
834839

835840
lr_scheduler = get_scheduler(
836841
args.lr_scheduler,
837842
optimizer=optimizer,
838-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
839-
num_training_steps=args.max_train_steps * accelerator.num_processes,
843+
num_warmup_steps=num_warmup_steps_for_scheduler,
844+
num_training_steps=num_training_steps_for_scheduler,
840845
)
841846

842847
# Prepare everything with our `accelerator`.
@@ -866,8 +871,14 @@ def collate_fn(examples):
866871

867872
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
868873
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
869-
if overrode_max_train_steps:
874+
if args.max_train_steps is None:
870875
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
876+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
877+
logger.warning(
878+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
879+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
880+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
881+
)
871882
# Afterwards we recalculate our number of training epochs
872883
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
873884

0 commit comments

Comments
 (0)