Skip to content

Commit a35acc9

Browse files
authored
LANCZOS as default interplotation
1 parent 4a09f96 commit a35acc9

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,17 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
135135
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
136136
validation_image = Image.open(validation_image).convert("RGB")
137137

138-
# Use the same interpolation mode as in training
139-
if args.interpolation_type.lower() == "lanczos":
138+
# Get the interpolation mode from string
139+
try:
140+
interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper())
141+
except (AttributeError, KeyError):
140142
interpolation_mode = transforms.InterpolationMode.LANCZOS
141-
else:
142-
interpolation_mode = transforms.InterpolationMode.BILINEAR
143-
144-
transform = transforms.Compose([
145-
transforms.Resize(args.resolution, interpolation=interpolation_mode),
146-
transforms.CenterCrop(args.resolution),
147-
])
148-
validation_image = transform(validation_image)
149143

144+
validation_image = validation_image.resize(
145+
(args.resolution, args.resolution),
146+
resample=Image.Resampling.LANCZOS if interpolation_mode == transforms.InterpolationMode.LANCZOS else Image.Resampling.BILINEAR
147+
)
148+
150149
images = []
151150

152151
for _ in range(args.num_validation_images):
@@ -602,8 +601,11 @@ def parse_args(input_args=None):
602601
"--interpolation_type",
603602
type=str,
604603
default="lanczos",
605-
choices=["lanczos", "bilinear"],
606-
help="The interpolation method to use for resizing images. Choose between 'lanczos' (default) and 'bilinear'.",
604+
help=(
605+
"The interpolation method to use for resizing images. Choose between 'bilinear', 'bicubic', 'lanczos', "
606+
"'nearest', 'nearest-exact', 'area', etc. See https://pytorch.org/vision/stable/transforms.html for all "
607+
"options. Default is 'lanczos'."
608+
),
607609
)
608610

609611
if input_args is not None:
@@ -750,11 +752,14 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
750752

751753

752754
def prepare_train_dataset(dataset, accelerator):
753-
# Set the interpolation mode based on user preference
754-
if args.interpolation_type.lower() == "lanczos":
755+
# Get the interpolation mode from string
756+
try:
757+
interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper())
758+
except (AttributeError, KeyError):
759+
logger.warning(
760+
f"Interpolation mode {args.interpolation_type} not found. Falling back to LANCZOS."
761+
)
755762
interpolation_mode = transforms.InterpolationMode.LANCZOS
756-
else:
757-
interpolation_mode = transforms.InterpolationMode.BILINEAR
758763

759764
image_transforms = transforms.Compose(
760765
[

0 commit comments

Comments
 (0)