Skip to content

Commit daa4fda

Browse files
committed
Add LANCZOS as default interplotation mode.
1 parent 8fe5a14 commit daa4fda

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,19 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
134134

135135
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
136136
validation_image = Image.open(validation_image).convert("RGB")
137-
validation_image = validation_image.resize((args.resolution, args.resolution))
138-
137+
138+
# Use the same interpolation mode as in training
139+
if args.interpolation_type.lower() == "lanczos":
140+
interpolation_mode = transforms.InterpolationMode.LANCZOS
141+
else:
142+
interpolation_mode = transforms.InterpolationMode.BILINEAR
143+
144+
transform = transforms.Compose([
145+
transforms.Resize(args.resolution, interpolation=interpolation_mode),
146+
transforms.CenterCrop(args.resolution),
147+
])
148+
validation_image = transform(validation_image)
149+
139150
images = []
140151

141152
for _ in range(args.num_validation_images):
@@ -587,6 +598,13 @@ def parse_args(input_args=None):
587598
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
588599
),
589600
)
601+
parser.add_argument(
602+
"--interpolation_type",
603+
type=str,
604+
default="lanczos",
605+
choices=["lanczos", "bilinear"],
606+
help="The interpolation method to use for resizing images. Choose between 'lanczos' (default) and 'bilinear'.",
607+
)
590608

591609
if input_args is not None:
592610
args = parser.parse_args(input_args)
@@ -732,9 +750,15 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
732750

733751

734752
def prepare_train_dataset(dataset, accelerator):
753+
# Set the interpolation mode based on user preference
754+
if args.interpolation_type.lower() == "lanczos":
755+
interpolation_mode = transforms.InterpolationMode.LANCZOS
756+
else:
757+
interpolation_mode = transforms.InterpolationMode.BILINEAR
758+
735759
image_transforms = transforms.Compose(
736760
[
737-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
761+
transforms.Resize(args.resolution, interpolation=interpolation_mode),
738762
transforms.CenterCrop(args.resolution),
739763
transforms.ToTensor(),
740764
transforms.Normalize([0.5], [0.5]),
@@ -743,7 +767,7 @@ def prepare_train_dataset(dataset, accelerator):
743767

744768
conditioning_image_transforms = transforms.Compose(
745769
[
746-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
770+
transforms.Resize(args.resolution, interpolation=interpolation_mode),
747771
transforms.CenterCrop(args.resolution),
748772
transforms.ToTensor(),
749773
]

0 commit comments

Comments
 (0)