@@ -673,6 +673,15 @@ def parse_args(input_args=None):
673673        default = False ,
674674        help = "Cache the VAE latents" ,
675675    )
676+     parser .add_argument (
677+         "--image_interpolation_mode" ,
678+         type = str ,
679+         default = "lanczos" ,
680+         choices = [
681+             f .lower () for  f  in  dir (transforms .InterpolationMode ) if  not  f .startswith ("__" ) and  not  f .endswith ("__" )
682+         ],
683+         help = "The image interpolation method to use for resizing images." ,
684+     )
676685
677686    if  input_args  is  not None :
678687        args  =  parser .parse_args (input_args )
@@ -906,6 +915,10 @@ def __init__(
906915            self .instance_images .extend (itertools .repeat (img , repeats ))
907916        self .num_instance_images  =  len (self .instance_images )
908917        self ._length  =  self .num_instance_images 
918+         
919+         interpolation  =  getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
920+         if  interpolation  is  None :
921+             raise  ValueError (f"Unsupported interpolation mode { interpolation = }  )
909922
910923        if  class_data_root  is  not None :
911924            self .class_data_root  =  Path (class_data_root )
@@ -921,7 +934,7 @@ def __init__(
921934
922935        self .image_transforms  =  transforms .Compose (
923936            [
924-                 transforms .Resize (size , interpolation = transforms . InterpolationMode . BILINEAR ),
937+                 transforms .Resize (size , interpolation = interpolation ),
925938                transforms .CenterCrop (size ) if  center_crop  else  transforms .RandomCrop (size ),
926939                transforms .ToTensor (),
927940                transforms .Normalize ([0.5 ], [0.5 ]),
0 commit comments