88import logging as log
99import re
1010from contextlib import contextmanager
11- from typing import TYPE_CHECKING , Any , NoReturn , Type
11+ from typing import TYPE_CHECKING , Any , Callable , NoReturn , Type
1212
1313from model_api .adapters .inference_adapter import InferenceAdapter
1414from model_api .adapters .onnx_adapter import ONNXRuntimeAdapter
@@ -98,11 +98,26 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
9898 self .load ()
9999 self .callback_fn = lambda _ : None
100100
101- def get_model (self ):
101+ def get_model (self ) -> Any :
102+ """
103+ Returns underlying adapter-specific model.
104+
105+ Returns:
106+ Any: Model object.
107+ """
102108 return self .inference_adapter .get_model ()
103109
104110 @classmethod
105111 def get_model_class (cls , name : str ) -> Type :
112+ """
113+ Retrieves a wrapper class by a given wrapper name.
114+
115+ Args:
116+ name (str): Wrapper name.
117+
118+ Returns:
119+ Type: Model class.
120+ """
106121 subclasses = [subclass for subclass in cls .get_subclasses () if subclass .__model__ ]
107122 if cls .__model__ :
108123 subclasses .append (cls )
@@ -188,14 +203,19 @@ def create_model(
188203
189204 @classmethod
190205 def get_subclasses (cls ) -> list [Any ]:
206+ """Retrieves all the subclasses of the model class given."""
191207 all_subclasses = []
192208 for subclass in cls .__subclasses__ ():
193209 all_subclasses .append (subclass )
194210 all_subclasses .extend (subclass .get_subclasses ())
195211 return all_subclasses
196212
197213 @classmethod
198- def available_wrappers (cls ):
214+ def available_wrappers (cls ) -> list [str ]:
215+ """
216+ Prepares a list of all discoverable wrapper names
217+ (including custom ones inherited from the core wrappers).
218+ """
199219 available_classes = [cls ] if cls .__model__ else []
200220 available_classes .extend (cls .get_subclasses ())
201221 return [subclass .__model__ for subclass in available_classes if subclass .__model__ ]
@@ -368,7 +388,7 @@ def __call__(self, inputs: ndarray):
368388 raw_result = self .infer_sync (dict_data )
369389 return self .postprocess (raw_result , input_meta )
370390
371- def infer_batch (self , inputs ) :
391+ def infer_batch (self , inputs : list ) -> list [ Any ] :
372392 """Applies preprocessing, asynchronous inference, postprocessing routines to a collection of inputs.
373393
374394 Args:
@@ -402,11 +422,24 @@ def batch_infer_callback(result, id):
402422 return [completed_results [i ] for i in range (len (inputs ))]
403423
404424 def load (self , force : bool = False ) -> None :
425+ """
426+ Prepares the model to be executed by the inference adapter.
427+
428+ Args:
429+ force (bool, optional): Forces the process even if the model is ready. Defaults to False.
430+ """
405431 if not self .model_loaded or force :
406432 self .model_loaded = True
407433 self .inference_adapter .load_model ()
408434
409- def reshape (self , new_shape ):
435+ def reshape (self , new_shape : dict ):
436+ """
437+ Reshapes the model inputs to fit the new input shape.
438+
439+ Args:
440+ new_shape (_type_): a dictionary with inputs names as keys and
441+ list of new shape as values in the following format.
442+ """
410443 if self .model_loaded :
411444 self .logger .warning (
412445 f"{ self .__model__ } : the model already loaded to device, " ,
@@ -418,22 +451,41 @@ def reshape(self, new_shape):
418451 self .outputs = self .inference_adapter .get_output_layers ()
419452
420453 def infer_sync (self , dict_data : dict [str , ndarray ]) -> dict [str , ndarray ]:
454+ """
455+ Performs the synchronous model inference. The infer is a blocking method.
456+ See InferenceAdapter documentation for details.
457+ """
421458 if not self .model_loaded :
422459 self .raise_error (
423460 "The model is not loaded to the device. Please, create the wrapper "
424461 "with preload=True option or call load() method before infer_sync()" ,
425462 )
426463 return self .inference_adapter .infer_sync (dict_data )
427464
428- def infer_async_raw (self , dict_data , callback_data ):
465+ def infer_async_raw (self , dict_data : dict , callback_data : Any ):
466+ """
467+ Runs asynchronous inference on raw data skipping preprocess() call.
468+
469+ Args:
470+ dict_data (dict): data to be passed to the model
471+ callback_data (Any): data to be passed to the callback alongside with inference results.
472+ """
429473 if not self .model_loaded :
430474 self .raise_error (
431475 "The model is not loaded to the device. Please, create the wrapper "
432476 "with preload=True option or call load() method before infer_async()" ,
433477 )
434478 self .inference_adapter .infer_async (dict_data , callback_data )
435479
436- def infer_async (self , input_data , user_data ):
480+ def infer_async (self , input_data : dict , user_data : Any ):
481+ """
482+ Runs asynchronous model inference.
483+
484+ Args:
485+ input_data (dict): Input dict containing model input name as keys and data object as values.
486+ user_data (Any): data to be passed to the callback alongside with inference results.
487+ """
488+
437489 if not self .model_loaded :
438490 self .raise_error (
439491 "The model is not loaded to the device. Please, create the wrapper "
@@ -452,23 +504,35 @@ def infer_async(self, input_data, user_data):
452504 )
453505
454506 @staticmethod
455- def process_callback (request , callback_data ):
507+ def _process_callback (request , callback_data : Any ):
508+ """
509+ A wrapper for async inference callback.
510+ """
456511 meta , get_result_fn , postprocess_fn , callback_fn , user_data = callback_data
457512 raw_result = get_result_fn (request )
458513 result = postprocess_fn (raw_result , meta )
459514 callback_fn (result , user_data )
460515
461- def set_callback (self , callback_fn ):
516+ def set_callback (self , callback_fn : Callable ):
517+ """
518+ Sets callback that grabs results of async inference.
519+
520+ Args:
521+ callback_fn (Callable): _description_
522+ """
462523 self .callback_fn = callback_fn
463- self .inference_adapter .set_callback (Model .process_callback )
524+ self .inference_adapter .set_callback (Model ._process_callback )
464525
465526 def is_ready (self ):
527+ """Checks if model is ready for async inference."""
466528 return self .inference_adapter .is_ready ()
467529
468530 def await_all (self ):
531+ """Waits for all async inference requests to be completed."""
469532 self .inference_adapter .await_all ()
470533
471534 def await_any (self ):
535+ """Waits for model to be available for an async infer request."""
472536 self .inference_adapter .await_any ()
473537
474538 def log_layers_info (self ):
@@ -484,7 +548,15 @@ def log_layers_info(self):
484548 f"precision: { metadata .precision } , layout: { metadata .layout } " ,
485549 )
486550
487- def save (self , path : str , weights_path : str = "" , version : str = "UNSPECIFIED" ):
551+ def save (self , path : str , weights_path : str | None , version : str | None ):
552+ """
553+ Serializes model to the filesystem. Model format depends in the InferenceAdapter being used.
554+
555+ Args:
556+ path (str): Path to write the resulting model.
557+ weights_path (str | None): Optional path to save weights if they are stored separately.
558+ version (str | None): Optional model version.
559+ """
488560 model_info = {
489561 "model_type" : self .__model__ ,
490562 }
0 commit comments