Skip to content

Commit 1e5e2a5

Browse files
committed
fix num_train_epochs
1 parent 7811185 commit 1e5e2a5

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -931,17 +931,24 @@ def load_model_hook(models, input_dir):
931931
)
932932

933933
# Scheduler and math around the number of training steps.
934-
overrode_max_train_steps = False
935-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
934+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
935+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
936936
if args.max_train_steps is None:
937937
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
938938
overrode_max_train_steps = True
939+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
940+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
941+
num_training_steps_for_scheduler = (
942+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
943+
)
944+
else:
945+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
939946

940947
lr_scheduler = get_scheduler(
941948
args.lr_scheduler,
942949
optimizer=optimizer,
943-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
944-
num_training_steps=args.max_train_steps * accelerator.num_processes,
950+
num_warmup_steps=num_warmup_steps_for_scheduler,
951+
num_training_steps=num_training_steps_for_scheduler,
945952
num_cycles=args.lr_num_cycles,
946953
power=args.lr_power,
947954
)
@@ -966,8 +973,14 @@ def load_model_hook(models, input_dir):
966973

967974
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
968975
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
969-
if overrode_max_train_steps:
976+
if args.max_train_steps is None:
970977
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
978+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
979+
logger.warning(
980+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
981+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
982+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
983+
)
971984
# Afterwards we recalculate our number of training epochs
972985
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
973986

0 commit comments

Comments
 (0)