Skip to content

Commit f98cd7e

Browse files
committed
ImageShowCV2 documentation and OnnxInferenceModel force_cpu flag fix
1 parent 29cad5a commit f98cd7e

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
## [0.1.5] - 2022-01-03
1+
## [0.1.5] - 2022-01-10
22

33
### Changed
44
- seperated CWERMetric to SER and WER Metrics in mltu.metrics, Character/word rate was calculatted in a wrong way
55
- created @setter for augmentors and transformers in DataProvider, to properlly add augmentors and transformers to the pipeline
66
- augmentors and transformers must inherit from `mltu.augmentors.base.Augmentor` and `mltu.transformers.base.Transformer` respectively
7+
- updated ImageShowCV2 transformer documentation
8+
- fixed OnnxInferenceModel in mltu.inferenceModels to use CPU even if GPU is available with force_cpu=True flag
79

810
### Added:
911
- added RandomSharpen to mltu.augmentors, used for simple image augmentation;

mltu/inferenceModel.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44
import onnxruntime as ort
55

66
class OnnxInferenceModel:
7-
""" Base class for all inference models that use onnxruntime """
7+
""" Base class for all inference models that use onnxruntime
8+
9+
Attributes:
10+
model_path (str, optional): Path to the model folder. Defaults to "".
11+
force_cpu (bool, optional): Force the model to run on CPU or GPU. Defaults to GPU.
12+
default_model_name (str, optional): Default model name. Defaults to "model.onnx".
13+
"""
814
def __init__(
915
self,
1016
model_path: str = "",
1117
force_cpu: bool = False,
1218
default_model_name: str = "model.onnx"
1319
):
14-
""" Initialize the model
1520

16-
Args:
17-
model_path (str, optional): Path to the model. Defaults to "".
18-
force_cpu (bool, optional): Force the model to run on CPU or GPU. Defaults to GPU.
19-
default_model_name (str, optional): Default model name. Defaults to "model.onnx".
20-
"""
2121
self.model_path = model_path
2222
self.force_cpu = force_cpu
2323
self.default_model_name = default_model_name
@@ -28,9 +28,14 @@ def __init__(
2828
if not os.path.exists(self.model_path):
2929
raise Exception(f"Model path ({self.model_path}) does not exist")
3030

31-
providers = ['CUDAExecutionProvider'] if ort.get_device() == "GPU" and not force_cpu else ['CPUExecutionProvider']
31+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if ort.get_device() == "GPU" and not force_cpu else ['CPUExecutionProvider']
3232

3333
self.model = ort.InferenceSession(self.model_path, providers=providers)
34+
35+
# Update providers priority to only CPUExecutionProvider
36+
if self.force_cpu:
37+
self.model.set_providers(['CPUExecutionProvider'])
38+
3439
self.input_shape = self.model.get_inputs()[0].shape[1:]
3540
self.input_name = self.model._inputs_meta[0].name
3641
self.output_name = self.model._outputs_meta[0].name

mltu/transformers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
import typing
33
import numpy as np
44

5+
import logging
6+
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s: %(message)s')
7+
logger = logging.getLogger(__name__)
8+
logger.setLevel(logging.INFO)
9+
510
class Transformer:
611
def __init__(self, *args, **kwargs):
712
pass
@@ -89,8 +94,27 @@ def __call__(self, data: np.ndarray, label: np.ndarray):
8994

9095
class ImageShowCV2(Transformer):
9196
"""Show image for visual inspection
97+
98+
Attributes:
99+
verbose (bool): Whether to log label
92100
"""
101+
def __init__(self, verbose: bool=True) -> None:
102+
self.verbose = verbose
103+
93104
def __call__(self, data: np.ndarray, label: np.ndarray):
105+
""" Show image for visual inspection
106+
107+
Args:
108+
data (np.ndarray): Image data
109+
label (np.ndarray): Label data
110+
111+
Returns:
112+
data (np.ndarray): Image data
113+
label (np.ndarray): Label data (unchanged)
114+
"""
115+
if self.verbose:
116+
logger.info('Label: ', label)
117+
94118
cv2.imshow('image', data)
95119
cv2.waitKey(0)
96120
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)