From 2324abbc96df53318f1c8d045a72d715a8a17725 Mon Sep 17 00:00:00 2001 From: Roger Date: Mon, 5 May 2025 19:34:58 +0530 Subject: [PATCH 1/2] Update training script for txt to img sdxl with lora supp with new interpolation. --- .../train_text_to_image_lora_sdxl.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 0a93a0d1c4c2..86e0b34243ee 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -480,6 +480,15 @@ def parse_args(input_args=None): action="store_true", help="debug loss for each image, if filenames are available in the dataset", ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -913,8 +922,14 @@ def tokenize_captions(examples, is_train=True): tokens_two = tokenize_prompt(tokenizer_two, captions) return tokens_one, tokens_two + # Get the specified interpolation method from the args + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + + # Raise an error if the interpolation method is invalid + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( From 6287b347d309d7c0187f56c2c839d4c5d9f9f0ad Mon Sep 17 00:00:00 2001 From: Roger Date: Mon, 5 May 2025 21:06:29 +0530 Subject: [PATCH 2/2] ran make style and make quality. --- examples/text_to_image/train_text_to_image_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 86e0b34243ee..a5c72c1e9716 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -929,7 +929,7 @@ def tokenize_captions(examples, is_train=True): if interpolation is None: raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method + train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose(