diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ba733efe6003..480f4b36df61 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -418,6 +418,15 @@ def parse_args(): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -649,10 +658,17 @@ def tokenize_captions(examples, is_train=True): ) return inputs.input_ids - # Preprocessing the datasets. + # Get the specified interpolation method from the args + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + + # Raise an error if the interpolation method is invalid + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") + + # Data preprocessing transformations train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(),