2424import math
2525import re
2626from functools import partial
27- from typing import Dict , List , Optional , Union
27+ from typing import Any , Dict , List , Optional , Union
2828
29- from PIL import Image , ImageOps , ImageEnhance , ImageChops , ImageFilter
29+ import torch
3030import PIL
3131import numpy as np
32-
32+ from PIL import Image , ImageFilter
33+ from torchvision .transforms import InterpolationMode
34+ import torchvision .transforms .functional as TF
35+ try :
36+ import torchvision .transforms .v2 .functional as TF2
37+ except ImportError :
38+ TF2 = None
3339
3440_PIL_VER = tuple ([int (x ) for x in PIL .__version__ .split ('.' )[:2 ]])
3541
4248 img_mean = _FILL ,
4349)
4450
45- if hasattr (Image , "Resampling" ):
46- _RANDOM_INTERPOLATION = (Image .Resampling .BILINEAR , Image .Resampling .BICUBIC )
47- _DEFAULT_INTERPOLATION = Image .Resampling .BICUBIC
48- else :
49- _RANDOM_INTERPOLATION = (Image .BILINEAR , Image .BICUBIC )
50- _DEFAULT_INTERPOLATION = Image .BICUBIC
51+
52+ _RANDOM_INTERPOLATION = (InterpolationMode .BILINEAR , InterpolationMode .BICUBIC )
53+ _DEFAULT_INTERPOLATION = InterpolationMode .BICUBIC
5154
5255
53- def _interpolation (kwargs ):
54- interpolation = kwargs .pop ('resample ' , _DEFAULT_INTERPOLATION )
56+ def _interpolation (kwargs , basic_only = False ):
57+ interpolation = kwargs .pop ('interpolation ' , _DEFAULT_INTERPOLATION )
5558 if isinstance (interpolation , (list , tuple )):
56- return random .choice (interpolation )
59+ interpolation = random .choice (interpolation )
60+ if basic_only :
61+ if interpolation not in (InterpolationMode .NEAREST , InterpolationMode .BILINEAR ):
62+ interpolation = InterpolationMode .BILINEAR
5763 return interpolation
5864
5965
6066def _check_args_tf (kwargs ):
61- if 'fillcolor' in kwargs and _PIL_VER < (5 , 0 ):
62- kwargs .pop ('fillcolor' )
63- kwargs ['resample' ] = _interpolation (kwargs )
67+ kwargs ['interpolation' ] = _interpolation (kwargs )
68+
69+
70+ def _check_args_affine (img , kwargs ):
71+ if isinstance (img , torch .Tensor ):
72+ kwargs ['interpolation' ] = _interpolation (kwargs , basic_only = True )
73+ else :
74+ kwargs ['interpolation' ] = _interpolation (kwargs )
6475
6576
6677def shear_x (img , factor , ** kwargs ):
67- _check_args_tf ( kwargs )
68- return img . transform (img . size , Image . AFFINE , ( 1 , factor , 0 , 0 , 1 , 0 ) , ** kwargs )
78+ _check_args_affine ( img , kwargs )
79+ return TF . affine (img , angle = 0 , translate = [ 0 , 0 ], scale = 1 , shear = [ math . degrees ( math . atan ( factor )), 0 ] , ** kwargs )
6980
7081
7182def shear_y (img , factor , ** kwargs ):
72- _check_args_tf ( kwargs )
73- return img . transform (img . size , Image . AFFINE , ( 1 , 0 , 0 , factor , 1 , 0 ) , ** kwargs )
83+ _check_args_affine ( img , kwargs )
84+ return TF . affine (img , angle = 0 , translate = [ 0 , 0 ], scale = 1 , shear = [ 0 , math . degrees ( math . atan ( factor ))] , ** kwargs )
7485
7586
76- def translate_x_rel (img , pct , ** kwargs ):
77- pixels = pct * img .size [0 ]
78- _check_args_tf (kwargs )
79- return img .transform (img .size , Image .AFFINE , (1 , 0 , pixels , 0 , 1 , 0 ), ** kwargs )
87+ def translate_x_abs (img , pixels , ** kwargs ):
88+ _check_args_affine (img , kwargs )
89+ return TF .affine (img , angle = 0 , translate = [pixels , 0 ], scale = 1 , shear = [0 , 0 ], ** kwargs )
8090
8191
82- def translate_y_rel (img , pct , ** kwargs ):
83- pixels = pct * img .size [1 ]
84- _check_args_tf (kwargs )
85- return img .transform (img .size , Image .AFFINE , (1 , 0 , 0 , 0 , 1 , pixels ), ** kwargs )
92+ def translate_y_abs (img , pixels , ** kwargs ):
93+ _check_args_affine (img , kwargs )
94+ return TF .affine (img , angle = 0 , translate = [0 , pixels ], scale = 1 , shear = [0 , 0 ], ** kwargs )
8695
8796
88- def translate_x_abs (img , pixels , ** kwargs ):
89- _check_args_tf ( kwargs )
90- return img . transform (img . size , Image . AFFINE , ( 1 , 0 , pixels , 0 , 1 , 0 ) , ** kwargs )
97+ def translate_x_rel (img , pct , ** kwargs ):
98+ pixels = pct * TF . get_image_size ( img )[ 0 ]
99+ return translate_x_abs (img , pixels , ** kwargs )
91100
92101
93- def translate_y_abs (img , pixels , ** kwargs ):
94- _check_args_tf ( kwargs )
95- return img . transform (img . size , Image . AFFINE , ( 1 , 0 , 0 , 0 , 1 , pixels ) , ** kwargs )
102+ def translate_y_rel (img , pct , ** kwargs ):
103+ pixels = pct * TF . get_image_size ( img )[ 1 ]
104+ return translate_y_abs (img , pixels , ** kwargs )
96105
97106
98107def rotate (img , degrees , ** kwargs ):
99- _check_args_tf (kwargs )
100- if _PIL_VER >= (5 , 2 ):
101- return img .rotate (degrees , ** kwargs )
102- if _PIL_VER >= (5 , 0 ):
103- w , h = img .size
104- post_trans = (0 , 0 )
105- rotn_center = (w / 2.0 , h / 2.0 )
106- angle = - math .radians (degrees )
107- matrix = [
108- round (math .cos (angle ), 15 ),
109- round (math .sin (angle ), 15 ),
110- 0.0 ,
111- round (- math .sin (angle ), 15 ),
112- round (math .cos (angle ), 15 ),
113- 0.0 ,
114- ]
115-
116- def transform (x , y , matrix ):
117- (a , b , c , d , e , f ) = matrix
118- return a * x + b * y + c , d * x + e * y + f
119-
120- matrix [2 ], matrix [5 ] = transform (
121- - rotn_center [0 ] - post_trans [0 ], - rotn_center [1 ] - post_trans [1 ], matrix
122- )
123- matrix [2 ] += rotn_center [0 ]
124- matrix [5 ] += rotn_center [1 ]
125- return img .transform (img .size , Image .AFFINE , matrix , ** kwargs )
126- return img .rotate (degrees , resample = kwargs ['resample' ])
108+ _check_args_affine (img , kwargs )
109+ return TF .rotate (img , degrees , ** kwargs )
127110
128111
129112def auto_contrast (img , ** __ ):
130- return ImageOps .autocontrast (img )
113+ return TF .autocontrast (img )
131114
132115
133116def invert (img , ** __ ):
134- return ImageOps .invert (img )
117+ return TF .invert (img )
135118
136119
137120def equalize (img , ** __ ):
138- return ImageOps .equalize (img )
121+ if isinstance (img , torch .Tensor ) and img .is_floating_point ():
122+ if TF2 is None :
123+ # FIXME warn / assert?
124+ return img
125+ return TF2 .equalize (img )
126+ return TF .equalize (img )
139127
140128
141129def solarize (img , thresh , ** __ ):
142- return ImageOps .solarize (img , thresh )
130+ if isinstance (img , torch .Tensor ) and img .is_floating_point ():
131+ thresh = min (thresh / 255 , 1.0 )
132+ return TF .solarize (img , thresh )
143133
144134
145135def solarize_add (img , add , thresh = 128 , ** __ ):
146- lut = []
147- for i in range (256 ):
148- if i < thresh :
149- lut .append (min (255 , i + add ))
136+ if isinstance (img , torch .Tensor ):
137+ if img .is_floating_point ():
138+ thresh = thresh / 255
139+ add = add / 255
140+ img_sum = (img + add ).clamp_ (max = 1.0 )
150141 else :
151- lut .append (i )
142+ img_sum = (img + add ).clamp_ (max = 255 )
143+ return torch .where (img >= thresh , img_sum , img )
144+ else :
145+ lut = []
146+ for i in range (256 ):
147+ if i < thresh :
148+ lut .append (min (255 , i + add ))
149+ else :
150+ lut .append (i )
152151
153- if img .mode in ("L" , "RGB" ):
154- if img .mode == "RGB" and len (lut ) == 256 :
155- lut = lut + lut + lut
156- return img .point (lut )
152+ if img .mode in ("L" , "RGB" ):
153+ if img .mode == "RGB" and len (lut ) == 256 :
154+ lut = lut + lut + lut
155+ return img .point (lut )
157156
158157 return img
159158
160159
161160def posterize (img , bits_to_keep , ** __ ):
162161 if bits_to_keep >= 8 :
163162 return img
164- return ImageOps .posterize (img , bits_to_keep )
163+ if isinstance (img , torch .Tensor ) and img .is_floating_point ():
164+ if TF2 is None :
165+ # FIXME warn / assert?
166+ return img
167+ return TF2 .posterize (img , bits_to_keep )
168+ return TF .posterize (img , bits_to_keep )
165169
166170
167171def contrast (img , factor , ** __ ):
168- return ImageEnhance . Contrast (img ). enhance ( factor )
172+ return TF . adjust_contrast (img , factor )
169173
170174
171175def color (img , factor , ** __ ):
172- return ImageEnhance . Color (img ). enhance ( factor )
176+ return TF . adjust_saturation (img , factor )
173177
174178
175179def brightness (img , factor , ** __ ):
176- return ImageEnhance . Brightness (img ). enhance ( factor )
180+ return TF . adjust_brightness (img , factor )
177181
178182
179183def sharpness (img , factor , ** __ ):
180- return ImageEnhance . Sharpness (img ). enhance ( factor )
184+ return TF . adjust_sharpness (img , factor )
181185
182186
183187def gaussian_blur (img , factor , ** __ ):
184- img = img .filter (ImageFilter .GaussianBlur (radius = factor ))
188+ if isinstance (img , torch .Tensor ):
189+ kernel_size = 2 * int (3 * factor ) + 1 # could be bigger, but more expensive
190+ img = TF .gaussian_blur (img , kernel_size = kernel_size , sigma = factor )
191+ else :
192+ img = img .filter (ImageFilter .GaussianBlur (radius = factor ))
185193 return img
186194
187195
188196def gaussian_blur_rand (img , factor , ** __ ):
189197 radius_min = 0.1
190198 radius_max = 2.0
191- img = img . filter ( ImageFilter . GaussianBlur ( radius = random .uniform (radius_min , radius_max * factor )) )
192- return img
199+ radius = random .uniform (radius_min , radius_max * factor )
200+ return gaussian_blur ( img , radius )
193201
194202
195203def desaturate (img , factor , ** _ ):
196204 factor = min (1. , max (0. , 1. - factor ))
197205 # enhance factor 0 = grayscale, 1.0 = no-change
198- return ImageEnhance . Color (img ). enhance ( factor )
206+ return TF . adjust_saturation (img , factor )
199207
200208
201209def _randomly_negate (v ):
@@ -356,7 +364,13 @@ def _solarize_add_level_to_arg(level, _hparams):
356364
357365class AugmentOp :
358366
359- def __init__ (self , name , prob = 0.5 , magnitude = 10 , hparams = None ):
367+ def __init__ (
368+ self ,
369+ name : str ,
370+ prob : float = 0.5 ,
371+ magnitude : float = 10 ,
372+ hparams : Optional [Dict [str , Any ]] = None
373+ ):
360374 hparams = hparams or _HPARAMS_DEFAULT
361375 self .name = name
362376 self .aug_fn = NAME_TO_OP [name ]
@@ -365,8 +379,8 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
365379 self .magnitude = magnitude
366380 self .hparams = hparams .copy ()
367381 self .kwargs = dict (
368- fillcolor = hparams ['img_mean' ] if 'img_mean' in hparams else _FILL ,
369- resample = hparams ['interpolation' ] if 'interpolation' in hparams else _RANDOM_INTERPOLATION ,
382+ fill = hparams ['img_mean' ] if 'img_mean' in hparams else _FILL ,
383+ interpolation = hparams ['interpolation' ] if 'interpolation' in hparams else _RANDOM_INTERPOLATION ,
370384 )
371385
372386 # If magnitude_std is > 0, we introduce some randomness
@@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):
564578
565579class AutoAugment :
566580
567- def __init__ (self , policy ):
581+ def __init__ (self , policy : List ):
568582 self .policy = policy
569583
570584 def __call__ (self , img ):
@@ -729,8 +743,14 @@ def rand_augment_ops(
729743):
730744 hparams = hparams or _HPARAMS_DEFAULT
731745 transforms = transforms or _RAND_TRANSFORMS
732- return [AugmentOp (
733- name , prob = prob , magnitude = magnitude , hparams = hparams ) for name in transforms ]
746+ return [
747+ AugmentOp (
748+ name ,
749+ prob = prob ,
750+ magnitude = magnitude ,
751+ hparams = hparams
752+ ) for name in transforms
753+ ]
734754
735755
736756class RandAugment :
0 commit comments