Skip to content

Commit 51c4579

Browse files
committed
fix ImageReader to work either with image path or np.ndarray
1 parent c4821c2 commit 51c4579

File tree

5 files changed

+56
-18
lines changed

5 files changed

+56
-18
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## [1.0.4] - 2022-03-22
2+
### Changed
3+
- Fix `ImageReader` to work either with image path or `np.ndarray`
4+
- Added `metadata` support to `callbacks/tf2onnx` when converting to onnx format
5+
6+
17
## [1.0.3] - 2022-03-20
28
### Changed
39
- Changed `mltu.augmentors` to work only with `Image` objects

mltu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "1.0.3"
1+
__version__ = "1.0.4"
22

33
from .annotations.image import Image

mltu/annotations/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(
1717
image: typing.Union[str, np.ndarray],
1818
method: int = cv2.IMREAD_COLOR,
1919
path: str = "",
20-
color: str = ""
20+
color: str = "BGR"
2121
) -> None:
2222

2323
if isinstance(image, str):

mltu/preprocessors.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,25 @@ def __init__(self, method: int = cv2.IMREAD_COLOR, log_level: int = logging.INFO
2020
self.logger = logging.getLogger(self.__class__.__name__)
2121
self.logger.setLevel(log_level)
2222

23-
def __call__(self, image_path: str, label: typing.Any) -> typing.Tuple[Image, typing.Any]:
24-
# check whether image_path exists
25-
if not os.path.exists(image_path):
26-
raise FileNotFoundError(f"Image {image_path} not found.")
23+
def __call__(self, image_path: typing.Union[str, np.ndarray], label: typing.Any) -> typing.Tuple[Image, typing.Any]:
24+
""" Read image with cv2 from path and return image and label
25+
26+
Args:
27+
image_path (typing.Union[str, np.ndarray]): Path to image or numpy array
28+
label (Any): Label of image
29+
30+
Returns:
31+
Image: Image object
32+
Any: Label of image
33+
"""
34+
if isinstance(image_path, str):
35+
# check whether image_path exists
36+
if not os.path.exists(image_path):
37+
raise FileNotFoundError(f"Image {image_path} not found.")
38+
elif isinstance(image_path, np.ndarray):
39+
pass
40+
else:
41+
raise TypeError(f"Image {image_path} is not a string or numpy array.")
2742

2843
image = Image(image = image_path, method = self._method)
2944
if image.image is None:

mltu/tensorflow/callbacks.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,43 @@
11
import os
2+
import tf2onnx
3+
import onnx
24
from keras.callbacks import Callback
35

46
import logging
57

68
class Model2onnx(Callback):
7-
"""Converts the model to onnx format after training is finished.
8-
9-
Args:
10-
saved_model_path (str): Path to the saved .h5 model.
11-
"""
12-
try:
13-
import tf2onnx
14-
except ImportError:
15-
raise ImportError("tf2onnx not installed, skipping model export to onnx")
16-
17-
def __init__(self, saved_model_path: str) -> None:
9+
""" Converts the model to onnx format after training is finished. """
10+
def __init__(
11+
self,
12+
saved_model_path: str,
13+
metadata: dict=None
14+
) -> None:
15+
""" Converts the model to onnx format after training is finished.
16+
Args:
17+
saved_model_path (str): Path to the saved .h5 model.
18+
metadata (dict, optional): Dictionary containing metadata to be added to the onnx model. Defaults to None.
19+
"""
1820
super().__init__()
1921
self.saved_model_path = saved_model_path
22+
self.metadata = metadata
2023

2124
def on_train_end(self, logs=None):
2225
self.model.load_weights(self.saved_model_path)
23-
self.tf2onnx.convert.from_keras(self.model, output_path=self.saved_model_path.replace(".h5", ".onnx"), )
26+
self.onnx_model_path = self.saved_model_path.replace(".h5", ".onnx")
27+
self.tf2onnx.convert.from_keras(self.model, output_path=self.onnx_model_path)
28+
29+
if self.metadata and isinstance(self.metadata, dict):
30+
# Load the ONNX model
31+
onnx_model = onnx.load(self.onnx_model_path)
32+
33+
# Add the metadata dictionary to the model's metadata_props attribute
34+
for key, value in self.metadata.items():
35+
meta = onnx_model.metadata_props.add()
36+
meta.key = key
37+
meta.value = value
38+
39+
# Save the modified ONNX model
40+
onnx.save(onnx_model, self.onnx_model_path)
2441

2542
class TrainLogger(Callback):
2643
"""Logs training metrics to a file.

0 commit comments

Comments
 (0)