Skip to content

Commit 22fc9d9

Browse files
authored
Merge branch 'main' into auraflow-lora
2 parents dbc8427 + 723dbdd commit 22fc9d9

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -927,17 +927,22 @@ def load_model_hook(models, input_dir):
927927
)
928928

929929
# Scheduler and math around the number of training steps.
930-
overrode_max_train_steps = False
931-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
930+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
931+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
932932
if args.max_train_steps is None:
933-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
934-
overrode_max_train_steps = True
933+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
934+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
935+
num_training_steps_for_scheduler = (
936+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
937+
)
938+
else:
939+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
935940

936941
lr_scheduler = get_scheduler(
937942
args.lr_scheduler,
938943
optimizer=optimizer,
939-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
940-
num_training_steps=args.max_train_steps * accelerator.num_processes,
944+
num_warmup_steps=num_warmup_steps_for_scheduler,
945+
num_training_steps=num_training_steps_for_scheduler,
941946
num_cycles=args.lr_num_cycles,
942947
power=args.lr_power,
943948
)
@@ -962,8 +967,14 @@ def load_model_hook(models, input_dir):
962967

963968
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
964969
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
965-
if overrode_max_train_steps:
970+
if args.max_train_steps is None:
966971
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
972+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
973+
logger.warning(
974+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
975+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
976+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
977+
)
967978
# Afterwards we recalculate our number of training epochs
968979
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
969980

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,16 @@ def parse_args(input_args=None):
669669
),
670670
)
671671

672+
parser.add_argument(
673+
"--image_interpolation_mode",
674+
type=str,
675+
default="lanczos",
676+
choices=[
677+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
678+
],
679+
help="The image interpolation method to use for resizing images.",
680+
)
681+
672682
if input_args is not None:
673683
args = parser.parse_args(input_args)
674684
else:
@@ -790,7 +800,12 @@ def __init__(
790800
self.original_sizes = []
791801
self.crop_top_lefts = []
792802
self.pixel_values = []
793-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
803+
804+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
805+
if interpolation is None:
806+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
807+
train_resize = transforms.Resize(size, interpolation=interpolation)
808+
794809
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
795810
train_flip = transforms.RandomHorizontalFlip(p=1.0)
796811
train_transforms = transforms.Compose(

0 commit comments

Comments
 (0)