@@ -134,8 +134,19 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
134134
135135 for validation_prompt , validation_image in zip (validation_prompts , validation_images ):
136136 validation_image = Image .open (validation_image ).convert ("RGB" )
137- validation_image = validation_image .resize ((args .resolution , args .resolution ))
138-
137+
138+ # Use the same interpolation mode as in training
139+ if args .interpolation_type .lower () == "lanczos" :
140+ interpolation_mode = transforms .InterpolationMode .LANCZOS
141+ else :
142+ interpolation_mode = transforms .InterpolationMode .BILINEAR
143+
144+ transform = transforms .Compose ([
145+ transforms .Resize (args .resolution , interpolation = interpolation_mode ),
146+ transforms .CenterCrop (args .resolution ),
147+ ])
148+ validation_image = transform (validation_image )
149+
139150 images = []
140151
141152 for _ in range (args .num_validation_images ):
@@ -587,6 +598,13 @@ def parse_args(input_args=None):
587598 " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
588599 ),
589600 )
601+ parser .add_argument (
602+ "--interpolation_type" ,
603+ type = str ,
604+ default = "lanczos" ,
605+ choices = ["lanczos" , "bilinear" ],
606+ help = "The interpolation method to use for resizing images. Choose between 'lanczos' (default) and 'bilinear'." ,
607+ )
590608
591609 if input_args is not None :
592610 args = parser .parse_args (input_args )
@@ -732,9 +750,15 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
732750
733751
734752def prepare_train_dataset (dataset , accelerator ):
753+ # Set the interpolation mode based on user preference
754+ if args .interpolation_type .lower () == "lanczos" :
755+ interpolation_mode = transforms .InterpolationMode .LANCZOS
756+ else :
757+ interpolation_mode = transforms .InterpolationMode .BILINEAR
758+
735759 image_transforms = transforms .Compose (
736760 [
737- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
761+ transforms .Resize (args .resolution , interpolation = interpolation_mode ),
738762 transforms .CenterCrop (args .resolution ),
739763 transforms .ToTensor (),
740764 transforms .Normalize ([0.5 ], [0.5 ]),
@@ -743,7 +767,7 @@ def prepare_train_dataset(dataset, accelerator):
743767
744768 conditioning_image_transforms = transforms .Compose (
745769 [
746- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
770+ transforms .Resize (args .resolution , interpolation = interpolation_mode ),
747771 transforms .CenterCrop (args .resolution ),
748772 transforms .ToTensor (),
749773 ]
0 commit comments