@@ -135,16 +135,22 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
135135    for  validation_prompt , validation_image  in  zip (validation_prompts , validation_images ):
136136        validation_image  =  Image .open (validation_image ).convert ("RGB" )
137137
138-         # Get the interpolation mode from string 
139138        try :
140-             interpolation_mode  =  getattr (transforms .InterpolationMode , args .interpolation_type .upper ())
139+             interpolation  =  getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper ())
141140        except  (AttributeError , KeyError ):
142-             interpolation_mode  =  transforms .InterpolationMode .LANCZOS 
143-         
144-         validation_image  =  validation_image .resize (
145-             (args .resolution , args .resolution ), 
146-             resample = Image .Resampling .LANCZOS  if  interpolation_mode  ==  transforms .InterpolationMode .LANCZOS  else  Image .Resampling .BILINEAR 
147-         )
141+             supported_interpolation_modes  =  [
142+                 f .lower () for  f  in  dir (transforms .InterpolationMode ) if  not  f .startswith ("__" ) and  not  f .endswith ("__" )
143+             ]
144+             raise  ValueError (
145+                 f"Interpolation mode { args .image_interpolation_mode }  
146+                 f"Please select one of the following: { ', ' .join (supported_interpolation_modes )}  
147+             )
148+             
149+         transform  =  transforms .Compose ([
150+             transforms .Resize (args .resolution , interpolation = interpolation ),
151+             transforms .CenterCrop (args .resolution ),
152+         ])
153+         validation_image  =  transform (validation_image )
148154
149155        images  =  []
150156
@@ -598,14 +604,13 @@ def parse_args(input_args=None):
598604        ),
599605    )
600606    parser .add_argument (
601-         "--interpolation_type " ,
607+         "--image_interpolation_mode " ,
602608        type = str ,
603609        default = "lanczos" ,
604-         help = (
605-             "The interpolation method to use for resizing images. Choose between 'bilinear', 'bicubic', 'lanczos', " 
606-             "'nearest', 'nearest-exact', 'area', etc. See https://pytorch.org/vision/stable/transforms.html for all " 
607-             "options. Default is 'lanczos'." 
608-         ),
610+         choices = [
611+             f .lower () for  f  in  dir (transforms .InterpolationMode ) if  not  f .startswith ("__" ) and  not  f .endswith ("__" )
612+         ],
613+         help = "The image interpolation method to use for resizing images." ,
609614    )
610615
611616    if  input_args  is  not None :
@@ -752,14 +757,16 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
752757
753758
754759def  prepare_train_dataset (dataset , accelerator ):
755-     # Get the interpolation mode from string 
756760    try :
757-         interpolation_mode  =  getattr (transforms .InterpolationMode , args .interpolation_type .upper ())
761+         interpolation_mode  =  getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper ())
758762    except  (AttributeError , KeyError ):
759-         logger .warning (
760-             f"Interpolation mode { args .interpolation_type }  
763+         supported_interpolation_modes  =  [
764+             f .lower () for  f  in  dir (transforms .InterpolationMode ) if  not  f .startswith ("__" ) and  not  f .endswith ("__" )
765+         ]
766+         raise  ValueError (
767+             f"Interpolation mode { args .image_interpolation_mode }  
768+             f"Please select one of the following: { ', ' .join (supported_interpolation_modes )}  
761769        )
762-         interpolation_mode  =  transforms .InterpolationMode .LANCZOS 
763770
764771    image_transforms  =  transforms .Compose (
765772        [
0 commit comments