From bac27a1df3cc2f2e093377706ef74ef8a356705c Mon Sep 17 00:00:00 2001 From: tongyu <119610311+tongyu0924@users.noreply.github.com> Date: Sun, 27 Apr 2025 15:11:32 +0800 Subject: [PATCH 1/2] Update train_text_to_image.py --- examples/text_to_image/train_text_to_image.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8ab136179996..102037dac079 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -499,6 +499,15 @@ def parse_args(): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True): ) return inputs.input_ids - # Preprocessing the datasets. + # Get the specified interpolation method from the args + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + + # Raise an error if the interpolation method is invalid + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") + + # Data preprocessing transformations train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), From ec0caa14b345e70454926cff055247cc2bf91ef6 Mon Sep 17 00:00:00 2001 From: tongyu0924 Date: Mon, 28 Apr 2025 20:24:19 +0800 Subject: [PATCH 2/2] update --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 102037dac079..324891dc79ff 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -798,11 +798,11 @@ def tokenize_captions(examples, is_train=True): # Get the specified interpolation method from the args interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) - + # Raise an error if the interpolation method is invalid if interpolation is None: raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") - + # Data preprocessing transformations train_transforms = transforms.Compose( [