Skip to content

Commit 31a7b8e

Browse files
committed
Typing for OV adapter
1 parent c4667bd commit 31a7b8e

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55

66
import logging as log
77
from pathlib import Path
8+
from typing import Any
9+
10+
from numpy import ndarray
811

912
try:
1013
import openvino.runtime as ov
1114
from openvino import (
1215
AsyncInferQueue,
1316
Core,
1417
Dimension,
18+
OVAny,
1519
PartialShape,
1620
Type,
1721
get_version,
@@ -35,7 +39,7 @@
3539
)
3640

3741

38-
def create_core():
42+
def create_core() -> Core:
3943
if openvino_absent:
4044
msg = "The OpenVINO package is not installed"
4145
raise ImportError(msg)
@@ -45,7 +49,7 @@ def create_core():
4549
return Core()
4650

4751

48-
def parse_devices(device_string):
52+
def parse_devices(device_string: str) -> tuple[str] | list[str]:
4953
colon_position = device_string.find(":")
5054
if colon_position != -1:
5155
device_type = device_string[:colon_position]
@@ -111,17 +115,17 @@ class OpenvinoAdapter(InferenceAdapter):
111115

112116
def __init__(
113117
self,
114-
core,
115-
model,
116-
weights_path="",
117-
model_parameters={},
118-
device="CPU",
119-
plugin_config=None,
120-
max_num_requests=0,
121-
precision="FP16",
122-
download_dir=None,
123-
cache_dir=None,
124-
):
118+
core: Core,
119+
model: str,
120+
weights_path: str = "",
121+
model_parameters: dict[str, Any] = {},
122+
device: str = "CPU",
123+
plugin_config: dict[str, Any] | None = None,
124+
max_num_requests: int = 0,
125+
precision: str = "FP16",
126+
download_dir: None = None,
127+
cache_dir: None = None,
128+
) -> None:
125129
"""precision, download_dir and cache_dir are ignored if model is a path to a file"""
126130
self.core = core
127131
self.model_path = model
@@ -179,7 +183,7 @@ def __init__(
179183
msg = "Model must be bytes, a file or existing OMZ model name"
180184
raise RuntimeError(msg)
181185

182-
def load_model(self):
186+
def load_model(self) -> None:
183187
self.compiled_model = self.core.compile_model(
184188
self.model,
185189
self.device,
@@ -201,7 +205,7 @@ def load_model(self):
201205
)
202206
self.log_runtime_settings()
203207

204-
def log_runtime_settings(self):
208+
def log_runtime_settings(self) -> None:
205209
devices = set(parse_devices(self.device))
206210
if "AUTO" not in devices:
207211
for device in devices:
@@ -222,7 +226,7 @@ def log_runtime_settings(self):
222226
pass
223227
log.info(f"\tNumber of model infer requests: {len(self.async_queue)}")
224228

225-
def get_input_layers(self):
229+
def get_input_layers(self) -> dict[str, Metadata]:
226230
inputs = {}
227231
for input in self.model.inputs:
228232
input_shape = get_input_shape(input)
@@ -235,7 +239,11 @@ def get_input_layers(self):
235239
)
236240
return self._get_meta_from_ngraph(inputs)
237241

238-
def get_layout_for_input(self, input, shape=None) -> str:
242+
def get_layout_for_input(
243+
self,
244+
input: ov.Output,
245+
shape: list[int] | tuple[int, int, int, int] | None = None,
246+
) -> str:
239247
input_layout = ""
240248
if self.model_parameters["input_layouts"]:
241249
input_layout = Layout.from_user_layouts(
@@ -251,7 +259,7 @@ def get_layout_for_input(self, input, shape=None) -> str:
251259
)
252260
return input_layout
253261

254-
def get_output_layers(self):
262+
def get_output_layers(self) -> dict[str, Metadata]:
255263
outputs = {}
256264
for i, output in enumerate(self.model.outputs):
257265
output_shape = output.partial_shape.get_min_shape() if self.model.is_dynamic() else output.shape
@@ -273,13 +281,13 @@ def reshape_model(self, new_shape):
273281
}
274282
self.model.reshape(new_shape)
275283

276-
def get_raw_result(self, request):
284+
def get_raw_result(self, request: ov.InferRequest) -> dict[str, ndarray]:
277285
return {key: request.get_tensor(key).data for key in self.get_output_layers()}
278286

279287
def copy_raw_result(self, request):
280288
return {key: request.get_tensor(key).data.copy() for key in self.get_output_layers()}
281289

282-
def infer_sync(self, dict_data):
290+
def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]:
283291
self.infer_request = self.async_queue[self.async_queue.get_idle_request_id()]
284292
self.infer_request.infer(dict_data)
285293
return self.get_raw_result(self.infer_request)
@@ -299,7 +307,7 @@ def await_all(self) -> None:
299307
def await_any(self) -> None:
300308
self.async_queue.get_idle_request_id()
301309

302-
def _get_meta_from_ngraph(self, layers_info):
310+
def _get_meta_from_ngraph(self, layers_info: dict[str, Metadata]) -> dict[str, Metadata]:
303311
for node in self.model.get_ordered_ops():
304312
layer_name = node.get_friendly_name()
305313
if layer_name not in layers_info:
@@ -319,24 +327,24 @@ def operations_by_type(self, operation_type):
319327
)
320328
return layers_info
321329

322-
def get_rt_info(self, path):
330+
def get_rt_info(self, path: list[str]) -> OVAny:
323331
if self.is_onnx_file:
324332
return get_rt_info_from_dict(self.onnx_metadata, path)
325333
return self.model.get_rt_info(path)
326334

327335
def embed_preprocessing(
328336
self,
329-
layout,
337+
layout: str,
330338
resize_mode: str,
331-
interpolation_mode,
339+
interpolation_mode: str,
332340
target_shape: tuple[int],
333-
pad_value,
341+
pad_value: int,
334342
dtype: type = int,
335-
brg2rgb=False,
336-
mean=None,
337-
scale=None,
338-
input_idx=0,
339-
):
343+
brg2rgb: bool = False,
344+
mean: list[Any] | None = None,
345+
scale: list[Any] | None = None,
346+
input_idx: int = 0,
347+
) -> None:
340348
ppp = PrePostProcessor(self.model)
341349

342350
# Change the input type to the 8-bit image
@@ -407,7 +415,7 @@ def get_model(self):
407415
return self.model
408416

409417

410-
def get_input_shape(input_tensor):
418+
def get_input_shape(input_tensor: ov.Output) -> list[int]:
411419
def string_to_tuple(string, casting_type=int):
412420
processed = string.replace(" ", "").replace("(", "").replace(")", "").split(",")
413421
processed = filter(lambda x: x, processed)
@@ -428,4 +436,4 @@ def string_to_tuple(string, casting_type=int):
428436
else:
429437
shape_list.append(int(dim))
430438
return shape_list
431-
return string_to_tuple(preprocessed)
439+
return list(string_to_tuple(preprocessed))

0 commit comments

Comments
 (0)