Skip to content

Commit 1f335bf

Browse files
committed
Update model docs
1 parent 99c497a commit 1f335bf

File tree

6 files changed

+111
-20
lines changed

6 files changed

+111
-20
lines changed

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from abc import ABC, abstractmethod
77
from dataclasses import dataclass, field
8-
from typing import Any
8+
from typing import Any, Callable
99

1010

1111
@dataclass
@@ -137,6 +137,15 @@ def get_raw_result(self, infer_result: dict) -> dict:
137137
}
138138
"""
139139

140+
@abstractmethod
141+
def set_callback(self, callback_fn: Callable):
142+
"""
143+
Sets callback that grabs results of async inference.
144+
145+
Args:
146+
callback_fn (Callable): Callback function.
147+
"""
148+
140149
@abstractmethod
141150
def is_ready(self) -> bool:
142151
"""In case of asynchronous execution checks if one can submit input data

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import sys
99
from functools import partial, reduce
10-
from typing import Any
10+
from typing import Any, Callable
1111

1212
import numpy as np
1313

@@ -111,7 +111,7 @@ def infer_sync(self, dict_data):
111111
def infer_async(self, dict_data, callback_data):
112112
raise NotImplementedError
113113

114-
def set_callback(self, callback_fn):
114+
def set_callback(self, callback_fn: Callable):
115115
self.callback_fn = callback_fn
116116

117117
def is_ready(self):
@@ -126,7 +126,7 @@ def await_all(self):
126126
def await_any(self):
127127
pass
128128

129-
def get_raw_result(self, infer_result):
129+
def get_raw_result(self, infer_result: dict):
130130
pass
131131

132132
def embed_preprocessing(

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging as log
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any
10+
from typing import TYPE_CHECKING, Any, Callable
1111

1212
if TYPE_CHECKING:
1313
from os import PathLike
@@ -300,7 +300,7 @@ def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]:
300300
def infer_async(self, dict_data, callback_data) -> None:
301301
self.async_queue.start_async(dict_data, callback_data)
302302

303-
def set_callback(self, callback_fn):
303+
def set_callback(self, callback_fn: Callable):
304304
self.async_queue.set_callback(callback_fn)
305305

306306
def is_ready(self) -> bool:

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55

66
import re
7-
from typing import Any
7+
from typing import Any, Callable
88

99
import numpy as np
1010

@@ -79,7 +79,7 @@ def infer_async(self, dict_data, callback_data):
7979
raw_result = {output_name: raw_result}
8080
self.callback_fn(raw_result, (lambda x: x, callback_data))
8181

82-
def set_callback(self, callback_fn):
82+
def set_callback(self, callback_fn: Callable):
8383
self.callback_fn = callback_fn
8484

8585
def is_ready(self):
@@ -98,7 +98,7 @@ def await_all(self):
9898
def await_any(self):
9999
pass
100100

101-
def get_raw_result(self, infer_result):
101+
def get_raw_result(self, infer_result: dict):
102102
pass
103103

104104
def embed_preprocessing(

model_api/python/model_api/models/image_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ def parameters(cls) -> dict[str, Any]:
146146
return parameters
147147

148148
def get_label_name(self, label_id: int) -> str:
149+
"""
150+
Returns a label name by it's index.
151+
If index is out of range, and auto-generated name is returned.
152+
153+
Args:
154+
label_id (int): label index.
155+
156+
Returns:
157+
str: label name.
158+
"""
149159
if self.labels is None:
150160
return f"#{label_id}"
151161
if label_id >= len(self.labels):

model_api/python/model_api/models/model.py

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging as log
99
import re
1010
from contextlib import contextmanager
11-
from typing import TYPE_CHECKING, Any, NoReturn, Type
11+
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Type
1212

1313
from model_api.adapters.inference_adapter import InferenceAdapter
1414
from model_api.adapters.onnx_adapter import ONNXRuntimeAdapter
@@ -98,11 +98,26 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
9898
self.load()
9999
self.callback_fn = lambda _: None
100100

101-
def get_model(self):
101+
def get_model(self) -> Any:
102+
"""
103+
Returns underlying adapter-specific model.
104+
105+
Returns:
106+
Any: Model object.
107+
"""
102108
return self.inference_adapter.get_model()
103109

104110
@classmethod
105111
def get_model_class(cls, name: str) -> Type:
112+
"""
113+
Retrieves a wrapper class by a given wrapper name.
114+
115+
Args:
116+
name (str): Wrapper name.
117+
118+
Returns:
119+
Type: Model class.
120+
"""
106121
subclasses = [subclass for subclass in cls.get_subclasses() if subclass.__model__]
107122
if cls.__model__:
108123
subclasses.append(cls)
@@ -188,14 +203,19 @@ def create_model(
188203

189204
@classmethod
190205
def get_subclasses(cls) -> list[Any]:
206+
"""Retrieves all the subclasses of the model class given."""
191207
all_subclasses = []
192208
for subclass in cls.__subclasses__():
193209
all_subclasses.append(subclass)
194210
all_subclasses.extend(subclass.get_subclasses())
195211
return all_subclasses
196212

197213
@classmethod
198-
def available_wrappers(cls):
214+
def available_wrappers(cls) -> list[str]:
215+
"""
216+
Prepares a list of all discoverable wrapper names
217+
(including custom ones inherited from the core wrappers).
218+
"""
199219
available_classes = [cls] if cls.__model__ else []
200220
available_classes.extend(cls.get_subclasses())
201221
return [subclass.__model__ for subclass in available_classes if subclass.__model__]
@@ -368,7 +388,7 @@ def __call__(self, inputs: ndarray):
368388
raw_result = self.infer_sync(dict_data)
369389
return self.postprocess(raw_result, input_meta)
370390

371-
def infer_batch(self, inputs):
391+
def infer_batch(self, inputs: list) -> list[Any]:
372392
"""Applies preprocessing, asynchronous inference, postprocessing routines to a collection of inputs.
373393
374394
Args:
@@ -402,11 +422,24 @@ def batch_infer_callback(result, id):
402422
return [completed_results[i] for i in range(len(inputs))]
403423

404424
def load(self, force: bool = False) -> None:
425+
"""
426+
Prepares the model to be executed by the inference adapter.
427+
428+
Args:
429+
force (bool, optional): Forces the process even if the model is ready. Defaults to False.
430+
"""
405431
if not self.model_loaded or force:
406432
self.model_loaded = True
407433
self.inference_adapter.load_model()
408434

409-
def reshape(self, new_shape):
435+
def reshape(self, new_shape: dict):
436+
"""
437+
Reshapes the model inputs to fit the new input shape.
438+
439+
Args:
440+
new_shape (_type_): a dictionary with inputs names as keys and
441+
list of new shape as values in the following format.
442+
"""
410443
if self.model_loaded:
411444
self.logger.warning(
412445
f"{self.__model__}: the model already loaded to device, ",
@@ -418,22 +451,41 @@ def reshape(self, new_shape):
418451
self.outputs = self.inference_adapter.get_output_layers()
419452

420453
def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]:
454+
"""
455+
Performs the synchronous model inference. The infer is a blocking method.
456+
See InferenceAdapter documentation for details.
457+
"""
421458
if not self.model_loaded:
422459
self.raise_error(
423460
"The model is not loaded to the device. Please, create the wrapper "
424461
"with preload=True option or call load() method before infer_sync()",
425462
)
426463
return self.inference_adapter.infer_sync(dict_data)
427464

428-
def infer_async_raw(self, dict_data, callback_data):
465+
def infer_async_raw(self, dict_data: dict, callback_data: Any):
466+
"""
467+
Runs asynchronous inference on raw data skipping preprocess() call.
468+
469+
Args:
470+
dict_data (dict): data to be passed to the model
471+
callback_data (Any): data to be passed to the callback alongside with inference results.
472+
"""
429473
if not self.model_loaded:
430474
self.raise_error(
431475
"The model is not loaded to the device. Please, create the wrapper "
432476
"with preload=True option or call load() method before infer_async()",
433477
)
434478
self.inference_adapter.infer_async(dict_data, callback_data)
435479

436-
def infer_async(self, input_data, user_data):
480+
def infer_async(self, input_data: dict, user_data: Any):
481+
"""
482+
Runs asynchronous model inference.
483+
484+
Args:
485+
input_data (dict): Input dict containing model input name as keys and data object as values.
486+
user_data (Any): data to be passed to the callback alongside with inference results.
487+
"""
488+
437489
if not self.model_loaded:
438490
self.raise_error(
439491
"The model is not loaded to the device. Please, create the wrapper "
@@ -452,23 +504,35 @@ def infer_async(self, input_data, user_data):
452504
)
453505

454506
@staticmethod
455-
def process_callback(request, callback_data):
507+
def _process_callback(request, callback_data: Any):
508+
"""
509+
A wrapper for async inference callback.
510+
"""
456511
meta, get_result_fn, postprocess_fn, callback_fn, user_data = callback_data
457512
raw_result = get_result_fn(request)
458513
result = postprocess_fn(raw_result, meta)
459514
callback_fn(result, user_data)
460515

461-
def set_callback(self, callback_fn):
516+
def set_callback(self, callback_fn: Callable):
517+
"""
518+
Sets callback that grabs results of async inference.
519+
520+
Args:
521+
callback_fn (Callable): _description_
522+
"""
462523
self.callback_fn = callback_fn
463-
self.inference_adapter.set_callback(Model.process_callback)
524+
self.inference_adapter.set_callback(Model._process_callback)
464525

465526
def is_ready(self):
527+
"""Checks if model is ready for async inference."""
466528
return self.inference_adapter.is_ready()
467529

468530
def await_all(self):
531+
"""Waits for all async inference requests to be completed."""
469532
self.inference_adapter.await_all()
470533

471534
def await_any(self):
535+
"""Waits for model to be available for an async infer request."""
472536
self.inference_adapter.await_any()
473537

474538
def log_layers_info(self):
@@ -484,7 +548,15 @@ def log_layers_info(self):
484548
f"precision: {metadata.precision}, layout: {metadata.layout}",
485549
)
486550

487-
def save(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
551+
def save(self, path: str, weights_path: str | None, version: str | None):
552+
"""
553+
Serializes model to the filesystem. Model format depends in the InferenceAdapter being used.
554+
555+
Args:
556+
path (str): Path to write the resulting model.
557+
weights_path (str | None): Optional path to save weights if they are stored separately.
558+
version (str | None): Optional model version.
559+
"""
488560
model_info = {
489561
"model_type": self.__model__,
490562
}

0 commit comments

Comments
 (0)