Skip to content

Commit 61f9004

Browse files
committed
Cover model
1 parent 1c8b0f0 commit 61f9004

File tree

6 files changed

+78
-63
lines changed

6 files changed

+78
-63
lines changed

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def reshape_model(self, new_shape):
8383
"""
8484

8585
@abstractmethod
86-
def infer_sync(self, dict_data):
86+
def infer_sync(self, dict_data) -> dict:
8787
"""Performs the synchronous model inference. The infer is a blocking method.
8888
8989
Args:
@@ -121,6 +121,22 @@ def infer_async(self, dict_data, callback_data):
121121
- callback_data: the data for callback, that will be taken after the model inference is ended
122122
"""
123123

124+
@abstractmethod
125+
def get_raw_result(self, infer_result) -> dict:
126+
"""Gets raw results from the internal inference framework representation as a dict.
127+
128+
Args:
129+
- infer_result: framework-specific result of inference from the model
130+
131+
Returns:
132+
- raw result (dict) - model raw output in the following format:
133+
{
134+
'output_layer_name_1': raw_result_1,
135+
'output_layer_name_2': raw_result_2,
136+
...
137+
}
138+
"""
139+
124140
@abstractmethod
125141
def is_ready(self):
126142
"""In case of asynchronous execution checks if one can submit input data

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def await_all(self):
122122
def await_any(self):
123123
pass
124124

125+
def get_raw_result(self, infer_result):
126+
pass
127+
125128
def embed_preprocessing(
126129
self,
127130
layout,

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import logging as log
7+
from os import PathLike
78
from pathlib import Path
89
from typing import Any
910

@@ -86,7 +87,7 @@ def parse_value_per_device(devices: set[str], values_string: str) -> dict[str, i
8687
def get_user_config(
8788
flags_d: str,
8889
flags_nstreams: str,
89-
flags_nthreads: int,
90+
flags_nthreads: int | None = None,
9091
) -> dict[str, str]:
9192
config = {}
9293

@@ -117,14 +118,14 @@ def __init__(
117118
self,
118119
core: Core,
119120
model: str,
120-
weights_path: str = "",
121+
weights_path: PathLike | None = None,
121122
model_parameters: dict[str, Any] = {},
122123
device: str = "CPU",
123124
plugin_config: dict[str, Any] | None = None,
124125
max_num_requests: int = 0,
125126
precision: str = "FP16",
126-
download_dir: None = None,
127-
cache_dir: None = None,
127+
download_dir: PathLike | None = None,
128+
cache_dir: PathLike | None = None,
128129
) -> None:
129130
"""precision, download_dir and cache_dir are ignored if model is a path to a file"""
130131
self.core = core

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def await_all(self):
9797
def await_any(self):
9898
pass
9999

100+
def get_raw_result(self, infer_result):
101+
pass
102+
100103
def embed_preprocessing(
101104
self,
102105
layout,

model_api/python/model_api/models/model.py

Lines changed: 48 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import logging as log
77
import re
88
from contextlib import contextmanager
9+
from os import PathLike
10+
from typing import Any, NoReturn, Type
11+
12+
from numpy import ndarray
913

1014
from model_api.adapters.inference_adapter import InferenceAdapter
1115
from model_api.adapters.onnx_adapter import ONNXRuntimeAdapter
@@ -20,7 +24,7 @@
2024
class WrapperError(Exception):
2125
"""The class for errors occurred in Model API wrappers"""
2226

23-
def __init__(self, wrapper_name, message):
27+
def __init__(self, wrapper_name, message) -> None:
2428
super().__init__(f"{wrapper_name}: {message}")
2529

2630

@@ -52,7 +56,7 @@ class Model:
5256

5357
__model__: str = "Model"
5458

55-
def __init__(self, inference_adapter, configuration: dict = {}, preload=False):
59+
def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}, preload: bool = False) -> None:
5660
"""Model constructor
5761
5862
Args:
@@ -98,7 +102,7 @@ def get_model(self):
98102
return model
99103

100104
@classmethod
101-
def get_model_class(cls, name):
105+
def get_model_class(cls, name: str) -> Type:
102106
subclasses = [subclass for subclass in cls.get_subclasses() if subclass.__model__]
103107
if cls.__model__:
104108
subclasses.append(cls)
@@ -113,21 +117,21 @@ def get_model_class(cls, name):
113117
@classmethod
114118
def create_model(
115119
cls,
116-
model,
117-
model_type=None,
118-
configuration={},
119-
preload=True,
120-
core=None,
121-
weights_path="",
122-
adaptor_parameters={},
123-
device="AUTO",
124-
nstreams="1",
125-
nthreads=None,
126-
max_num_requests=0,
127-
precision="FP16",
128-
download_dir=None,
129-
cache_dir=None,
130-
):
120+
model: str,
121+
model_type: Any | None = None,
122+
configuration: dict[str, Any] = {},
123+
preload: bool = True,
124+
core: Any | None = None,
125+
weights_path: PathLike | None = None,
126+
adaptor_parameters: dict[str, Any] = {},
127+
device: str = "AUTO",
128+
nstreams: str = "1",
129+
nthreads: int | None = None,
130+
max_num_requests: int = 0,
131+
precision: str = "FP16",
132+
download_dir: PathLike | None = None,
133+
cache_dir: PathLike | None = None,
134+
) -> Any:
131135
"""Create an instance of the Model API model
132136
133137
Args:
@@ -152,9 +156,8 @@ def create_model(
152156
Returns:
153157
Model object
154158
"""
155-
if isinstance(model, InferenceAdapter):
156-
inference_adapter = model
157-
elif isinstance(model, str) and re.compile(
159+
inference_adapter: InferenceAdapter
160+
if isinstance(model, str) and re.compile(
158161
r"(\w+\.*\-*)*\w+:\d+\/models\/[a-zA-Z0-9._-]+(\:\d+)*",
159162
).fullmatch(model):
160163
inference_adapter = OVMSAdapter(model)
@@ -182,7 +185,7 @@ def create_model(
182185
return Model(inference_adapter, configuration, preload)
183186

184187
@classmethod
185-
def get_subclasses(cls):
188+
def get_subclasses(cls) -> list[Any]:
186189
all_subclasses = []
187190
for subclass in cls.__subclasses__():
188191
all_subclasses.append(subclass)
@@ -196,7 +199,7 @@ def available_wrappers(cls):
196199
return [subclass.__model__ for subclass in available_classes if subclass.__model__]
197200

198201
@classmethod
199-
def parameters(cls):
202+
def parameters(cls) -> dict[str, Any]:
200203
"""Defines the description and type of configurable data parameters for the wrapper.
201204
202205
See `types.py` to find available types of the data parameter. For each parameter
@@ -214,7 +217,7 @@ def parameters(cls):
214217
"""
215218
return {}
216219

217-
def _load_config(self, config):
220+
def _load_config(self, config: dict[str, Any]) -> None:
218221
"""Reads the configuration and creates data attributes
219222
by setting the wrapper parameters with values from configuration.
220223
@@ -265,7 +268,7 @@ def _load_config(self, config):
265268
)
266269

267270
@classmethod
268-
def raise_error(cls, message):
271+
def raise_error(cls, message) -> NoReturn:
269272
"""Raises the WrapperError.
270273
271274
Args:
@@ -292,7 +295,7 @@ def preprocess(self, inputs):
292295
"""
293296
raise NotImplementedError
294297

295-
def postprocess(self, outputs, meta):
298+
def postprocess(self, outputs: dict[str, Any], meta: dict[str, Any]):
296299
"""Interface for postprocess method.
297300
298301
Args:
@@ -309,7 +312,11 @@ def postprocess(self, outputs, meta):
309312
"""
310313
raise NotImplementedError
311314

312-
def _check_io_number(self, number_of_inputs, number_of_outputs):
315+
def _check_io_number(
316+
self,
317+
number_of_inputs: int | tuple[int, ...],
318+
number_of_outputs: int | tuple[int, ...],
319+
) -> None:
313320
"""Checks whether the number of model inputs/outputs is supported.
314321
315322
Args:
@@ -321,47 +328,32 @@ def _check_io_number(self, number_of_inputs, number_of_outputs):
321328
Raises:
322329
WrapperError: if the model has unsupported number of inputs/outputs
323330
"""
324-
if not isinstance(number_of_inputs, tuple):
331+
if isinstance(number_of_inputs, int):
325332
if len(self.inputs) != number_of_inputs and number_of_inputs != -1:
326333
self.raise_error(
327-
"Expected {} input blob{}, but {} found: {}".format(
328-
number_of_inputs,
329-
"s" if number_of_inputs != 1 else "",
330-
len(self.inputs),
331-
", ".join(self.inputs),
332-
),
334+
f"Expected {number_of_inputs} input blob {'s' if number_of_inputs != 1 else ''}, "
335+
f"but {len(self.inputs)} found: {', '.join(self.inputs)}",
333336
)
334337
elif len(self.inputs) not in number_of_inputs:
335338
self.raise_error(
336-
"Expected {} or {} input blobs, but {} found: {}".format(
337-
", ".join(str(n) for n in number_of_inputs[:-1]),
338-
int(number_of_inputs[-1]),
339-
len(self.inputs),
340-
", ".join(self.inputs),
341-
),
339+
f"Expected {', '.join(str(n) for n in number_of_inputs[:-1])} or "
340+
f"{int(number_of_inputs[-1])} input blobs, but {len(self.inputs)} found: {', '.join(self.inputs)}",
342341
)
343342

344-
if not isinstance(number_of_outputs, tuple):
343+
if isinstance(number_of_outputs, int):
345344
if len(self.outputs) != number_of_outputs and number_of_outputs != -1:
346345
self.raise_error(
347-
"Expected {} output blob{}, but {} found: {}".format(
348-
number_of_outputs,
349-
"s" if number_of_outputs != 1 else "",
350-
len(self.outputs),
351-
", ".join(self.outputs),
352-
),
346+
f"Expected {number_of_outputs} output blob {'s' if number_of_outputs != 1 else ''}, "
347+
f"but {len(self.outputs)} found: {', '.join(self.outputs)}",
353348
)
354349
elif len(self.outputs) not in number_of_outputs:
355350
self.raise_error(
356-
"Expected {} or {} output blobs, but {} found: {}".format(
357-
", ".join(str(n) for n in number_of_outputs[:-1]),
358-
int(number_of_outputs[-1]),
359-
len(self.outputs),
360-
", ".join(self.outputs),
361-
),
351+
f"Expected {', '.join(str(n) for n in number_of_outputs[:-1])} or "
352+
f"{int(number_of_outputs[-1])} output blobs, "
353+
f"but {len(self.outputs)} found: {', '.join(self.outputs)}",
362354
)
363355

364-
def __call__(self, inputs):
356+
def __call__(self, inputs: ndarray):
365357
"""Applies preprocessing, synchronous inference, postprocessing routines while one call.
366358
367359
Args:
@@ -407,7 +399,7 @@ def batch_infer_callback(result, id):
407399

408400
return [completed_results[i] for i in range(len(inputs))]
409401

410-
def load(self, force=False):
402+
def load(self, force: bool = False) -> None:
411403
if not self.model_loaded or force:
412404
self.model_loaded = True
413405
self.inference_adapter.load_model()
@@ -423,7 +415,7 @@ def reshape(self, new_shape):
423415
self.inputs = self.inference_adapter.get_input_layers()
424416
self.outputs = self.inference_adapter.get_output_layers()
425417

426-
def infer_sync(self, dict_data):
418+
def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]:
427419
if not self.model_loaded:
428420
self.raise_error(
429421
"The model is not loaded to the device. Please, create the wrapper "

model_api/python/model_api/models/sam_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ def _get_preprocess_shape(
192192

193193
def _check_io_number(
194194
self,
195-
number_of_inputs: int | tuple[int],
196-
number_of_outputs: int | tuple[int],
195+
number_of_inputs: int | tuple[int, ...],
196+
number_of_outputs: int | tuple[int, ...],
197197
) -> None:
198198
pass
199199

0 commit comments

Comments
 (0)