Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
from dataclasses import (
asdict,
dataclass,
Expand Down Expand Up @@ -108,6 +109,7 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
skip_duplicate_bucketed_images: bool = False

@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
Expand All @@ -118,7 +120,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0

@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
Expand Down Expand Up @@ -244,6 +246,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
"skip_duplicate_bucketed_images": bool,
}

# options handled by argparse but not handled by user config
Expand Down Expand Up @@ -530,6 +533,7 @@ def print_info(_datasets, dataset_type: str):
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
skip_duplicate_bucketed_images: {dataset.skip_duplicate_bucketed_images}
""")

if dataset.enable_bucket:
Expand Down Expand Up @@ -593,6 +597,52 @@ def print_info(_datasets, dataset_type: str):
dataset.make_buckets()
dataset.set_seed(seed)

# Optional dedup: remove later duplicates when the same image lands in the
# same bucket resolution across datasets.
seen_items = set()
for i, dataset in enumerate(datasets):
if not (dataset.bucket_no_upscale and dataset.skip_duplicate_bucketed_images):
continue

removed = 0
# Convert iterator to list to avoid mutating the iterator in the loop
for image_key, info in list(dataset.image_data.items()):
subset = dataset.image_to_subset.get(image_key)
if subset is None or info.bucket_reso is None:
continue

dedup_key = (
os.path.normcase(os.path.abspath(info.absolute_path)),
info.bucket_reso,
info.caption,
info.is_reg,
dataset.network_multiplier,
subset.flip_aug,
subset.alpha_mask,
subset.random_crop,
subset.shuffle_caption,
subset.caption_dropout_rate,
subset.caption_dropout_every_n_epochs,
subset.caption_tag_dropout_rate,
)

if dedup_key in seen_items:
dataset.image_data.pop(image_key, None)
dataset.image_to_subset.pop(image_key, None)
if info.is_reg:
dataset.num_reg_images -= info.num_repeats
else:
dataset.num_train_images -= info.num_repeats
removed += 1
else:
seen_items.add(dedup_key)

if removed > 0:
logger.info(f"[Prepare dataset {i}] removed {removed} duplicated images with same bucket resolution across datasets")
# make_buckets reuses existing bucket_manager when present, so clear it to avoid stale keys
dataset.bucket_manager = None
dataset.make_buckets()

for i, dataset in enumerate(val_datasets):
logger.info(f"[Prepare validation dataset {i}]")
dataset.make_buckets()
Expand Down
8 changes: 8 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def __init__(
self.token_strings = None

self.enable_bucket = False
self.skip_duplicate_bucketed_images = False
self.bucket_manager: BucketManager = None # not initialized
self.min_bucket_reso = None
self.max_bucket_reso = None
Expand Down Expand Up @@ -1914,6 +1915,7 @@ def __init__(
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
skip_duplicate_bucketed_images: bool,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
Expand All @@ -1929,6 +1931,7 @@ def __init__(
self.validation_split = validation_split

self.enable_bucket = enable_bucket
self.skip_duplicate_bucketed_images = skip_duplicate_bucketed_images
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
Expand Down Expand Up @@ -2199,6 +2202,7 @@ def __init__(
debug_dataset: bool,
validation_seed: int,
validation_split: float,
skip_duplicate_bucketed_images: bool,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
Expand All @@ -2208,6 +2212,7 @@ def __init__(
self.latents_cache = None

self.enable_bucket = enable_bucket
self.skip_duplicate_bucketed_images = skip_duplicate_bucketed_images
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
Expand Down Expand Up @@ -2386,6 +2391,7 @@ def __init__(
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
skip_duplicate_bucketed_images: bool,
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
Expand Down Expand Up @@ -2439,6 +2445,7 @@ def __init__(
debug_dataset,
validation_split,
validation_seed,
skip_duplicate_bucketed_images,
resize_interpolation,
)

Expand All @@ -2449,6 +2456,7 @@ def __init__(
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.skip_duplicate_bucketed_images = skip_duplicate_bucketed_images
self.resize_interpolation = resize_interpolation

# assert all conditioning data exists
Expand Down