Skip to content

Commit 04521f6

Browse files
authored
Merge branch 'main' into patch-7
2 parents 3488d0d + df55f05 commit 04521f6

File tree

5 files changed

+34
-23
lines changed

5 files changed

+34
-23
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,11 @@ def log_validation(
178178
else:
179179
logger.warning(f"image logging not implemented for {tracker.name}")
180180

181-
del pipeline
182-
gc.collect()
183-
torch.cuda.empty_cache()
181+
del pipeline
182+
gc.collect()
183+
torch.cuda.empty_cache()
184184

185-
return image_logs
185+
return image_logs
186186

187187

188188
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):

examples/controlnet/train_controlnet_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ def log_validation(
192192
else:
193193
logger.warning(f"image logging not implemented for {tracker.name}")
194194

195-
del pipeline
196-
free_memory()
197-
return image_logs
195+
del pipeline
196+
free_memory()
197+
return image_logs
198198

199199

200200
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):

examples/controlnet/train_controlnet_sd3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,13 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
199199
else:
200200
logger.warning(f"image logging not implemented for {tracker.name}")
201201

202-
del pipeline
203-
free_memory()
202+
del pipeline
203+
free_memory()
204204

205-
if not is_final_validation:
206-
controlnet.to(accelerator.device)
205+
if not is_final_validation:
206+
controlnet.to(accelerator.device)
207207

208-
return image_logs
208+
return image_logs
209209

210210

211211
# Copied from dreambooth sd3 example

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
201201
else:
202202
logger.warning(f"image logging not implemented for {tracker.name}")
203203

204-
del pipeline
205-
gc.collect()
206-
torch.cuda.empty_cache()
204+
del pipeline
205+
gc.collect()
206+
torch.cuda.empty_cache()
207207

208-
return image_logs
208+
return image_logs
209209

210210

211211
def import_model_class_from_model_name_or_path(

examples/textual_inversion/textual_inversion_sdxl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -793,17 +793,22 @@ def main():
793793
)
794794

795795
# Scheduler and math around the number of training steps.
796-
overrode_max_train_steps = False
797-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
796+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
797+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
798798
if args.max_train_steps is None:
799-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
800-
overrode_max_train_steps = True
799+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
800+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
801+
num_training_steps_for_scheduler = (
802+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
803+
)
804+
else:
805+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
801806

802807
lr_scheduler = get_scheduler(
803808
args.lr_scheduler,
804809
optimizer=optimizer,
805-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
806-
num_training_steps=args.max_train_steps * accelerator.num_processes,
810+
num_warmup_steps=num_warmup_steps_for_scheduler,
811+
num_training_steps=num_training_steps_for_scheduler,
807812
num_cycles=args.lr_num_cycles,
808813
)
809814

@@ -829,8 +834,14 @@ def main():
829834

830835
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
831836
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
832-
if overrode_max_train_steps:
837+
if args.max_train_steps is None:
833838
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
839+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
840+
logger.warning(
841+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
842+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
843+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
844+
)
834845
# Afterwards we recalculate our number of training epochs
835846
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
836847

0 commit comments

Comments
 (0)