From e3cf1deb9113268e174aff304f60c0ec50b240f7 Mon Sep 17 00:00:00 2001 From: flyxiv <78268298+flyxiv@users.noreply.github.com> Date: Wed, 5 Mar 2025 23:22:16 +0900 Subject: [PATCH 1/4] updated train_dreambooth_lora to fix the LR schedulers for `num_train_epochs` in distributed training env --- examples/dreambooth/train_dreambooth_lora.py | 29 ++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 07b14e1ddc0c..70f7cc8a2e5f 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1119,17 +1119,23 @@ def compute_text_embeddings(prompt): ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, num_cycles=args.lr_num_cycles, power=args.lr_power, ) @@ -1146,8 +1152,15 @@ def compute_text_embeddings(prompt): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) From 58c90b112daa61f62a3e22ebd1ece1a266ebdf2c Mon Sep 17 00:00:00 2001 From: flyxiv Date: Wed, 5 Mar 2025 23:49:05 +0900 Subject: [PATCH 2/4] fixed formatting --- examples/dreambooth/train_dreambooth_lora.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 70f7cc8a2e5f..c5fe62f031bc 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1119,16 +1119,16 @@ def compute_text_embeddings(prompt): ) # Scheduler and math around the number of training steps. - # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: len_train_dataloader_after_sharding = ceil(len(train_dataloader) / accelerator.num_processes) - num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) - num_training_steps_for_scheduler = ( + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch - ) - else: - num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( @@ -1153,13 +1153,13 @@ def compute_text_embeddings(prompt): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - if num_training_steps_for_scheduler != args.max_train_steps: - logger.warning( - f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " f"This inconsistency may result in the learning rate scheduler not functioning properly." - ) + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) From 00c29f93fa5e2e0d4b7a278296960fb823b02799 Mon Sep 17 00:00:00 2001 From: flyxiv Date: Thu, 6 Mar 2025 00:04:06 +0900 Subject: [PATCH 3/4] remove trailing newlines --- examples/dreambooth/train_dreambooth_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index c5fe62f031bc..8089612cf1b4 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1130,7 +1130,6 @@ def compute_text_embeddings(prompt): else: num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes - lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, From 41271239138dbf0032327970ca923fc46a73696b Mon Sep 17 00:00:00 2001 From: flyxiv Date: Thu, 6 Mar 2025 10:43:05 +0900 Subject: [PATCH 4/4] fixed style error --- examples/dreambooth/train_dreambooth_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 8089612cf1b4..9584e7762dbd 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1122,7 +1122,7 @@ def compute_text_embeddings(prompt): # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - len_train_dataloader_after_sharding = ceil(len(train_dataloader) / accelerator.num_processes) + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) num_training_steps_for_scheduler = ( args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch