Skip to content

Commit 4646c60

Browse files
Apply style fixes
1 parent 0df0ea1 commit 4646c60

File tree

3 files changed

+54
-58
lines changed

3 files changed

+54
-58
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import warnings
2525
from contextlib import nullcontext
2626
from pathlib import Path
27-
from torch.utils.data.sampler import Sampler, BatchSampler
2827

2928
import numpy as np
3029
import torch
@@ -40,6 +39,7 @@
4039
from PIL import Image
4140
from PIL.ImageOps import exif_transpose
4241
from torch.utils.data import Dataset
42+
from torch.utils.data.sampler import BatchSampler
4343
from torchvision import transforms
4444
from torchvision.transforms.functional import crop
4545
from tqdm.auto import tqdm
@@ -58,9 +58,9 @@
5858
cast_training_params,
5959
compute_density_for_timestep_sampling,
6060
compute_loss_weighting_for_sd3,
61+
find_nearest_bucket,
6162
free_memory,
6263
parse_buckets_string,
63-
find_nearest_bucket
6464
)
6565
from diffusers.utils import (
6666
check_min_version,
@@ -911,11 +911,9 @@ def collate_fn(examples, with_prior_preservation=False):
911911
class BucketBatchSampler(BatchSampler):
912912
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
913913
if not isinstance(batch_size, int) or batch_size <= 0:
914-
raise ValueError("batch_size should be a positive integer value, "
915-
"but got batch_size={}".format(batch_size))
914+
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
916915
if not isinstance(drop_last, bool):
917-
raise ValueError("drop_last should be a boolean value, but got "
918-
"drop_last={}".format(drop_last))
916+
raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
919917

920918
self.dataset = dataset
921919
self.batch_size = batch_size
@@ -935,7 +933,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
935933
random.shuffle(indices_in_bucket)
936934
# Create batches
937935
for i in range(0, len(indices_in_bucket), self.batch_size):
938-
batch = indices_in_bucket[i:i + self.batch_size]
936+
batch = indices_in_bucket[i : i + self.batch_size]
939937
if len(batch) < self.batch_size and self.drop_last:
940938
continue # Skip partial batch if drop_last is True
941939
self.batches.append(batch)
@@ -1512,10 +1510,7 @@ def load_model_hook(models, input_dir):
15121510
repeats=args.repeats,
15131511
center_crop=args.center_crop,
15141512
)
1515-
batch_sampler = BucketBatchSampler(
1516-
train_dataset,
1517-
batch_size=args.train_batch_size,
1518-
drop_last=False)
1513+
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
15191514
train_dataloader = torch.utils.data.DataLoader(
15201515
train_dataset,
15211516
batch_sampler=batch_sampler,

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import warnings
2525
from contextlib import nullcontext
2626
from pathlib import Path
27-
from torch.utils.data.sampler import Sampler, BatchSampler
2827

2928
import numpy as np
3029
import torch
@@ -40,6 +39,7 @@
4039
from PIL import Image
4140
from PIL.ImageOps import exif_transpose
4241
from torch.utils.data import Dataset
42+
from torch.utils.data.sampler import BatchSampler
4343
from torchvision import transforms
4444
from torchvision.transforms.functional import crop
4545
from tqdm.auto import tqdm
@@ -57,9 +57,9 @@
5757
cast_training_params,
5858
compute_density_for_timestep_sampling,
5959
compute_loss_weighting_for_sd3,
60+
find_nearest_bucket,
6061
free_memory,
6162
parse_buckets_string,
62-
find_nearest_bucket
6363
)
6464
from diffusers.utils import (
6565
check_min_version,
@@ -70,6 +70,7 @@
7070
from diffusers.utils.import_utils import is_torch_npu_available
7171
from diffusers.utils.torch_utils import is_compiled_module
7272

73+
7374
if is_wandb_available():
7475
import wandb
7576

@@ -81,13 +82,14 @@
8182
if is_torch_npu_available():
8283
torch.npu.config.allow_internal_format = False
8384

85+
8486
def save_model_card(
85-
repo_id: str,
86-
images=None,
87-
base_model: str = None,
88-
instance_prompt=None,
89-
validation_prompt=None,
90-
repo_folder=None,
87+
repo_id: str,
88+
images=None,
89+
base_model: str = None,
90+
instance_prompt=None,
91+
validation_prompt=None,
92+
repo_folder=None,
9193
):
9294
widget_dict = []
9395
if images is not None:
@@ -189,13 +191,13 @@ def load_text_encoders(class_one, class_two, class_three):
189191

190192

191193
def log_validation(
192-
pipeline,
193-
args,
194-
accelerator,
195-
pipeline_args,
196-
epoch,
197-
torch_dtype,
198-
is_final_validation=False,
194+
pipeline,
195+
args,
196+
accelerator,
197+
pipeline_args,
198+
epoch,
199+
torch_dtype,
200+
is_final_validation=False,
199201
):
200202
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
201203
logger.info(
@@ -244,7 +246,7 @@ def log_validation(
244246

245247

246248
def import_model_class_from_model_name_or_path(
247-
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
249+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
248250
):
249251
text_encoder_config = PretrainedConfig.from_pretrained(
250252
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
@@ -331,8 +333,8 @@ def parse_args(input_args=None):
331333
type=str,
332334
default="image",
333335
help="The column of the dataset containing the target image. By "
334-
"default, the standard Image Dataset maps out 'file_name' "
335-
"to 'image'.",
336+
"default, the standard Image Dataset maps out 'file_name' "
337+
"to 'image'.",
336338
)
337339
parser.add_argument(
338340
"--caption_column",
@@ -598,7 +600,7 @@ def parse_args(input_args=None):
598600
type=float,
599601
default=None,
600602
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
601-
"uses the value of square root of beta2. Ignored if optimizer is adamW",
603+
"uses the value of square root of beta2. Ignored if optimizer is adamW",
602604
)
603605
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
604606
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
@@ -629,7 +631,7 @@ def parse_args(input_args=None):
629631
type=bool,
630632
default=True,
631633
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
632-
"Ignored if optimizer is adamW",
634+
"Ignored if optimizer is adamW",
633635
)
634636
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
635637
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@@ -736,17 +738,17 @@ class DreamBoothDataset(Dataset):
736738
"""
737739

738740
def __init__(
739-
self,
740-
instance_data_root,
741-
instance_prompt,
742-
class_prompt,
743-
class_data_root=None,
744-
class_num=None,
745-
size=1024,
746-
repeats=1,
747-
center_crop=False,
748-
buckets=[(1024,1024),(768,1360),(1360, 768),(880, 1168),(1168, 880), (1248, 832), (832, 1248)],
749-
# buckets=[(1024, 1024)],
741+
self,
742+
instance_data_root,
743+
instance_prompt,
744+
class_prompt,
745+
class_data_root=None,
746+
class_num=None,
747+
size=1024,
748+
repeats=1,
749+
center_crop=False,
750+
buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)],
751+
# buckets=[(1024, 1024)],
750752
):
751753
# self.size = (size, size)
752754
self.center_crop = center_crop
@@ -930,11 +932,9 @@ def collate_fn(examples, with_prior_preservation=False):
930932
class BucketBatchSampler(BatchSampler):
931933
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
932934
if not isinstance(batch_size, int) or batch_size <= 0:
933-
raise ValueError("batch_size should be a positive integer value, "
934-
"but got batch_size={}".format(batch_size))
935+
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
935936
if not isinstance(drop_last, bool):
936-
raise ValueError("drop_last should be a boolean value, but got "
937-
"drop_last={}".format(drop_last))
937+
raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
938938

939939
self.dataset = dataset
940940
self.batch_size = batch_size
@@ -954,7 +954,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
954954
random.shuffle(indices_in_bucket)
955955
# Create batches
956956
for i in range(0, len(indices_in_bucket), self.batch_size):
957-
batch = indices_in_bucket[i:i + self.batch_size]
957+
batch = indices_in_bucket[i : i + self.batch_size]
958958
if len(batch) < self.batch_size and self.drop_last:
959959
continue # Skip partial batch if drop_last is True
960960
self.batches.append(batch)
@@ -1064,7 +1064,7 @@ def main(args):
10641064
pipeline.to(accelerator.device)
10651065

10661066
for example in tqdm(
1067-
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
1067+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
10681068
):
10691069
images = pipeline(example["prompt"]).images
10701070

@@ -1278,7 +1278,7 @@ def load_model_hook(models, input_dir):
12781278

12791279
if args.scale_lr:
12801280
args.learning_rate = (
1281-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1281+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
12821282
)
12831283

12841284
# Make sure the trainable params are in float32.
@@ -1368,10 +1368,7 @@ def load_model_hook(models, input_dir):
13681368
repeats=args.repeats,
13691369
center_crop=args.center_crop,
13701370
)
1371-
batch_sampler = BucketBatchSampler(
1372-
train_dataset,
1373-
batch_size=args.train_batch_size,
1374-
drop_last=False)
1371+
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
13751372
train_dataloader = torch.utils.data.DataLoader(
13761373
train_dataset,
13771374
batch_sampler=batch_sampler,

src/diffusers/training_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import gc
44
import math
55
import random
6+
import re
67
import warnings
78
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
8-
import re
9+
910
import numpy as np
1011
import torch
1112

@@ -308,12 +309,13 @@ def free_memory():
308309
elif hasattr(torch, "xpu") and torch.xpu.is_available():
309310
torch.xpu.empty_cache()
310311

312+
311313
def parse_buckets_string(buckets_str):
312-
""" Parses a string defining buckets into a list of (height, width) tuples. """
314+
"""Parses a string defining buckets into a list of (height, width) tuples."""
313315
if not buckets_str:
314316
raise ValueError("Bucket string cannot be empty.")
315317

316-
bucket_pairs = buckets_str.strip().split(';')
318+
bucket_pairs = buckets_str.strip().split(";")
317319
parsed_buckets = []
318320
for pair_str in bucket_pairs:
319321
match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
@@ -335,9 +337,10 @@ def parse_buckets_string(buckets_str):
335337

336338
return parsed_buckets
337339

340+
338341
def find_nearest_bucket(h, w, bucket_options):
339-
""" Finds the closes bucket to the given height and width. """
340-
min_metric = float('inf')
342+
"""Finds the closes bucket to the given height and width."""
343+
min_metric = float("inf")
341344
best_bucket_idx = None
342345
for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
343346
metric = abs(h * bucket_w - w * bucket_h)
@@ -346,6 +349,7 @@ def find_nearest_bucket(h, w, bucket_options):
346349
best_bucket_idx = bucket_idx
347350
return best_bucket_idx
348351

352+
349353
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
350354
class EMAModel:
351355
"""

0 commit comments

Comments
 (0)