@@ -135,18 +135,17 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
135135    for  validation_prompt , validation_image  in  zip (validation_prompts , validation_images ):
136136        validation_image  =  Image .open (validation_image ).convert ("RGB" )
137137
138-         # Use the same interpolation mode as in training 
139-         if  args .interpolation_type .lower () ==  "lanczos" :
138+         # Get the interpolation mode from string 
139+         try :
140+             interpolation_mode  =  getattr (transforms .InterpolationMode , args .interpolation_type .upper ())
141+         except  (AttributeError , KeyError ):
140142            interpolation_mode  =  transforms .InterpolationMode .LANCZOS 
141-         else :
142-             interpolation_mode  =  transforms .InterpolationMode .BILINEAR 
143-             
144-         transform  =  transforms .Compose ([
145-             transforms .Resize (args .resolution , interpolation = interpolation_mode ),
146-             transforms .CenterCrop (args .resolution ),
147-         ])
148-         validation_image  =  transform (validation_image )
149143
144+         validation_image  =  validation_image .resize (
145+             (args .resolution , args .resolution ), 
146+             resample = Image .Resampling .LANCZOS  if  interpolation_mode  ==  transforms .InterpolationMode .LANCZOS  else  Image .Resampling .BILINEAR 
147+         )
148+ 
150149        images  =  []
151150
152151        for  _  in  range (args .num_validation_images ):
@@ -602,8 +601,11 @@ def parse_args(input_args=None):
602601        "--interpolation_type" ,
603602        type = str ,
604603        default = "lanczos" ,
605-         choices = ["lanczos" , "bilinear" ],
606-         help = "The interpolation method to use for resizing images. Choose between 'lanczos' (default) and 'bilinear'." ,
604+         help = (
605+             "The interpolation method to use for resizing images. Choose between 'bilinear', 'bicubic', 'lanczos', " 
606+             "'nearest', 'nearest-exact', 'area', etc. See https://pytorch.org/vision/stable/transforms.html for all " 
607+             "options. Default is 'lanczos'." 
608+         ),
607609    )
608610
609611    if  input_args  is  not None :
@@ -750,11 +752,14 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
750752
751753
752754def  prepare_train_dataset (dataset , accelerator ):
753-     # Set the interpolation mode based on user preference 
754-     if  args .interpolation_type .lower () ==  "lanczos" :
755+     # Get the interpolation mode from string 
756+     try :
757+         interpolation_mode  =  getattr (transforms .InterpolationMode , args .interpolation_type .upper ())
758+     except  (AttributeError , KeyError ):
759+         logger .warning (
760+             f"Interpolation mode { args .interpolation_type }  
761+         )
755762        interpolation_mode  =  transforms .InterpolationMode .LANCZOS 
756-     else :
757-         interpolation_mode  =  transforms .InterpolationMode .BILINEAR 
758763
759764    image_transforms  =  transforms .Compose (
760765        [
0 commit comments