Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2da7c20
initial commit
linoytsaban Apr 25, 2025
6517a70
initial commit
linoytsaban Apr 25, 2025
873fe89
fix best bucket
linoytsaban Apr 27, 2025
fa4765c
fix best bucket
linoytsaban Apr 27, 2025
9ad4b61
fix best bucket
linoytsaban Apr 28, 2025
5211ffa
Merge branch 'huggingface:main' into aspect_ratio_bucketing
linoytsaban Apr 28, 2025
4130560
make it configurable
linoytsaban Apr 28, 2025
bd7a8b8
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Apr 28, 2025
50782b7
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Apr 30, 2025
314cbdb
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Apr 30, 2025
1571961
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 2, 2025
b817ca1
move `find_nearest_bucket`, `parse_buckets_string` to training_utils.py
linoytsaban May 2, 2025
b0d77ee
move `find_nearest_bucket`, `parse_buckets_string` to training_utils.py
linoytsaban May 2, 2025
f5636c6
fix import
linoytsaban May 2, 2025
0df0ea1
fix flip
linoytsaban May 2, 2025
4646c60
Apply style fixes
github-actions[bot] May 2, 2025
ad28907
Merge branch 'huggingface:main' into aspect_ratio_bucketing
linoytsaban May 5, 2025
d442e5a
cleanup
linoytsaban May 5, 2025
b6d180b
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 5, 2025
1f235b5
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 5, 2025
bcd6a60
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 8, 2025
fd8bccf
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 22, 2025
eb59238
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 27, 2025
84d5b20
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 29, 2025
4ec0872
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 148 additions & 25 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import warnings
from contextlib import nullcontext
from pathlib import Path
from torch.utils.data.sampler import Sampler, BatchSampler

import numpy as np
import torch
Expand Down Expand Up @@ -65,6 +66,7 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module


Expand All @@ -76,6 +78,38 @@

logger = get_logger(__name__)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False


def parse_buckets_string(buckets_str):
""" Parses a string defining buckets into a list of (height, width) tuples. """
if not buckets_str:
raise ValueError("Bucket string cannot be empty.")

bucket_pairs = buckets_str.strip().split(';')
parsed_buckets = []
for pair_str in bucket_pairs:
match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
if not match:
raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
try:
height = int(match.group(1))
width = int(match.group(2))
if height <= 0 or width <= 0:
raise ValueError("Bucket dimensions must be positive integers.")
if height % 8 != 0 or width % 8 != 0:
logger.warning(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
parsed_buckets.append((height, width))
except ValueError as e:
raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e

if not parsed_buckets:
raise ValueError("No valid buckets found in the provided string.")

logger.info(f"Using parsed aspect ratio buckets: {parsed_buckets}")
return parsed_buckets


def save_model_card(
repo_id: str,
Expand Down Expand Up @@ -390,6 +424,16 @@ def parse_args(input_args=None):
" resolution"
),
)
parser.add_argument(
"--aspect_ratio_buckets",
type=str,
default=None,
help=(
"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
"Images will be resized and cropped to fit the nearest bucket."
),
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -700,6 +744,19 @@ class DreamBoothDataset(Dataset):
It pre-processes the images.
"""

@staticmethod
def find_nearest_bucket(h, w, bucket_options):
min_metric = float('inf')
best_bucket = None
best_bucket_idx = None
for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
metric = abs(h * bucket_w - w * bucket_h)
if metric <= min_metric:
min_metric = metric
best_bucket = (bucket_h, bucket_w)
best_bucket_idx = bucket_idx
return best_bucket_idx

def __init__(
self,
instance_data_root,
Expand All @@ -710,14 +767,18 @@ def __init__(
size=1024,
repeats=1,
center_crop=False,
buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)],
# buckets=[(1024, 1024)],
):
self.size = size
# self.size = (size, size)
self.center_crop = center_crop

self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.class_prompt = class_prompt

self.buckets = buckets

# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
Expand Down Expand Up @@ -782,32 +843,40 @@ def __init__(
self.instance_images.extend(itertools.repeat(img, repeats))

self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
image = train_resize(image)

if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)

width, height = image.size
print("width, height", width, height)
# Find the closest bucket
bucket_idx = find_nearest_bucket(height, width, self.buckets)
target_height, target_width = self.buckets[bucket_idx]
self.size = (target_height, target_width)

# based on the bucket assignment, define the transformations
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
image = train_resize(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
y1, x1, h, w = train_crop.get_params(image, self.size)
image = crop(image, y1, x1, h, w)
image = train_transforms(image)
self.pixel_values.append(image)
self.pixel_values.append((image, bucket_idx))

self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
Expand All @@ -826,8 +895,8 @@ def __init__(

self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
Expand All @@ -838,8 +907,9 @@ def __len__(self):

def __getitem__(self, index):
example = {}
instance_image = self.pixel_values[index % self.num_instance_images]
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx

if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
Expand Down Expand Up @@ -880,6 +950,49 @@ def collate_fn(examples, with_prior_preservation=False):
return batch


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))

self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
self.bucket_indices[bucket_idx].append(idx)

self.sampler_len = 0
self.batches = []

# Pre-generate batches for each bucket
for indices_in_bucket in self.bucket_indices:
# Shuffle indices within the bucket
random.shuffle(indices_in_bucket)
# Create batches
for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i:i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue # Skip partial batch if drop_last is True
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
for batch in self.batches:
yield batch

def __len__(self):
return self.sampler_len


class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."

Expand Down Expand Up @@ -1134,8 +1247,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1425,6 +1537,11 @@ def load_model_hook(models, input_dir):
safeguard_warmup=args.prodigy_safeguard_warmup,
)

if args.aspect_ratio_buckets:
buckets = parse_buckets_string(args.aspect_ratio_buckets)
else:
buckets = [(args.resolution, args.resolution)]

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
Expand All @@ -1433,14 +1550,19 @@ def load_model_hook(models, input_dir):
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
)

train_dataloader = torch.utils.data.DataLoader(
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
# batch_size=args.train_batch_size,
# shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
Expand Down Expand Up @@ -1879,7 +2001,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
free_memory()

images = None
del pipeline
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down Expand Up @@ -1927,6 +2049,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
is_final_validation=True,
torch_dtype=weight_dtype,
)
del pipeline
free_memory()

if args.push_to_hub:
save_model_card(
Expand All @@ -1946,7 +2070,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

images = None
del pipeline

accelerator.end_training()

Expand Down
Loading
Loading