Skip to content

Commit e94b91d

Browse files
authored
LANCZOS as default interplotation mode
1 parent a35acc9 commit e94b91d

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,22 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
135135
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
136136
validation_image = Image.open(validation_image).convert("RGB")
137137

138-
# Get the interpolation mode from string
139138
try:
140-
interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper())
139+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
141140
except (AttributeError, KeyError):
142-
interpolation_mode = transforms.InterpolationMode.LANCZOS
143-
144-
validation_image = validation_image.resize(
145-
(args.resolution, args.resolution),
146-
resample=Image.Resampling.LANCZOS if interpolation_mode == transforms.InterpolationMode.LANCZOS else Image.Resampling.BILINEAR
147-
)
141+
supported_interpolation_modes = [
142+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
143+
]
144+
raise ValueError(
145+
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
146+
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
147+
)
148+
149+
transform = transforms.Compose([
150+
transforms.Resize(args.resolution, interpolation=interpolation),
151+
transforms.CenterCrop(args.resolution),
152+
])
153+
validation_image = transform(validation_image)
148154

149155
images = []
150156

@@ -598,14 +604,13 @@ def parse_args(input_args=None):
598604
),
599605
)
600606
parser.add_argument(
601-
"--interpolation_type",
607+
"--image_interpolation_mode",
602608
type=str,
603609
default="lanczos",
604-
help=(
605-
"The interpolation method to use for resizing images. Choose between 'bilinear', 'bicubic', 'lanczos', "
606-
"'nearest', 'nearest-exact', 'area', etc. See https://pytorch.org/vision/stable/transforms.html for all "
607-
"options. Default is 'lanczos'."
608-
),
610+
choices=[
611+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
612+
],
613+
help="The image interpolation method to use for resizing images.",
609614
)
610615

611616
if input_args is not None:
@@ -752,14 +757,16 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
752757

753758

754759
def prepare_train_dataset(dataset, accelerator):
755-
# Get the interpolation mode from string
756760
try:
757-
interpolation_mode = getattr(transforms.InterpolationMode, args.interpolation_type.upper())
761+
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
758762
except (AttributeError, KeyError):
759-
logger.warning(
760-
f"Interpolation mode {args.interpolation_type} not found. Falling back to LANCZOS."
763+
supported_interpolation_modes = [
764+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
765+
]
766+
raise ValueError(
767+
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
768+
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
761769
)
762-
interpolation_mode = transforms.InterpolationMode.LANCZOS
763770

764771
image_transforms = transforms.Compose(
765772
[

0 commit comments

Comments
 (0)