Skip to content

Commit d22f814

Browse files
author
PyLessons
committed
add option to choose whether augment annotation or not
1 parent 6a4eb20 commit d22f814

File tree

1 file changed

+112
-55
lines changed

1 file changed

+112
-55
lines changed

mltu/augmentors.py

Lines changed: 112 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def wrapper(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, t
3333
# check if image is Image object
3434
if not isinstance(image, Image):
3535
self.logger.error(f"image must be Image object, not {type(image)}, skipping augmentor")
36+
# TODO instead of error convert image into Image object
3637
return image, annotation
3738

3839
if np.random.rand() > self._random_chance:
@@ -51,9 +52,10 @@ class Augmentor:
5152
random_chance (float, optional): Chance of applying the augmentor. Where 0.0 is never and 1.0 is always. Defaults to 0.5.
5253
log_level (int, optional): Log level for the augmentor. Defaults to logging.INFO.
5354
"""
54-
def __init__(self, random_chance: float=0.5, log_level: int = logging.INFO) -> None:
55+
def __init__(self, random_chance: float=0.5, log_level: int = logging.INFO, augment_annotation: bool = False) -> None:
5556
self._random_chance = random_chance
5657
self._log_level = log_level
58+
self._augment_annotation = augment_annotation
5759

5860
self.logger = logging.getLogger(self.__class__.__name__)
5961
self.logger.setLevel(logging.INFO)
@@ -73,20 +75,37 @@ def __init__(
7375
random_chance: float = 0.5,
7476
delta: int = 100,
7577
log_level: int = logging.INFO,
78+
augment_annotation: bool = False
7679
) -> None:
7780
""" Randomly adjust image brightness
7881
7982
Args:
8083
random_chance (float, optional): Chance of applying the augmentor. Where 0.0 is never and 1.0 is always. Defaults to 0.5.
8184
delta (int, optional): Integer value for brightness adjustment. Defaults to 100.
8285
log_level (int, optional): Log level for the augmentor. Defaults to logging.INFO.
86+
augment_annotation (bool, optional): If True, the annotation will be adjusted as well. Defaults to False.
8387
"""
84-
super(RandomBrightness, self).__init__(random_chance, log_level)
88+
super(RandomBrightness, self).__init__(random_chance, log_level, augment_annotation)
8589

8690
assert 0 <= delta <= 255.0, "Delta must be between 0.0 and 255.0"
8791

8892
self._delta = delta
8993

94+
def augment(self, image: Image, value: float) -> Image:
95+
""" Augment image brightness """
96+
hsv = np.array(image.HSV(), dtype = np.float32)
97+
98+
hsv[:, :, 1] = hsv[:, :, 1] * value
99+
hsv[:, :, 2] = hsv[:, :, 2] * value
100+
101+
hsv = np.uint8(np.clip(hsv, 0, 255))
102+
103+
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
104+
105+
image.update(img)
106+
107+
return image
108+
90109
@randomness_decorator
91110
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
92111
""" Randomly adjust image brightness
@@ -101,16 +120,10 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
101120
"""
102121
value = 1 + np.random.uniform(-self._delta, self._delta) / 255
103122

104-
hsv = np.array(image.HSV(), dtype = np.float32)
105-
106-
hsv[:, :, 1] = hsv[:, :, 1] * value
107-
hsv[:, :, 2] = hsv[:, :, 2] * value
123+
image = self.augment(image, value)
108124

109-
hsv = np.uint8(np.clip(hsv, 0, 255))
110-
111-
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
112-
113-
image.update(img)
125+
if self._augment_annotation and isinstance(annotation, Image):
126+
annotation = self.augment(annotation, value)
114127

115128
return image, annotation
116129

@@ -123,6 +136,7 @@ def __init__(
123136
angle: typing.Union[int, typing.List]=30,
124137
borderValue: typing.Tuple[int, int, int]=None,
125138
log_level: int = logging.INFO,
139+
augment_annotation: bool = True
126140
) -> None:
127141
""" Randomly rotate image
128142
@@ -131,8 +145,9 @@ def __init__(
131145
angle (int, list): Integer value or list of integer values for image rotation
132146
borderValue (tuple): Tuple of 3 integers, setting border color for image rotation
133147
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
148+
augment_annotation (bool): If True, the annotation will be adjusted as well. Defaults to True.
134149
"""
135-
super(RandomRotate, self).__init__(random_chance, log_level)
150+
super(RandomRotate, self).__init__(random_chance, log_level, augment_annotation)
136151

137152
self._angle = angle
138153
self._borderValue = borderValue
@@ -149,7 +164,7 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
149164
image (Image): Adjusted image
150165
annotation (typing.Any): Adjusted annotation
151166
"""
152-
# check if angle is list of angles or signle angle value
167+
# check if angle is list of angles or a single angle value
153168
if isinstance(self._angle, list):
154169
angle = float(np.random.choice(self._angle))
155170
else:
@@ -180,7 +195,7 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
180195
# perform the actual rotation and return the image
181196
img = cv2.warpAffine(image.numpy(), M, (nW, nH), borderValue=borderValue)
182197

183-
if isinstance(annotation, Image):
198+
if self._augment_annotation and isinstance(annotation, Image):
184199
annotation_warp = cv2.warpAffine(annotation.numpy(), M, (nW, nH), borderValue=(0, 0, 0))
185200
annotation.update(annotation_warp)
186201

@@ -196,16 +211,29 @@ def __init__(
196211
random_chance: float = 0.5,
197212
kernel_size: typing.Tuple[int, int]=(1, 1),
198213
log_level: int = logging.INFO,
214+
augment_annotation: bool = False,
199215
) -> None:
200216
""" Randomly erode and dilate image
201217
202218
Args:
203219
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
204220
kernel_size (tuple): Tuple of 2 integers, setting kernel size for erosion and dilation
205221
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
222+
augment_annotation (bool): Boolean value to determine if annotation should be adjusted. Defaults to False.
206223
"""
207-
super(RandomErodeDilate, self).__init__(random_chance, log_level)
224+
super(RandomErodeDilate, self).__init__(random_chance, log_level, augment_annotation)
208225
self._kernel_size = kernel_size
226+
self.kernel = np.ones(self._kernel_size, np.uint8)
227+
228+
def augment(self, image: Image) -> Image:
229+
if np.random.rand() <= 0.5:
230+
img = cv2.erode(image.numpy(), self.kernel, iterations=1)
231+
else:
232+
img = cv2.dilate(image.numpy(), self.kernel, iterations=1)
233+
234+
image.update(img)
235+
236+
return image
209237

210238
@randomness_decorator
211239
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
@@ -219,14 +247,10 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
219247
image (Image): Eroded and dilated image
220248
annotation (typing.Any): Adjusted annotation if necessary
221249
"""
222-
kernel = np.ones(self._kernel_size, np.uint8)
223-
224-
if np.random.rand() <= 0.5:
225-
img = cv2.erode(image.numpy(), kernel, iterations=1)
226-
else:
227-
img = cv2.dilate(image.numpy(), kernel, iterations=1)
250+
image = self.augment(image)
228251

229-
image.update(img)
252+
if self._augment_annotation and isinstance(annotation, Image):
253+
annotation = self.augment(annotation)
230254

231255
return image, annotation
232256

@@ -241,6 +265,7 @@ def __init__(
241265
kernel: np.ndarray = None,
242266
kernel_anchor: np.ndarray = None,
243267
log_level: int = logging.INFO,
268+
augment_annotation: bool = False,
244269
) -> None:
245270
""" Randomly sharpen image
246271
@@ -251,8 +276,9 @@ def __init__(
251276
kernel (np.ndarray): Numpy array of kernel for image convolution
252277
kernel_anchor (np.ndarray): Numpy array of kernel anchor for image convolution
253278
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
279+
augment_annotation (bool): Boolean to determine if annotation should be augmented. Defaults to False.
254280
"""
255-
super(RandomSharpen, self).__init__(random_chance, log_level)
281+
super(RandomSharpen, self).__init__(random_chance, log_level, augment_annotation)
256282

257283
self._alpha_range = (alpha, 1.0)
258284
self._ligtness_range = lightness_range
@@ -263,18 +289,7 @@ def __init__(
263289

264290
assert 0 <= alpha <= 1.0, "Alpha must be between 0.0 and 1.0"
265291

266-
@randomness_decorator
267-
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
268-
""" Randomly sharpen image
269-
270-
Args:
271-
image (Image): Image to be sharpened
272-
annotation (typing.Any): Annotation to be adjusted
273-
274-
Returns:
275-
image (Image): Sharpened image
276-
annotation (typing.Any): Adjusted annotation if necessary
277-
"""
292+
def augment(self, image: Image) -> Image:
278293
lightness = np.random.uniform(*self._ligtness_range)
279294
alpha = np.random.uniform(*self._alpha_range)
280295

@@ -291,6 +306,25 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
291306
# Merge the sharpened channels back into the original image
292307
image.update(cv2.merge([r_sharp, g_sharp, b_sharp]))
293308

309+
return image
310+
311+
@randomness_decorator
312+
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
313+
""" Randomly sharpen image
314+
315+
Args:
316+
image (Image): Image to be sharpened
317+
annotation (typing.Any): Annotation to be adjusted
318+
319+
Returns:
320+
image (Image): Sharpened image
321+
annotation (typing.Any): Adjusted annotation if necessary
322+
"""
323+
image = self.augment(image)
324+
325+
if self._augment_annotation and isinstance(annotation, Image):
326+
annotation = self.augment(annotation)
327+
294328
return image, annotation
295329

296330

@@ -301,6 +335,7 @@ def __init__(
301335
random_chance: float = 0.5,
302336
log_level: int = logging.INFO,
303337
sigma: typing.Union[int, float] = 0.5,
338+
augment_annotation: bool = False,
304339
) -> None:
305340
""" Randomly erode and dilate image
306341
@@ -309,9 +344,16 @@ def __init__(
309344
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
310345
sigma (int, float): standard deviation of the Gaussian kernel
311346
"""
312-
super(RandomGaussianBlur, self).__init__(random_chance, log_level)
347+
super(RandomGaussianBlur, self).__init__(random_chance, log_level, augment_annotation)
313348
self.sigma = sigma
314349

350+
def augment(self, image: Image) -> Image:
351+
img = cv2.GaussianBlur(image.numpy(), (0, 0), self.sigma)
352+
353+
image.update(img)
354+
355+
return image
356+
315357
@randomness_decorator
316358
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
317359
""" Randomly blurs an image with a Gaussian filter
@@ -324,9 +366,10 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
324366
image (Image): Blurred image
325367
annotation (typing.Any): Blurred annotation if necessary
326368
"""
327-
img = cv2.GaussianBlur(image.numpy(), (0, 0), self.sigma)
369+
image = self.augment(image)
328370

329-
image.update(img)
371+
if self._augment_annotation and isinstance(annotation, Image):
372+
annotation = self.augment(annotation)
330373

331374
return image, annotation
332375

@@ -339,6 +382,7 @@ def __init__(
339382
log_level: int = logging.INFO,
340383
salt_vs_pepper: float = 0.5,
341384
amount: float = 0.1,
385+
augment_annotation: bool = False,
342386
) -> None:
343387
""" Randomly add Salt and Pepper noise to image
344388
@@ -347,26 +391,16 @@ def __init__(
347391
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
348392
salt_vs_pepper (float): ratio of salt vs pepper. Defaults to 0.5.
349393
amount (float): proportion of the image to be salted and peppered. Defaults to 0.1.
394+
augment_annotation (bool): Whether to augment the annotation. Defaults to False.
350395
"""
351-
super(RandomSaltAndPepper, self).__init__(random_chance, log_level)
396+
super(RandomSaltAndPepper, self).__init__(random_chance, log_level, augment_annotation)
352397
self.salt_vs_pepper = salt_vs_pepper
353398
self.amount = amount
354399

355400
assert 0 <= salt_vs_pepper <= 1.0, "salt_vs_pepper must be between 0.0 and 1.0"
356401
assert 0 <= amount <= 1.0, "amount must be between 0.0 and 1.0"
357402

358-
@randomness_decorator
359-
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
360-
""" Randomly add salt and pepper noise to an image
361-
362-
Args:
363-
image (Image): Image to be noised
364-
annotation (typing.Any): Annotation to be noised
365-
366-
Returns:
367-
image (Image): Noised image
368-
annotation (typing.Any): Noised annotation if necessary
369-
"""
403+
def augment(self, image: Image) -> Image:
370404
img = image.numpy()
371405
height, width, channels = img.shape
372406

@@ -384,6 +418,25 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
384418

385419
image.update(img)
386420

421+
return image
422+
423+
@randomness_decorator
424+
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
425+
""" Randomly add salt and pepper noise to an image
426+
427+
Args:
428+
image (Image): Image to be noised
429+
annotation (typing.Any): Annotation to be noised
430+
431+
Returns:
432+
image (Image): Noised image
433+
annotation (typing.Any): Noised annotation if necessary
434+
"""
435+
image = self.augment(image)
436+
437+
if self._augment_annotation and isinstance(annotation, Image):
438+
annotation = self.augment(annotation)
439+
387440
return image, annotation
388441

389442

@@ -393,14 +446,16 @@ def __init__(
393446
self,
394447
random_chance: float = 0.5,
395448
log_level: int = logging.INFO,
449+
augment_annotation: bool = False,
396450
) -> None:
397451
""" Randomly mirror image
398452
399453
Args:
400454
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
401455
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
456+
augment_annotation (bool): Whether to augment the annotation. Defaults to False.
402457
"""
403-
super(RandomMirror, self).__init__(random_chance, log_level)
458+
super(RandomMirror, self).__init__(random_chance, log_level, augment_annotation)
404459

405460
@randomness_decorator
406461
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
@@ -415,7 +470,7 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
415470
annotation (typing.Any): Mirrored annotation if necessary
416471
"""
417472
image = image.flip(0)
418-
if isinstance(annotation, Image):
473+
if self._augment_annotation and isinstance(annotation, Image):
419474
annotation = annotation.flip(0)
420475

421476
return image, annotation
@@ -427,14 +482,16 @@ def __init__(
427482
self,
428483
random_chance: float = 0.5,
429484
log_level: int = logging.INFO,
485+
augment_annotation: bool = False,
430486
) -> None:
431487
""" Randomly mirror image
432488
433489
Args:
434490
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
435491
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
492+
augment_annotation (bool): Whether to augment the annotation. Defaults to False.
436493
"""
437-
super(RandomFlip, self).__init__(random_chance, log_level)
494+
super(RandomFlip, self).__init__(random_chance, log_level, augment_annotation)
438495

439496
@randomness_decorator
440497
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
@@ -449,7 +506,7 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
449506
annotation (typing.Any): Flipped annotation if necessary
450507
"""
451508
image = image.flip(1)
452-
if isinstance(annotation, Image):
509+
if self._augment_annotation and isinstance(annotation, Image):
453510
annotation = annotation.flip(1)
454511

455512
return image, annotation

0 commit comments

Comments
 (0)