Skip to content

Commit 3b181b7

Browse files
committed
Updating augmentations, esp randaug to support full torch.Tensor pipeline
1 parent ea23107 commit 3b181b7

File tree

4 files changed

+157
-96
lines changed

4 files changed

+157
-96
lines changed

timm/data/auto_augment.py

Lines changed: 109 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@
2424
import math
2525
import re
2626
from functools import partial
27-
from typing import Dict, List, Optional, Union
27+
from typing import Any, Dict, List, Optional, Union
2828

29-
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
29+
import torch
3030
import PIL
3131
import numpy as np
32-
32+
from PIL import Image, ImageFilter
33+
from torchvision.transforms import InterpolationMode
34+
import torchvision.transforms.functional as TF
35+
try:
36+
import torchvision.transforms.v2.functional as TF2
37+
except ImportError:
38+
TF2 = None
3339

3440
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
3541

@@ -42,160 +48,162 @@
4248
img_mean=_FILL,
4349
)
4450

45-
if hasattr(Image, "Resampling"):
46-
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
47-
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
48-
else:
49-
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
50-
_DEFAULT_INTERPOLATION = Image.BICUBIC
51+
52+
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
53+
_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC
5154

5255

53-
def _interpolation(kwargs):
54-
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
56+
def _interpolation(kwargs, basic_only=False):
57+
interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION)
5558
if isinstance(interpolation, (list, tuple)):
56-
return random.choice(interpolation)
59+
interpolation = random.choice(interpolation)
60+
if basic_only:
61+
if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR):
62+
interpolation = InterpolationMode.BILINEAR
5763
return interpolation
5864

5965

6066
def _check_args_tf(kwargs):
61-
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
62-
kwargs.pop('fillcolor')
63-
kwargs['resample'] = _interpolation(kwargs)
67+
kwargs['interpolation'] = _interpolation(kwargs)
68+
69+
70+
def _check_args_affine(img, kwargs):
71+
if isinstance(img, torch.Tensor):
72+
kwargs['interpolation'] = _interpolation(kwargs, basic_only=True)
73+
else:
74+
kwargs['interpolation'] = _interpolation(kwargs)
6475

6576

6677
def shear_x(img, factor, **kwargs):
67-
_check_args_tf(kwargs)
68-
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
78+
_check_args_affine(img, kwargs)
79+
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs)
6980

7081

7182
def shear_y(img, factor, **kwargs):
72-
_check_args_tf(kwargs)
73-
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
83+
_check_args_affine(img, kwargs)
84+
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs)
7485

7586

76-
def translate_x_rel(img, pct, **kwargs):
77-
pixels = pct * img.size[0]
78-
_check_args_tf(kwargs)
79-
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
87+
def translate_x_abs(img, pixels, **kwargs):
88+
_check_args_affine(img, kwargs)
89+
return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs)
8090

8191

82-
def translate_y_rel(img, pct, **kwargs):
83-
pixels = pct * img.size[1]
84-
_check_args_tf(kwargs)
85-
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
92+
def translate_y_abs(img, pixels, **kwargs):
93+
_check_args_affine(img, kwargs)
94+
return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs)
8695

8796

88-
def translate_x_abs(img, pixels, **kwargs):
89-
_check_args_tf(kwargs)
90-
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
97+
def translate_x_rel(img, pct, **kwargs):
98+
pixels = pct * TF.get_image_size(img)[0]
99+
return translate_x_abs(img, pixels, **kwargs)
91100

92101

93-
def translate_y_abs(img, pixels, **kwargs):
94-
_check_args_tf(kwargs)
95-
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
102+
def translate_y_rel(img, pct, **kwargs):
103+
pixels = pct * TF.get_image_size(img)[1]
104+
return translate_y_abs(img, pixels, **kwargs)
96105

97106

98107
def rotate(img, degrees, **kwargs):
99-
_check_args_tf(kwargs)
100-
if _PIL_VER >= (5, 2):
101-
return img.rotate(degrees, **kwargs)
102-
if _PIL_VER >= (5, 0):
103-
w, h = img.size
104-
post_trans = (0, 0)
105-
rotn_center = (w / 2.0, h / 2.0)
106-
angle = -math.radians(degrees)
107-
matrix = [
108-
round(math.cos(angle), 15),
109-
round(math.sin(angle), 15),
110-
0.0,
111-
round(-math.sin(angle), 15),
112-
round(math.cos(angle), 15),
113-
0.0,
114-
]
115-
116-
def transform(x, y, matrix):
117-
(a, b, c, d, e, f) = matrix
118-
return a * x + b * y + c, d * x + e * y + f
119-
120-
matrix[2], matrix[5] = transform(
121-
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
122-
)
123-
matrix[2] += rotn_center[0]
124-
matrix[5] += rotn_center[1]
125-
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
126-
return img.rotate(degrees, resample=kwargs['resample'])
108+
_check_args_affine(img, kwargs)
109+
return TF.rotate(img, degrees, **kwargs)
127110

128111

129112
def auto_contrast(img, **__):
130-
return ImageOps.autocontrast(img)
113+
return TF.autocontrast(img)
131114

132115

133116
def invert(img, **__):
134-
return ImageOps.invert(img)
117+
return TF.invert(img)
135118

136119

137120
def equalize(img, **__):
138-
return ImageOps.equalize(img)
121+
if isinstance(img, torch.Tensor) and img.is_floating_point():
122+
if TF2 is None:
123+
# FIXME warn / assert?
124+
return img
125+
return TF2.equalize(img)
126+
return TF.equalize(img)
139127

140128

141129
def solarize(img, thresh, **__):
142-
return ImageOps.solarize(img, thresh)
130+
if isinstance(img, torch.Tensor) and img.is_floating_point():
131+
thresh = min(thresh / 255, 1.0)
132+
return TF.solarize(img, thresh)
143133

144134

145135
def solarize_add(img, add, thresh=128, **__):
146-
lut = []
147-
for i in range(256):
148-
if i < thresh:
149-
lut.append(min(255, i + add))
136+
if isinstance(img, torch.Tensor):
137+
if img.is_floating_point():
138+
thresh = thresh / 255
139+
add = add / 255
140+
img_sum = (img + add).clamp_(max=1.0)
150141
else:
151-
lut.append(i)
142+
img_sum = (img + add).clamp_(max=255)
143+
return torch.where(img >= thresh, img_sum, img)
144+
else:
145+
lut = []
146+
for i in range(256):
147+
if i < thresh:
148+
lut.append(min(255, i + add))
149+
else:
150+
lut.append(i)
152151

153-
if img.mode in ("L", "RGB"):
154-
if img.mode == "RGB" and len(lut) == 256:
155-
lut = lut + lut + lut
156-
return img.point(lut)
152+
if img.mode in ("L", "RGB"):
153+
if img.mode == "RGB" and len(lut) == 256:
154+
lut = lut + lut + lut
155+
return img.point(lut)
157156

158157
return img
159158

160159

161160
def posterize(img, bits_to_keep, **__):
162161
if bits_to_keep >= 8:
163162
return img
164-
return ImageOps.posterize(img, bits_to_keep)
163+
if isinstance(img, torch.Tensor) and img.is_floating_point():
164+
if TF2 is None:
165+
# FIXME warn / assert?
166+
return img
167+
return TF2.posterize(img, bits_to_keep)
168+
return TF.posterize(img, bits_to_keep)
165169

166170

167171
def contrast(img, factor, **__):
168-
return ImageEnhance.Contrast(img).enhance(factor)
172+
return TF.adjust_contrast(img, factor)
169173

170174

171175
def color(img, factor, **__):
172-
return ImageEnhance.Color(img).enhance(factor)
176+
return TF.adjust_saturation(img, factor)
173177

174178

175179
def brightness(img, factor, **__):
176-
return ImageEnhance.Brightness(img).enhance(factor)
180+
return TF.adjust_brightness(img, factor)
177181

178182

179183
def sharpness(img, factor, **__):
180-
return ImageEnhance.Sharpness(img).enhance(factor)
184+
return TF.adjust_sharpness(img, factor)
181185

182186

183187
def gaussian_blur(img, factor, **__):
184-
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
188+
if isinstance(img, torch.Tensor):
189+
kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive
190+
img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor)
191+
else:
192+
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
185193
return img
186194

187195

188196
def gaussian_blur_rand(img, factor, **__):
189197
radius_min = 0.1
190198
radius_max = 2.0
191-
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
192-
return img
199+
radius = random.uniform(radius_min, radius_max * factor)
200+
return gaussian_blur(img, radius)
193201

194202

195203
def desaturate(img, factor, **_):
196204
factor = min(1., max(0., 1. - factor))
197205
# enhance factor 0 = grayscale, 1.0 = no-change
198-
return ImageEnhance.Color(img).enhance(factor)
206+
return TF.adjust_saturation(img, factor)
199207

200208

201209
def _randomly_negate(v):
@@ -356,7 +364,13 @@ def _solarize_add_level_to_arg(level, _hparams):
356364

357365
class AugmentOp:
358366

359-
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
367+
def __init__(
368+
self,
369+
name: str,
370+
prob: float = 0.5,
371+
magnitude: float = 10,
372+
hparams: Optional[Dict[str, Any]] = None
373+
):
360374
hparams = hparams or _HPARAMS_DEFAULT
361375
self.name = name
362376
self.aug_fn = NAME_TO_OP[name]
@@ -365,8 +379,8 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
365379
self.magnitude = magnitude
366380
self.hparams = hparams.copy()
367381
self.kwargs = dict(
368-
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
369-
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
382+
fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
383+
interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
370384
)
371385

372386
# If magnitude_std is > 0, we introduce some randomness
@@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):
564578

565579
class AutoAugment:
566580

567-
def __init__(self, policy):
581+
def __init__(self, policy: List):
568582
self.policy = policy
569583

570584
def __call__(self, img):
@@ -729,8 +743,14 @@ def rand_augment_ops(
729743
):
730744
hparams = hparams or _HPARAMS_DEFAULT
731745
transforms = transforms or _RAND_TRANSFORMS
732-
return [AugmentOp(
733-
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
746+
return [
747+
AugmentOp(
748+
name,
749+
prob=prob,
750+
magnitude=magnitude,
751+
hparams=hparams
752+
) for name in transforms
753+
]
734754

735755

736756
class RandAugment:

timm/data/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def __init__(
8787
re_prob=0.,
8888
re_mode='const',
8989
re_count=1,
90-
re_num_splits=0):
90+
re_num_splits=0,
91+
):
9192

9293
mean = adapt_to_chs(mean, channels)
9394
std = adapt_to_chs(std, channels)

timm/data/transforms.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
has_interpolation_mode = True
1313
except ImportError:
1414
has_interpolation_mode = False
15-
from PIL import Image
15+
from PIL import Image, ImageCms
16+
1617
import numpy as np
1718

1819
__all__ = [
@@ -89,6 +90,31 @@ def __repr__(self) -> str:
8990
return f"{self.__class__.__name__}()"
9091

9192

93+
class ToLab(transforms.ToTensor):
94+
95+
def __init__(self) -> None:
96+
super().__init__()
97+
rgb_profile = ImageCms.createProfile(colorSpace='sRGB')
98+
lab_profile = ImageCms.createProfile(colorSpace='LAB')
99+
# Create a transform object from the input and output profiles
100+
self.rgb_to_lab_transform = ImageCms.buildTransform(
101+
inputProfile=rgb_profile,
102+
outputProfile=lab_profile,
103+
inMode='RGB',
104+
outMode='LAB'
105+
)
106+
107+
def __call__(self, pic) -> torch.Tensor:
108+
lab_image = ImageCms.applyTransform(
109+
im=pic,
110+
transform=self.rgb_to_lab_transform
111+
)
112+
return lab_image
113+
114+
def __repr__(self) -> str:
115+
return f"{self.__class__.__name__}()"
116+
117+
92118
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
93119
# favor of the Image.Resampling enum. The top-level resampling attributes will be
94120
# removed in Pillow 10.

0 commit comments

Comments
 (0)