Skip to content

Commit 4c0df6b

Browse files
committed
Merge branch 'feature/imageSegmentation' into develop
2 parents 59a8a0e + df2ebf8 commit 4c0df6b

File tree

8 files changed

+474
-17
lines changed

8 files changed

+474
-17
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
## [1.0.12] - 2022-06-08
2+
### Changed
3+
- Moved `onnx` and `tf2onnx` import inside `mltu.tensorflow.callbacks.Model2onnx` to avoid import errors when not using this callback
4+
- Removed `onnx` and `tf2onnx` install requirements from global requirements
5+
6+
### Added
7+
- Added `RandomMirror` and `RandomFlip` augmentors into `mltu.augmentors`
8+
- Added `u2net` segmentation model into `mltu.tensorflow.models`
9+
110
## [1.0.11] - 2022-06-07
211
### Changed
312
- Downgrade `tf2onnx` and `onnx` versions, they don't work with newest TensorFlow version

mltu/augmentors.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
- RandomSharpen
1313
- RandomGaussianBlur
1414
- RandomSaltAndPepper
15+
- RandomMirror
16+
- RandomFlip
1517
"""
1618

1719

@@ -177,6 +179,11 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
177179

178180
# perform the actual rotation and return the image
179181
img = cv2.warpAffine(image.numpy(), M, (nW, nH), borderValue=borderValue)
182+
183+
if isinstance(annotation, Image):
184+
annotation_warp = cv2.warpAffine(annotation.numpy(), M, (nW, nH), borderValue=(0, 0, 0))
185+
annotation.update(annotation_warp)
186+
180187
image.update(img)
181188

182189
return image, annotation
@@ -377,4 +384,72 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
377384

378385
image.update(img)
379386

387+
return image, annotation
388+
389+
390+
class RandomMirror(Augmentor):
391+
""" Randomly mirror image"""
392+
def __init__(
393+
self,
394+
random_chance: float = 0.5,
395+
log_level: int = logging.INFO,
396+
) -> None:
397+
""" Randomly mirror image
398+
399+
Args:
400+
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
401+
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
402+
"""
403+
super(RandomMirror, self).__init__(random_chance, log_level)
404+
405+
@randomness_decorator
406+
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
407+
""" Randomly mirror an image
408+
409+
Args:
410+
image (Image): Image to be mirrored
411+
annotation (typing.Any): Annotation to be mirrored
412+
413+
Returns:
414+
image (Image): Mirrored image
415+
annotation (typing.Any): Mirrored annotation if necessary
416+
"""
417+
image = image.flip(0)
418+
if isinstance(annotation, Image):
419+
annotation = annotation.flip(0)
420+
421+
return image, annotation
422+
423+
424+
class RandomFlip(Augmentor):
425+
""" Randomly flip image"""
426+
def __init__(
427+
self,
428+
random_chance: float = 0.5,
429+
log_level: int = logging.INFO,
430+
) -> None:
431+
""" Randomly mirror image
432+
433+
Args:
434+
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
435+
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
436+
"""
437+
super(RandomFlip, self).__init__(random_chance, log_level)
438+
439+
@randomness_decorator
440+
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
441+
""" Randomly mirror an image
442+
443+
Args:
444+
image (Image): Image to be flipped
445+
annotation (typing.Any): Annotation to be flipped
446+
447+
Returns:
448+
image (Image): Flipped image
449+
annotation (typing.Any): Flipped annotation if necessary
450+
"""
451+
image = image.flip(1)
452+
if isinstance(annotation, Image):
453+
annotation = annotation.flip(1)
454+
380455
return image, annotation

mltu/preprocessors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class ImageReader:
1515
"""Read image from path and return image and label"""
16-
def __init__(self, image_class, log_level: int = logging.INFO, ) -> None:
16+
def __init__(self, image_class: Image, log_level: int = logging.INFO, ) -> None:
1717
self.logger = logging.getLogger(self.__class__.__name__)
1818
self.logger.setLevel(log_level)
1919
self._image_class = image_class

mltu/tensorflow/callbacks.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import logging
55

66
class Model2onnx(Callback):
7+
8+
# import onnx and use as global
9+
import onnx
10+
import tf2onnx
11+
712
""" Converts the model to onnx format after training is finished. """
813
def __init__(
914
self,
@@ -22,15 +27,15 @@ def __init__(
2227
def on_train_end(self, logs=None):
2328
""" Converts the model to onnx format after training is finished. """
2429
try:
25-
import onnx
26-
import tf2onnx
30+
# import onnx
31+
# import tf2onnx
2732
self.model.load_weights(self.saved_model_path)
2833
self.onnx_model_path = self.saved_model_path.replace(".h5", ".onnx")
29-
tf2onnx.convert.from_keras(self.model, output_path=self.onnx_model_path)
34+
self.tf2onnx.convert.from_keras(self.model, output_path=self.onnx_model_path)
3035

3136
if self.metadata and isinstance(self.metadata, dict):
3237
# Load the ONNX model
33-
onnx_model = onnx.load(self.onnx_model_path)
38+
onnx_model = self.onnx.load(self.onnx_model_path)
3439

3540
# Add the metadata dictionary to the model's metadata_props attribute
3641
for key, value in self.metadata.items():
@@ -39,7 +44,7 @@ def on_train_end(self, logs=None):
3944
meta.value = value
4045

4146
# Save the modified ONNX model
42-
onnx.save(onnx_model, self.onnx_model_path)
47+
self.onnx.save(onnx_model, self.onnx_model_path)
4348

4449
except Exception as e:
4550
print(e)

0 commit comments

Comments
 (0)