@@ -470,6 +470,15 @@ def parse_args(input_args=None):
470470 "--enable_xformers_memory_efficient_attention" , action = "store_true" , help = "Whether or not to use xformers."
471471 )
472472 parser .add_argument ("--noise_offset" , type = float , default = 0 , help = "The scale of noise offset." )
473+ parser .add_argument (
474+ "--image_interpolation_mode" ,
475+ type = str ,
476+ default = "lanczos" ,
477+ choices = [
478+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
479+ ],
480+ help = "The image interpolation method to use for resizing images." ,
481+ )
473482
474483 if input_args is not None :
475484 args = parser .parse_args (input_args )
@@ -861,7 +870,10 @@ def load_model_hook(models, input_dir):
861870 )
862871
863872 # Preprocessing the datasets.
864- train_resize = transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR )
873+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
874+ if interpolation is None :
875+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
876+ train_resize = transforms .Resize (args .resolution , interpolation = interpolation )
865877 train_crop = transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution )
866878 train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
867879 train_transforms = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])])
0 commit comments