Skip to content

Commit ad28907

Browse files
authored
Merge branch 'huggingface:main' into aspect_ratio_bucketing
2 parents 4646c60 + ec3d582 commit ad28907

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,15 @@ def parse_args(input_args=None):
770770
),
771771
)
772772
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
773+
parser.add_argument(
774+
"--image_interpolation_mode",
775+
type=str,
776+
default="lanczos",
777+
choices=[
778+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
779+
],
780+
help="The image interpolation method to use for resizing images.",
781+
)
773782

774783
if input_args is not None:
775784
args = parser.parse_args(input_args)
@@ -1034,7 +1043,10 @@ def __init__(
10341043
self.instance_images.extend(itertools.repeat(img, repeats))
10351044

10361045
self.pixel_values = []
1037-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
1046+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
1047+
if interpolation is None:
1048+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
1049+
train_resize = transforms.Resize(size, interpolation=interpolation)
10381050
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
10391051
train_flip = transforms.RandomHorizontalFlip(p=1.0)
10401052
train_transforms = transforms.Compose(
@@ -1078,7 +1090,7 @@ def __init__(
10781090

10791091
self.image_transforms = transforms.Compose(
10801092
[
1081-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1093+
transforms.Resize(size, interpolation=interpolation),
10821094
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
10831095
transforms.ToTensor(),
10841096
transforms.Normalize([0.5], [0.5]),

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,15 @@ def parse_args(input_args=None):
799799
default=False,
800800
help="Cache the VAE latents",
801801
)
802+
parser.add_argument(
803+
"--image_interpolation_mode",
804+
type=str,
805+
default="lanczos",
806+
choices=[
807+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
808+
],
809+
help="The image interpolation method to use for resizing images.",
810+
)
802811

803812
if input_args is not None:
804813
args = parser.parse_args(input_args)
@@ -1069,7 +1078,10 @@ def __init__(
10691078
self.original_sizes = []
10701079
self.crop_top_lefts = []
10711080
self.pixel_values = []
1072-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
1081+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
1082+
if interpolation is None:
1083+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
1084+
train_resize = transforms.Resize(size, interpolation=interpolation)
10731085
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
10741086
train_flip = transforms.RandomHorizontalFlip(p=1.0)
10751087
train_transforms = transforms.Compose(
@@ -1146,7 +1158,7 @@ def __init__(
11461158

11471159
self.image_transforms = transforms.Compose(
11481160
[
1149-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1161+
transforms.Resize(size, interpolation=interpolation),
11501162
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
11511163
transforms.ToTensor(),
11521164
transforms.Normalize([0.5], [0.5]),

0 commit comments

Comments
 (0)