Skip to content

Commit 99c497a

Browse files
committed
Update adapter doc and api
1 parent 2a5f66f commit 99c497a

File tree

4 files changed

+115
-31
lines changed

4 files changed

+115
-31
lines changed

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_output_layers(self):
6969
"""
7070

7171
@abstractmethod
72-
def reshape_model(self, new_shape):
72+
def reshape_model(self, new_shape: dict):
7373
"""Reshapes the model inputs to fit the new input shape.
7474
7575
Args:
@@ -83,7 +83,7 @@ def reshape_model(self, new_shape):
8383
"""
8484

8585
@abstractmethod
86-
def infer_sync(self, dict_data) -> dict:
86+
def infer_sync(self, dict_data: dict) -> dict:
8787
"""Performs the synchronous model inference. The infer is a blocking method.
8888
8989
Args:
@@ -104,7 +104,7 @@ def infer_sync(self, dict_data) -> dict:
104104
"""
105105

106106
@abstractmethod
107-
def infer_async(self, dict_data, callback_data):
107+
def infer_async(self, dict_data: dict, callback_data: Any):
108108
"""
109109
Performs the asynchronous model inference and sets
110110
the callback for inference completion. Also, it should
@@ -122,11 +122,11 @@ def infer_async(self, dict_data, callback_data):
122122
"""
123123

124124
@abstractmethod
125-
def get_raw_result(self, infer_result) -> dict:
125+
def get_raw_result(self, infer_result: dict) -> dict:
126126
"""Gets raw results from the internal inference framework representation as a dict.
127127
128128
Args:
129-
- infer_result: framework-specific result of inference from the model
129+
- infer_resul (dict): framework-specific result of inference from the model
130130
131131
Returns:
132132
- raw result (dict) - model raw output in the following format:
@@ -138,7 +138,7 @@ def get_raw_result(self, infer_result) -> dict:
138138
"""
139139

140140
@abstractmethod
141-
def is_ready(self):
141+
def is_ready(self) -> bool:
142142
"""In case of asynchronous execution checks if one can submit input data
143143
to the model for inference, or all infer requests are busy.
144144
@@ -160,29 +160,67 @@ def await_any(self):
160160
"""
161161

162162
@abstractmethod
163-
def get_rt_info(self, path):
164-
"""Forwards to openvino.Model.get_rt_info(path)"""
163+
def get_rt_info(self, path: list[str]) -> Any:
164+
"""
165+
Returns an attribute stored in model info.
166+
167+
Args:
168+
path (list[str]): a sequence of tag names leading to the attribute.
169+
170+
Returns:
171+
Any: a value stored under corresponding tag sequence.
172+
"""
165173

166174
@abstractmethod
167175
def update_model_info(self, model_info: dict[str, Any]):
168-
"""Updates model with the provided model info."""
176+
"""
177+
Updates model with the provided model info. Model info dict can
178+
also contain nested dicts.
179+
180+
Args:
181+
model_info (dict[str, Any]): model info dict to write to the model.
182+
"""
169183

170184
@abstractmethod
171-
def save_model(self, path: str, weights_path: str, version: str):
172-
"""Serializes model to the filesystem."""
185+
def save_model(self, path: str, weights_path: str | None, version: str | None):
186+
"""
187+
Serializes model to the filesystem.
188+
189+
Args:
190+
path (str): Path to write the resulting model.
191+
weights_path (str | None): Optional path to save weights if they are stored separately.
192+
version (str | None): Optional model version.
193+
"""
173194

174195
@abstractmethod
175196
def embed_preprocessing(
176197
self,
177-
layout,
198+
layout: str,
178199
resize_mode: str,
179-
interpolation_mode,
200+
interpolation_mode: str,
180201
target_shape: tuple[int, ...],
181-
pad_value,
202+
pad_value: int,
182203
dtype: type = int,
183-
brg2rgb=False,
184-
mean=None,
185-
scale=None,
186-
input_idx=0,
204+
brg2rgb: bool = False,
205+
mean: list[Any] | None = None,
206+
scale: list[Any] | None = None,
207+
input_idx: int = 0,
187208
):
188-
"""Embeds preprocessing into the model using OpenVINO preprocessing API"""
209+
"""
210+
Embeds preprocessing into the model if possible with the adapter being used.
211+
In some cases, this method would just add extra python preprocessing steps
212+
instaed actuall of embedding it into the model representation.
213+
214+
Args:
215+
layout (str): Layout, for instance NCHW.
216+
resize_mode (str): Resize type to use for preprocessing.
217+
interpolation_mode (str): Resize interpolation mode.
218+
target_shape (tuple[int, ...]): Target resize shape.
219+
pad_value (int): Value to pad with if resize implies padding.
220+
dtype (type, optional): Input data type for the preprocessing module. Defaults to int.
221+
bgr2rgb (bool, optional): Defines if we need to swap R and B channels in case of image input.
222+
Defaults to False.
223+
mean (list[Any] | None, optional): Mean values to perform input normalization. Defaults to None.
224+
scale (list[Any] | None, optional): Scale values to perform input normalization. Defaults to None.
225+
input_idx (int, optional): Index of the model input to apply preprocessing to. Defaults to 0.
226+
"""

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,20 @@ def get_raw_result(self, infer_result):
131131

132132
def embed_preprocessing(
133133
self,
134-
layout,
134+
layout: str,
135135
resize_mode: str,
136-
interpolation_mode,
137-
target_shape,
138-
pad_value,
136+
interpolation_mode: str,
137+
target_shape: tuple[int, ...],
138+
pad_value: int,
139139
dtype: type = int,
140-
brg2rgb=False,
141-
mean=None,
142-
scale=None,
143-
input_idx=0,
140+
brg2rgb: bool = False,
141+
mean: list[Any] | None = None,
142+
scale: list[Any] | None = None,
143+
input_idx: int = 0,
144144
):
145+
"""
146+
Adds external preprocessing steps done before ONNX model execution.
147+
"""
145148
preproc_funcs = [np.squeeze]
146149
if resize_mode != "crop":
147150
if resize_mode == "fit_to_window_letterbox":
@@ -170,13 +173,23 @@ def embed_preprocessing(
170173
)
171174

172175
def get_model(self):
173-
"""Return the reference to the ONNXRuntime session."""
176+
"""Return a reference to the ONNXRuntime session."""
174177
return self.model
175178

176179
def reshape_model(self, new_shape):
180+
""" "Not supported by ONNX adapter."""
177181
raise NotImplementedError
178182

179183
def get_rt_info(self, path):
184+
"""
185+
Returns an attribute stored in model info.
186+
187+
Args:
188+
path (list[str]): a sequence of tag names leading to the attribute.
189+
190+
Returns:
191+
Any: a value stored under corresponding tag sequence.
192+
"""
180193
return get_rt_info_from_dict(self.onnx_metadata, path)
181194

182195
def update_model_info(self, model_info: dict[str, Any]):
@@ -189,7 +202,15 @@ def update_model_info(self, model_info: dict[str, Any]):
189202
else:
190203
meta.value = str(model_info[item])
191204

192-
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
205+
def save_model(self, path: str, weights_path: str | None, version: str | None):
206+
"""
207+
Serializes model to the filesystem.
208+
209+
Args:
210+
path (str): paths to save .onnx file.
211+
weights_path (str | None): not used by ONNX adapter.
212+
version (str | None): not used by ONNX adapter.
213+
"""
193214
onnx.save(self.model, path)
194215

195216

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,15 @@ def operations_by_type(self, operation_type):
333333
return layers_info
334334

335335
def get_rt_info(self, path: list[str]) -> OVAny:
336+
"""
337+
Gets an attribute value from OV.model_info structure.
338+
339+
Args:
340+
path (list[str]): a suquence of tag names leading to the attribute.
341+
342+
Returns:
343+
OVAny: attribute value wrapped into OVAny object.
344+
"""
336345
if self.is_onnx_file:
337346
return get_rt_info_from_dict(self.onnx_metadata, path)
338347
return self.model.get_rt_info(path)
@@ -350,6 +359,9 @@ def embed_preprocessing(
350359
scale: list[Any] | None = None,
351360
input_idx: int = 0,
352361
) -> None:
362+
"""
363+
Embeds OpenVINO PrePostProcessor module into the model.
364+
"""
353365
ppp = PrePostProcessor(self.model)
354366

355367
# Change the input type to the 8-bit image
@@ -429,7 +441,20 @@ def update_model_info(self, model_info: dict[str, Any]):
429441
for name in model_info:
430442
self.model.set_rt_info(model_info[name], ["model_info", name])
431443

432-
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
444+
def save_model(self, path: str, weights_path: str | None, version: str | None):
445+
"""
446+
Saves OV model as two files: .xml (architecture) and .bin (weights).
447+
448+
Args:
449+
path (str): path to save the model files (.xml and .bin).
450+
weights_path (str, optional): Optional path to save .bin if it differs from .xml path. Defaults to None.
451+
version (str, optional): Output IR model version (for instance, IR_V10). Defaults to None.
452+
"""
453+
if weights_path is None:
454+
weights_path = ""
455+
if version is None:
456+
version = "UNSPECIFIED"
457+
433458
ov.serialize(self.get_model(), path, weights_path, version)
434459

435460

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def update_model_info(self, model_info: dict[str, Any]):
127127
msg = "OVMSAdapter does not support updating model info"
128128
raise NotImplementedError(msg)
129129

130-
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
130+
def save_model(self, path: str, weights_path: str | None, version: str | None):
131131
msg = "OVMSAdapter does not support saving a model"
132132
raise NotImplementedError(msg)
133133

0 commit comments

Comments
 (0)