66import logging as log
77import re
88from contextlib import contextmanager
9+ from os import PathLike
10+ from typing import Any , NoReturn , Type
11+
12+ from numpy import ndarray
913
1014from model_api .adapters .inference_adapter import InferenceAdapter
1115from model_api .adapters .onnx_adapter import ONNXRuntimeAdapter
2024class WrapperError (Exception ):
2125 """The class for errors occurred in Model API wrappers"""
2226
23- def __init__ (self , wrapper_name , message ):
27+ def __init__ (self , wrapper_name , message ) -> None :
2428 super ().__init__ (f"{ wrapper_name } : { message } " )
2529
2630
@@ -52,7 +56,7 @@ class Model:
5256
5357 __model__ : str = "Model"
5458
55- def __init__ (self , inference_adapter , configuration : dict = {}, preload = False ):
59+ def __init__ (self , inference_adapter : InferenceAdapter , configuration : dict = {}, preload : bool = False ) -> None :
5660 """Model constructor
5761
5862 Args:
@@ -98,7 +102,7 @@ def get_model(self):
98102 return model
99103
100104 @classmethod
101- def get_model_class (cls , name ) :
105+ def get_model_class (cls , name : str ) -> Type :
102106 subclasses = [subclass for subclass in cls .get_subclasses () if subclass .__model__ ]
103107 if cls .__model__ :
104108 subclasses .append (cls )
@@ -113,21 +117,21 @@ def get_model_class(cls, name):
113117 @classmethod
114118 def create_model (
115119 cls ,
116- model ,
117- model_type = None ,
118- configuration = {},
119- preload = True ,
120- core = None ,
121- weights_path = "" ,
122- adaptor_parameters = {},
123- device = "AUTO" ,
124- nstreams = "1" ,
125- nthreads = None ,
126- max_num_requests = 0 ,
127- precision = "FP16" ,
128- download_dir = None ,
129- cache_dir = None ,
130- ):
120+ model : str ,
121+ model_type : Any | None = None ,
122+ configuration : dict [ str , Any ] = {},
123+ preload : bool = True ,
124+ core : Any | None = None ,
125+ weights_path : PathLike | None = None ,
126+ adaptor_parameters : dict [ str , Any ] = {},
127+ device : str = "AUTO" ,
128+ nstreams : str = "1" ,
129+ nthreads : int | None = None ,
130+ max_num_requests : int = 0 ,
131+ precision : str = "FP16" ,
132+ download_dir : PathLike | None = None ,
133+ cache_dir : PathLike | None = None ,
134+ ) -> Any :
131135 """Create an instance of the Model API model
132136
133137 Args:
@@ -152,9 +156,8 @@ def create_model(
152156 Returns:
153157 Model object
154158 """
155- if isinstance (model , InferenceAdapter ):
156- inference_adapter = model
157- elif isinstance (model , str ) and re .compile (
159+ inference_adapter : InferenceAdapter
160+ if isinstance (model , str ) and re .compile (
158161 r"(\w+\.*\-*)*\w+:\d+\/models\/[a-zA-Z0-9._-]+(\:\d+)*" ,
159162 ).fullmatch (model ):
160163 inference_adapter = OVMSAdapter (model )
@@ -182,7 +185,7 @@ def create_model(
182185 return Model (inference_adapter , configuration , preload )
183186
184187 @classmethod
185- def get_subclasses (cls ):
188+ def get_subclasses (cls ) -> list [ Any ] :
186189 all_subclasses = []
187190 for subclass in cls .__subclasses__ ():
188191 all_subclasses .append (subclass )
@@ -196,7 +199,7 @@ def available_wrappers(cls):
196199 return [subclass .__model__ for subclass in available_classes if subclass .__model__ ]
197200
198201 @classmethod
199- def parameters (cls ):
202+ def parameters (cls ) -> dict [ str , Any ] :
200203 """Defines the description and type of configurable data parameters for the wrapper.
201204
202205 See `types.py` to find available types of the data parameter. For each parameter
@@ -214,7 +217,7 @@ def parameters(cls):
214217 """
215218 return {}
216219
217- def _load_config (self , config ) :
220+ def _load_config (self , config : dict [ str , Any ]) -> None :
218221 """Reads the configuration and creates data attributes
219222 by setting the wrapper parameters with values from configuration.
220223
@@ -265,7 +268,7 @@ def _load_config(self, config):
265268 )
266269
267270 @classmethod
268- def raise_error (cls , message ):
271+ def raise_error (cls , message ) -> NoReturn :
269272 """Raises the WrapperError.
270273
271274 Args:
@@ -292,7 +295,7 @@ def preprocess(self, inputs):
292295 """
293296 raise NotImplementedError
294297
295- def postprocess (self , outputs , meta ):
298+ def postprocess (self , outputs : dict [ str , Any ], meta : dict [ str , Any ] ):
296299 """Interface for postprocess method.
297300
298301 Args:
@@ -309,7 +312,11 @@ def postprocess(self, outputs, meta):
309312 """
310313 raise NotImplementedError
311314
312- def _check_io_number (self , number_of_inputs , number_of_outputs ):
315+ def _check_io_number (
316+ self ,
317+ number_of_inputs : int | tuple [int , ...],
318+ number_of_outputs : int | tuple [int , ...],
319+ ) -> None :
313320 """Checks whether the number of model inputs/outputs is supported.
314321
315322 Args:
@@ -321,47 +328,32 @@ def _check_io_number(self, number_of_inputs, number_of_outputs):
321328 Raises:
322329 WrapperError: if the model has unsupported number of inputs/outputs
323330 """
324- if not isinstance (number_of_inputs , tuple ):
331+ if isinstance (number_of_inputs , int ):
325332 if len (self .inputs ) != number_of_inputs and number_of_inputs != - 1 :
326333 self .raise_error (
327- "Expected {} input blob{}, but {} found: {}" .format (
328- number_of_inputs ,
329- "s" if number_of_inputs != 1 else "" ,
330- len (self .inputs ),
331- ", " .join (self .inputs ),
332- ),
334+ f"Expected { number_of_inputs } input blob { 's' if number_of_inputs != 1 else '' } , "
335+ f"but { len (self .inputs )} found: { ', ' .join (self .inputs )} " ,
333336 )
334337 elif len (self .inputs ) not in number_of_inputs :
335338 self .raise_error (
336- "Expected {} or {} input blobs, but {} found: {}" .format (
337- ", " .join (str (n ) for n in number_of_inputs [:- 1 ]),
338- int (number_of_inputs [- 1 ]),
339- len (self .inputs ),
340- ", " .join (self .inputs ),
341- ),
339+ f"Expected { ', ' .join (str (n ) for n in number_of_inputs [:- 1 ])} or "
340+ f"{ int (number_of_inputs [- 1 ])} input blobs, but { len (self .inputs )} found: { ', ' .join (self .inputs )} " ,
342341 )
343342
344- if not isinstance (number_of_outputs , tuple ):
343+ if isinstance (number_of_outputs , int ):
345344 if len (self .outputs ) != number_of_outputs and number_of_outputs != - 1 :
346345 self .raise_error (
347- "Expected {} output blob{}, but {} found: {}" .format (
348- number_of_outputs ,
349- "s" if number_of_outputs != 1 else "" ,
350- len (self .outputs ),
351- ", " .join (self .outputs ),
352- ),
346+ f"Expected { number_of_outputs } output blob { 's' if number_of_outputs != 1 else '' } , "
347+ f"but { len (self .outputs )} found: { ', ' .join (self .outputs )} " ,
353348 )
354349 elif len (self .outputs ) not in number_of_outputs :
355350 self .raise_error (
356- "Expected {} or {} output blobs, but {} found: {}" .format (
357- ", " .join (str (n ) for n in number_of_outputs [:- 1 ]),
358- int (number_of_outputs [- 1 ]),
359- len (self .outputs ),
360- ", " .join (self .outputs ),
361- ),
351+ f"Expected { ', ' .join (str (n ) for n in number_of_outputs [:- 1 ])} or "
352+ f"{ int (number_of_outputs [- 1 ])} output blobs, "
353+ f"but { len (self .outputs )} found: { ', ' .join (self .outputs )} " ,
362354 )
363355
364- def __call__ (self , inputs ):
356+ def __call__ (self , inputs : ndarray ):
365357 """Applies preprocessing, synchronous inference, postprocessing routines while one call.
366358
367359 Args:
@@ -407,7 +399,7 @@ def batch_infer_callback(result, id):
407399
408400 return [completed_results [i ] for i in range (len (inputs ))]
409401
410- def load (self , force = False ):
402+ def load (self , force : bool = False ) -> None :
411403 if not self .model_loaded or force :
412404 self .model_loaded = True
413405 self .inference_adapter .load_model ()
@@ -423,7 +415,7 @@ def reshape(self, new_shape):
423415 self .inputs = self .inference_adapter .get_input_layers ()
424416 self .outputs = self .inference_adapter .get_output_layers ()
425417
426- def infer_sync (self , dict_data ) :
418+ def infer_sync (self , dict_data : dict [ str , ndarray ]) -> dict [ str , ndarray ] :
427419 if not self .model_loaded :
428420 self .raise_error (
429421 "The model is not loaded to the device. Please, create the wrapper "
0 commit comments