@@ -480,6 +480,15 @@ def parse_args(input_args=None):
480480 action = "store_true" ,
481481 help = "debug loss for each image, if filenames are available in the dataset" ,
482482 )
483+ parser .add_argument (
484+ "--image_interpolation_mode" ,
485+ type = str ,
486+ default = "lanczos" ,
487+ choices = [
488+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
489+ ],
490+ help = "The image interpolation method to use for resizing images." ,
491+ )
483492
484493 if input_args is not None :
485494 args = parser .parse_args (input_args )
@@ -913,8 +922,14 @@ def tokenize_captions(examples, is_train=True):
913922 tokens_two = tokenize_prompt (tokenizer_two , captions )
914923 return tokens_one , tokens_two
915924
925+ # Get the specified interpolation method from the args
926+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
927+
928+ # Raise an error if the interpolation method is invalid
929+ if interpolation is None :
930+ raise ValueError (f"Unsupported interpolation mode { args .image_interpolation_mode } ." )
916931 # Preprocessing the datasets.
917- train_resize = transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR )
932+ train_resize = transforms .Resize (args .resolution , interpolation = interpolation ) # Use dynamic interpolation method
918933 train_crop = transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution )
919934 train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
920935 train_transforms = transforms .Compose (
0 commit comments