Skip to content

Commit cae7273

Browse files
authored
Merge branch 'main' into feat/lanczos-resize-sd15-advanced
2 parents 77cf919 + ee1516e commit cae7273

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

examples/dreambooth/train_dreambooth_lora_lumina2.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,15 @@ def parse_args(input_args=None):
599599
"Defaults to precision dtype used for training to save memory"
600600
),
601601
)
602+
parser.add_argument(
603+
"--image_interpolation_mode",
604+
type=str,
605+
default="lanczos",
606+
choices=[
607+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
608+
],
609+
help="The image interpolation method to use for resizing images.",
610+
)
602611
parser.add_argument(
603612
"--offload",
604613
action="store_true",
@@ -724,7 +733,11 @@ def __init__(
724733
self.instance_images.extend(itertools.repeat(img, repeats))
725734

726735
self.pixel_values = []
727-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
736+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
737+
if interpolation is None:
738+
raise ValueError(f"Unsupported interpolation mode: {args.image_interpolation_mode}")
739+
740+
train_resize = transforms.Resize(size, interpolation=interpolation)
728741
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
729742
train_flip = transforms.RandomHorizontalFlip(p=1.0)
730743
train_transforms = transforms.Compose(
@@ -768,7 +781,7 @@ def __init__(
768781

769782
self.image_transforms = transforms.Compose(
770783
[
771-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
784+
transforms.Resize(size, interpolation=interpolation),
772785
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
773786
transforms.ToTensor(),
774787
transforms.Normalize([0.5], [0.5]),

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def __init__(
852852

853853
self.image_transforms = transforms.Compose(
854854
[
855-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
855+
transforms.Resize(size, interpolation=interpolation),
856856
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
857857
transforms.ToTensor(),
858858
transforms.Normalize([0.5], [0.5]),

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,15 @@ def parse_args(input_args=None):
470470
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
471471
)
472472
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
473+
parser.add_argument(
474+
"--image_interpolation_mode",
475+
type=str,
476+
default="lanczos",
477+
choices=[
478+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
479+
],
480+
help="The image interpolation method to use for resizing images.",
481+
)
473482

474483
if input_args is not None:
475484
args = parser.parse_args(input_args)
@@ -861,7 +870,10 @@ def load_model_hook(models, input_dir):
861870
)
862871

863872
# Preprocessing the datasets.
864-
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
873+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
874+
if interpolation is None:
875+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
876+
train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
865877
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
866878
train_flip = transforms.RandomHorizontalFlip(p=1.0)
867879
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

0 commit comments

Comments
 (0)