@@ -599,6 +599,15 @@ def parse_args(input_args=None):
599599 "Defaults to precision dtype used for training to save memory"
600600 ),
601601 )
602+ parser .add_argument (
603+ "--image_interpolation_mode" ,
604+ type = str ,
605+ default = "lanczos" ,
606+ choices = [
607+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
608+ ],
609+ help = "The image interpolation method to use for resizing images." ,
610+ )
602611 parser .add_argument (
603612 "--offload" ,
604613 action = "store_true" ,
@@ -724,7 +733,11 @@ def __init__(
724733 self .instance_images .extend (itertools .repeat (img , repeats ))
725734
726735 self .pixel_values = []
727- train_resize = transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR )
736+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
737+ if interpolation is None :
738+ raise ValueError (f"Unsupported interpolation mode: { args .image_interpolation_mode } " )
739+
740+ train_resize = transforms .Resize (size , interpolation = interpolation )
728741 train_crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
729742 train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
730743 train_transforms = transforms .Compose (
@@ -768,7 +781,7 @@ def __init__(
768781
769782 self .image_transforms = transforms .Compose (
770783 [
771- transforms .Resize (size , interpolation = transforms . InterpolationMode . BILINEAR ),
784+ transforms .Resize (size , interpolation = interpolation ),
772785 transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
773786 transforms .ToTensor (),
774787 transforms .Normalize ([0.5 ], [0.5 ]),
0 commit comments