55
66import logging as log
77from pathlib import Path
8+ from typing import Any
9+
10+ from numpy import ndarray
811
912try :
1013 import openvino .runtime as ov
1114 from openvino import (
1215 AsyncInferQueue ,
1316 Core ,
1417 Dimension ,
18+ OVAny ,
1519 PartialShape ,
1620 Type ,
1721 get_version ,
3539)
3640
3741
38- def create_core ():
42+ def create_core () -> Core :
3943 if openvino_absent :
4044 msg = "The OpenVINO package is not installed"
4145 raise ImportError (msg )
@@ -45,7 +49,7 @@ def create_core():
4549 return Core ()
4650
4751
48- def parse_devices (device_string ) :
52+ def parse_devices (device_string : str ) -> tuple [ str ] | list [ str ] :
4953 colon_position = device_string .find (":" )
5054 if colon_position != - 1 :
5155 device_type = device_string [:colon_position ]
@@ -111,17 +115,17 @@ class OpenvinoAdapter(InferenceAdapter):
111115
112116 def __init__ (
113117 self ,
114- core ,
115- model ,
116- weights_path = "" ,
117- model_parameters = {},
118- device = "CPU" ,
119- plugin_config = None ,
120- max_num_requests = 0 ,
121- precision = "FP16" ,
122- download_dir = None ,
123- cache_dir = None ,
124- ):
118+ core : Core ,
119+ model : str ,
120+ weights_path : str = "" ,
121+ model_parameters : dict [ str , Any ] = {},
122+ device : str = "CPU" ,
123+ plugin_config : dict [ str , Any ] | None = None ,
124+ max_num_requests : int = 0 ,
125+ precision : str = "FP16" ,
126+ download_dir : None = None ,
127+ cache_dir : None = None ,
128+ ) -> None :
125129 """precision, download_dir and cache_dir are ignored if model is a path to a file"""
126130 self .core = core
127131 self .model_path = model
@@ -179,7 +183,7 @@ def __init__(
179183 msg = "Model must be bytes, a file or existing OMZ model name"
180184 raise RuntimeError (msg )
181185
182- def load_model (self ):
186+ def load_model (self ) -> None :
183187 self .compiled_model = self .core .compile_model (
184188 self .model ,
185189 self .device ,
@@ -201,7 +205,7 @@ def load_model(self):
201205 )
202206 self .log_runtime_settings ()
203207
204- def log_runtime_settings (self ):
208+ def log_runtime_settings (self ) -> None :
205209 devices = set (parse_devices (self .device ))
206210 if "AUTO" not in devices :
207211 for device in devices :
@@ -222,7 +226,7 @@ def log_runtime_settings(self):
222226 pass
223227 log .info (f"\t Number of model infer requests: { len (self .async_queue )} " )
224228
225- def get_input_layers (self ):
229+ def get_input_layers (self ) -> dict [ str , Metadata ] :
226230 inputs = {}
227231 for input in self .model .inputs :
228232 input_shape = get_input_shape (input )
@@ -235,7 +239,11 @@ def get_input_layers(self):
235239 )
236240 return self ._get_meta_from_ngraph (inputs )
237241
238- def get_layout_for_input (self , input , shape = None ) -> str :
242+ def get_layout_for_input (
243+ self ,
244+ input : ov .Output ,
245+ shape : list [int ] | tuple [int , int , int , int ] | None = None ,
246+ ) -> str :
239247 input_layout = ""
240248 if self .model_parameters ["input_layouts" ]:
241249 input_layout = Layout .from_user_layouts (
@@ -251,7 +259,7 @@ def get_layout_for_input(self, input, shape=None) -> str:
251259 )
252260 return input_layout
253261
254- def get_output_layers (self ):
262+ def get_output_layers (self ) -> dict [ str , Metadata ] :
255263 outputs = {}
256264 for i , output in enumerate (self .model .outputs ):
257265 output_shape = output .partial_shape .get_min_shape () if self .model .is_dynamic () else output .shape
@@ -273,13 +281,13 @@ def reshape_model(self, new_shape):
273281 }
274282 self .model .reshape (new_shape )
275283
276- def get_raw_result (self , request ) :
284+ def get_raw_result (self , request : ov . InferRequest ) -> dict [ str , ndarray ] :
277285 return {key : request .get_tensor (key ).data for key in self .get_output_layers ()}
278286
279287 def copy_raw_result (self , request ):
280288 return {key : request .get_tensor (key ).data .copy () for key in self .get_output_layers ()}
281289
282- def infer_sync (self , dict_data ) :
290+ def infer_sync (self , dict_data : dict [ str , ndarray ]) -> dict [ str , ndarray ] :
283291 self .infer_request = self .async_queue [self .async_queue .get_idle_request_id ()]
284292 self .infer_request .infer (dict_data )
285293 return self .get_raw_result (self .infer_request )
@@ -299,7 +307,7 @@ def await_all(self) -> None:
299307 def await_any (self ) -> None :
300308 self .async_queue .get_idle_request_id ()
301309
302- def _get_meta_from_ngraph (self , layers_info ) :
310+ def _get_meta_from_ngraph (self , layers_info : dict [ str , Metadata ]) -> dict [ str , Metadata ] :
303311 for node in self .model .get_ordered_ops ():
304312 layer_name = node .get_friendly_name ()
305313 if layer_name not in layers_info :
@@ -319,24 +327,24 @@ def operations_by_type(self, operation_type):
319327 )
320328 return layers_info
321329
322- def get_rt_info (self , path ) :
330+ def get_rt_info (self , path : list [ str ]) -> OVAny :
323331 if self .is_onnx_file :
324332 return get_rt_info_from_dict (self .onnx_metadata , path )
325333 return self .model .get_rt_info (path )
326334
327335 def embed_preprocessing (
328336 self ,
329- layout ,
337+ layout : str ,
330338 resize_mode : str ,
331- interpolation_mode ,
339+ interpolation_mode : str ,
332340 target_shape : tuple [int ],
333- pad_value ,
341+ pad_value : int ,
334342 dtype : type = int ,
335- brg2rgb = False ,
336- mean = None ,
337- scale = None ,
338- input_idx = 0 ,
339- ):
343+ brg2rgb : bool = False ,
344+ mean : list [ Any ] | None = None ,
345+ scale : list [ Any ] | None = None ,
346+ input_idx : int = 0 ,
347+ ) -> None :
340348 ppp = PrePostProcessor (self .model )
341349
342350 # Change the input type to the 8-bit image
@@ -407,7 +415,7 @@ def get_model(self):
407415 return self .model
408416
409417
410- def get_input_shape (input_tensor ) :
418+ def get_input_shape (input_tensor : ov . Output ) -> list [ int ] :
411419 def string_to_tuple (string , casting_type = int ):
412420 processed = string .replace (" " , "" ).replace ("(" , "" ).replace (")" , "" ).split ("," )
413421 processed = filter (lambda x : x , processed )
@@ -428,4 +436,4 @@ def string_to_tuple(string, casting_type=int):
428436 else :
429437 shape_list .append (int (dim ))
430438 return shape_list
431- return string_to_tuple (preprocessed )
439+ return list ( string_to_tuple (preprocessed ) )
0 commit comments