Skip to content

Commit 453d692

Browse files
committed
Cover ImageModel
1 parent 61f9004 commit 453d692

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def embed_preprocessing(
169169
layout,
170170
resize_mode: str,
171171
interpolation_mode,
172-
target_shape: tuple[int],
172+
target_shape: tuple[int, ...],
173173
pad_value,
174174
dtype: type = int,
175175
brg2rgb=False,

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def embed_preprocessing(
338338
layout: str,
339339
resize_mode: str,
340340
interpolation_mode: str,
341-
target_shape: tuple[int],
341+
target_shape: tuple[int, ...],
342342
pad_value: int,
343343
dtype: type = int,
344344
brg2rgb: bool = False,

model_api/python/model_api/models/image_model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
from typing import Any
7+
68
import numpy as np
79

10+
from model_api.adapters.inference_adapter import InferenceAdapter
811
from model_api.adapters.utils import RESIZE_TYPES, InputTransform
912
from model_api.models.model import Model
1013
from model_api.models.types import BooleanValue, ListValue, NumericalValue, StringValue
@@ -33,7 +36,7 @@ class ImageModel(Model):
3336

3437
__model__ = "ImageModel"
3538

36-
def __init__(self, inference_adapter, configuration: dict = {}, preload=False):
39+
def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}, preload: bool = False) -> None:
3740
"""Image model constructor
3841
3942
It extends the `Model` constructor.
@@ -59,6 +62,7 @@ def __init__(self, inference_adapter, configuration: dict = {}, preload=False):
5962
self.scale_values: list
6063
self.reverse_input_channels: bool
6164
self.embedded_processing: bool
65+
self.labels: list[str]
6266

6367
self.nchw_layout = self.inputs[self.image_blob_name].layout == "NCHW"
6468
if self.nchw_layout:
@@ -90,7 +94,7 @@ def __init__(self, inference_adapter, configuration: dict = {}, preload=False):
9094
self.orig_height, self.orig_width = self.h, self.w
9195

9296
@classmethod
93-
def parameters(cls) -> dict:
97+
def parameters(cls) -> dict[str, Any]:
9498
parameters = super().parameters()
9599
parameters.update(
96100
{
@@ -137,14 +141,14 @@ def parameters(cls) -> dict:
137141
)
138142
return parameters
139143

140-
def get_label_name(self, label_id):
144+
def get_label_name(self, label_id: int) -> str:
141145
if self.labels is None:
142146
return f"#{label_id}"
143147
if label_id >= len(self.labels):
144148
return f"#{label_id}"
145149
return self.labels[label_id]
146150

147-
def _get_inputs(self):
151+
def _get_inputs(self) -> tuple[list[str], ...]:
148152
"""Defines the model inputs for images and additional info.
149153
150154
Raises:
@@ -170,7 +174,7 @@ def _get_inputs(self):
170174
)
171175
return image_blob_names, image_info_blob_names
172176

173-
def preprocess(self, inputs) -> list[dict]:
177+
def preprocess(self, inputs: np.ndarray) -> list[dict]:
174178
"""Data preprocess method
175179
176180
It performs basic preprocessing of a single image:
@@ -203,7 +207,7 @@ def preprocess(self, inputs) -> list[dict]:
203207
},
204208
]
205209

206-
def _change_layout(self, image):
210+
def _change_layout(self, image: np.ndarray) -> np.ndarray:
207211
"""Changes the input image layout to fit the layout of the model input layer.
208212
209213
Args:

0 commit comments

Comments
 (0)