-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmy_transforms.py
More file actions
executable file
·125 lines (99 loc) · 4.84 KB
/
my_transforms.py
File metadata and controls
executable file
·125 lines (99 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# KATANA: Simple Post-Training Robustness Using Test Time Augmentations
# https://arxiv.org/pdf/2109.08191v1.pdf
import torch
import torchvision.transforms.functional as F
from torchvision.transforms import ColorJitter, Compose, Lambda
from numpy import random
class GaussianNoise(torch.nn.Module):
def __init__(self, mean=0., std=1.):
super().__init__()
self.std = std
self.mean = mean
def forward(self, img):
noise = torch.randn(img.size()) * self.std + self.mean
noise = noise.to(img.device)
return img + noise
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Clip(torch.nn.Module):
def __init__(self, min_val=0., max_val=1.):
super().__init__()
self.min_val = min_val
self.max_val = max_val
def forward(self, img):
return torch.clip(img, self.min_val, self.max_val)
def __repr__(self):
return self.__class__.__name__ + '(min_val={0}, max_val={1})'.format(self.min_val, self.max_val)
class ColorJitterPro(ColorJitter):
"""Randomly change the brightness, contrast, saturation, and gamma correction of an image."""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, gamma=0):
super().__init__(brightness, contrast, saturation, hue)
self.gamma = self._check_input(gamma, 'gamma')
@staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue, gamma):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
if gamma is not None:
gamma_factor = random.uniform(gamma[0], gamma[1])
transforms.append(Lambda(lambda img: F.adjust_gamma(img, gamma_factor)))
random.shuffle(transforms)
transform = Compose(transforms)
return transform
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx = torch.randperm(5)
for fn_id in fn_idx:
if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = F.adjust_brightness(img, brightness_factor)
if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = F.adjust_contrast(img, contrast_factor)
if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = F.adjust_saturation(img, saturation_factor)
if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = F.adjust_hue(img, hue_factor)
if fn_id == 4 and self.gamma is not None:
gamma = self.gamma
gamma_factor = torch.tensor(1.0).uniform_(gamma[0], gamma[1]).item()
img = img.clamp(1e-8, 1.0) # to fix Nan values in gradients, which happens when applying gamma
# after contrast
img = F.adjust_gamma(img, gamma_factor)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
format_string += ', gamma={0})'.format(self.gamma)
return format_string