From 561ccbc558565f24ba3d07e830080b589a140b45 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Tue, 29 Apr 2025 15:41:03 +0000 Subject: [PATCH 1/4] Add LANCZOS as default interplotation mode. --- examples/text_to_image/train_text_to_image_sdxl.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 06b9bf5f3ef0..b65441d7c70e 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -470,6 +470,15 @@ def parse_args(input_args=None): "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -861,7 +870,10 @@ def load_model_hook(models, input_dir): ) # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + train_resize = transforms.Resize(size, interpolation=interpolation) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) From 5936be4dc187a4db5eec528bc0f71ed3066bf576 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Tue, 29 Apr 2025 17:26:54 +0000 Subject: [PATCH 2/4] update script --- examples/text_to_image/train_text_to_image_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index b65441d7c70e..0147b4e8bfad 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -873,7 +873,7 @@ def load_model_hook(models, input_dir): interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) if interpolation is None: raise ValueError(f"Unsupported interpolation mode {interpolation=}.") - train_resize = transforms.Resize(size, interpolation=interpolation) + train_resize = transforms.Resize(args.resolution,, interpolation=interpolation) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) From 2e8dfe40f49bd3bb97f6e4d1822e6d7fe7956e98 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Wed, 30 Apr 2025 12:24:56 +0530 Subject: [PATCH 3/4] Update as per code review. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- examples/text_to_image/train_text_to_image_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 0147b4e8bfad..87241646fe70 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -873,7 +873,7 @@ def load_model_hook(models, input_dir): interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) if interpolation is None: raise ValueError(f"Unsupported interpolation mode {interpolation=}.") - train_resize = transforms.Resize(args.resolution,, interpolation=interpolation) + train_resize = transforms.Resize(args.resolution, interpolation=interpolation) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) From 3a2759f839e4db810e58ea11fafaf4148f0f01e3 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Sat, 3 May 2025 04:11:40 +0000 Subject: [PATCH 4/4] make style. --- examples/text_to_image/train_text_to_image_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 87241646fe70..ba059fd6fa11 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -872,7 +872,7 @@ def load_model_hook(models, input_dir): # Preprocessing the datasets. interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) if interpolation is None: - raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") train_resize = transforms.Resize(args.resolution, interpolation=interpolation) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0)