Skip to content

Commit 23620d1

Browse files
✨Add MyPy (#219)
* add mypy Signed-off-by: Ashwin Vaidya <[email protected]> * fix imports Signed-off-by: Ashwin Vaidya <[email protected]> * add __model__ to detection Signed-off-by: Ashwin Vaidya <[email protected]> * add get_model to ovms_adapter Signed-off-by: Ashwin Vaidya <[email protected]> * add __model__ to Model Signed-off-by: Ashwin Vaidya <[email protected]> * Add get_model method to ONNXRuntimeAdapter Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 25c88f8 commit 23620d1

File tree

21 files changed

+141
-87
lines changed

21 files changed

+141
-87
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ repos:
2323
# Run the formatter
2424
- id: ruff-format
2525

26+
# python static type checking
27+
- repo: https://github.com/pre-commit/mirrors-mypy
28+
rev: "v1.11.2"
29+
hooks:
30+
- id: mypy
31+
additional_dependencies: [types-PyYAML, types-setuptools]
32+
2633
- repo: https://github.com/pre-commit/mirrors-prettier
2734
rev: v4.0.0-alpha.8
2835
hooks:

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
]
4343

4444
templates_path = ["_templates"]
45-
exclude_patterns = []
45+
exclude_patterns: list[str] = []
4646

4747
# Automatic exclusion of prompts from the copies
4848
# https://sphinx-copybutton.readthedocs.io/en/latest/use.html#automatic-exclusion-of-prompts-from-the-copies

examples/python/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6-
import abc
6+
from abc import ABC, abstractmethod
77
from dataclasses import dataclass, field
8+
from typing import Any, Dict, List, Set, Tuple
89

910

1011
@dataclass
@@ -17,30 +18,37 @@ class Metadata:
1718
meta: dict = field(default_factory=dict)
1819

1920

20-
class InferenceAdapter(abc.ABC):
21-
"""An abstract Model Adapter with the following interface:
22-
23-
- Reading the model from disk or other place
24-
- Loading the model to the device
25-
- Accessing the information about inputs/outputs
26-
- The model reshaping
27-
- Synchronous model inference
28-
- Asynchronous model inference
21+
class InferenceAdapter(ABC):
22+
"""
23+
An abstract Model Adapter with the following interface:
24+
25+
- Reading the model from disk or other place
26+
- Loading the model to the device
27+
- Accessing the information about inputs/outputs
28+
- The model reshaping
29+
- Synchronous model inference
30+
- Asynchronous model inference
2931
"""
3032

3133
precisions = ("FP32", "I32", "FP16", "I16", "I8", "U8")
3234

33-
@abc.abstractmethod
34-
def __init__(self):
35-
"""An abstract Model Adapter constructor.
35+
@abstractmethod
36+
def __init__(self) -> None:
37+
"""
38+
An abstract Model Adapter constructor.
3639
Reads the model from disk or other place.
3740
"""
41+
self.model: Any
3842

39-
@abc.abstractmethod
43+
@abstractmethod
4044
def load_model(self):
4145
"""Loads the model on the device."""
4246

43-
@abc.abstractmethod
47+
@abstractmethod
48+
def get_model(self):
49+
"""Get the model."""
50+
51+
@abstractmethod
4452
def get_input_layers(self):
4553
"""Gets the names of model inputs and for each one creates the Metadata structure,
4654
which contains the information about the input shape, layout, precision
@@ -50,7 +58,7 @@ def get_input_layers(self):
5058
- the dict containing Metadata for all inputs
5159
"""
5260

53-
@abc.abstractmethod
61+
@abstractmethod
5462
def get_output_layers(self):
5563
"""Gets the names of model outputs and for each one creates the Metadata structure,
5664
which contains the information about the output shape, layout, precision
@@ -60,7 +68,7 @@ def get_output_layers(self):
6068
- the dict containing Metadata for all outputs
6169
"""
6270

63-
@abc.abstractmethod
71+
@abstractmethod
6472
def reshape_model(self, new_shape):
6573
"""Reshapes the model inputs to fit the new input shape.
6674
@@ -74,7 +82,7 @@ def reshape_model(self, new_shape):
7482
}
7583
"""
7684

77-
@abc.abstractmethod
85+
@abstractmethod
7886
def infer_sync(self, dict_data):
7987
"""Performs the synchronous model inference. The infer is a blocking method.
8088
@@ -95,9 +103,10 @@ def infer_sync(self, dict_data):
95103
}
96104
"""
97105

98-
@abc.abstractmethod
99-
def infer_async(self, dict_data, callback_fn, callback_data):
100-
"""Performs the asynchronous model inference and sets
106+
@abstractmethod
107+
def infer_async(self, dict_data, callback_data):
108+
"""
109+
Performs the asynchronous model inference and sets
101110
the callback for inference completion. Also, it should
102111
define get_raw_result() function, which handles the result
103112
of inference from the model.
@@ -109,11 +118,10 @@ def infer_async(self, dict_data, callback_fn, callback_data):
109118
'input_layer_name_2': data_2,
110119
...
111120
}
112-
- callback_fn: the callback function, which is defined outside the adapter
113121
- callback_data: the data for callback, that will be taken after the model inference is ended
114122
"""
115123

116-
@abc.abstractmethod
124+
@abstractmethod
117125
def is_ready(self):
118126
"""In case of asynchronous execution checks if one can submit input data
119127
to the model for inference, or all infer requests are busy.
@@ -123,23 +131,23 @@ def is_ready(self):
123131
submitted to the model for inference or not
124132
"""
125133

126-
@abc.abstractmethod
134+
@abstractmethod
127135
def await_all(self):
128136
"""In case of asynchronous execution waits the completion of all
129137
busy infer requests.
130138
"""
131139

132-
@abc.abstractmethod
140+
@abstractmethod
133141
def await_any(self):
134142
"""In case of asynchronous execution waits the completion of any
135143
busy infer request until it becomes available for the data submission.
136144
"""
137145

138-
@abc.abstractmethod
146+
@abstractmethod
139147
def get_rt_info(self, path):
140148
"""Forwards to openvino.Model.get_rt_info(path)"""
141149

142-
@abc.abstractmethod
150+
@abstractmethod
143151
def embed_preprocessing(
144152
self,
145153
layout,

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ def embed_preprocessing(
162162
reversed(preproc_funcs),
163163
)
164164

165+
def get_model(self):
166+
"""Return the reference to the ONNXRuntime session."""
167+
return self.session
168+
165169
def reshape_model(self, new_shape):
166170
raise NotImplementedError
167171

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def is_ready(self):
8787
def load_model(self):
8888
pass
8989

90+
def get_model(self):
91+
"""Return the reference to the GrpcClient."""
92+
return self.client
93+
9094
def await_all(self):
9195
pass
9296

model_api/python/model_api/adapters/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import math
99
from functools import partial
10+
from typing import Callable, Optional
1011

1112
import cv2
1213
import numpy as np
@@ -509,7 +510,7 @@ def crop_resize_ocv(image, size):
509510
return cv2.resize(cropped_frame, size)
510511

511512

512-
RESIZE_TYPES = {
513+
RESIZE_TYPES: dict[str, Callable] = {
513514
"crop": crop_resize_ocv,
514515
"standard": resize_image_ocv,
515516
"fit_to_window": resize_image_with_aspect_ocv,

model_api/python/model_api/models/action_classification.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ def __init__(
6464
self.image_blob_names = self._get_inputs()
6565
self.image_blob_name = self.image_blob_names[0]
6666
self.nscthw_layout = "NSCTHW" in self.inputs[self.image_blob_name].layout
67+
self.labels: list[str]
68+
self.path_to_labels: str
69+
self.mean_values: list[int | float]
70+
self.pad_value: int
71+
self.resize_type: str
72+
self.reverse_input_channels: bool
73+
self.scale_values: list[int | float]
74+
6775
if self.nscthw_layout:
6876
self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[self.image_blob_name].shape
6977
else:
@@ -118,7 +126,7 @@ def parameters(cls) -> dict[str, Any]:
118126
)
119127
return parameters
120128

121-
def _get_inputs(self) -> tuple[list[str], list[str]]:
129+
def _get_inputs(self) -> list[str]:
122130
"""Defines the model inputs for images and additional info.
123131
124132
Raises:

model_api/python/model_api/models/anomaly.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import cv2
1515
import numpy as np
1616

17+
from model_api.adapters.inference_adapter import InferenceAdapter
18+
1719
from .image_model import ImageModel
1820
from .types import ListValue, NumericalValue, StringValue
1921
from .utils import AnomalyResult
@@ -22,7 +24,9 @@
2224
class AnomalyDetection(ImageModel):
2325
__model__ = "AnomalyDetection"
2426

25-
def __init__(self, inference_adapter, configuration=dict(), preload=False):
27+
def __init__(
28+
self, inference_adapter: InferenceAdapter, configuration: dict = dict(), preload: bool = False
29+
) -> None:
2630
super().__init__(inference_adapter, configuration, preload)
2731
self._check_io_number(1, 1)
2832
self.normalization_scale: float
@@ -31,7 +35,7 @@ def __init__(self, inference_adapter, configuration=dict(), preload=False):
3135
self.task: str
3236
self.labels: list[str]
3337

34-
def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]):
38+
def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> AnomalyResult:
3539
"""Post-processes the outputs and returns the results.
3640
3741
Args:

model_api/python/model_api/models/classification.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from openvino.runtime import Model, Type
1515
from openvino.runtime import opset10 as opset
1616

17+
from model_api.adapters.inference_adapter import InferenceAdapter
18+
1719
from .image_model import ImageModel
1820
from .types import BooleanValue, ListValue, NumericalValue, StringValue
1921
from .utils import ClassificationResult
@@ -22,13 +24,24 @@
2224
class ClassificationModel(ImageModel):
2325
__model__ = "Classification"
2426

25-
def __init__(self, inference_adapter, configuration=dict(), preload=False):
27+
def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = dict(), preload: bool = False):
2628
super().__init__(inference_adapter, configuration, preload=False)
29+
self.topk: int
30+
self.labels: list[str]
31+
self.path_to_labels: str
32+
self.multilabel: bool
33+
self.hierarchical: bool
34+
self.hierarchical_config: str
35+
self.confidence_threshold: float
36+
self.output_raw_scores: bool
37+
self.hierarchical_postproc: str
38+
self.labels_resolver: GreedyLabelsResolver | ProbabilisticLabelsResolver
39+
2740
self._check_io_number(1, (1, 2, 3, 4, 5))
2841
if self.path_to_labels:
2942
self.labels = self._load_labels(self.path_to_labels)
3043
if len(self.outputs) == 1:
31-
self._verify_signle_output()
44+
self._verify_single_output()
3245

3346
self.raw_scores_name = _raw_scores_name
3447
if self.hierarchical:
@@ -99,7 +112,7 @@ def _load_labels(self, labels_file):
99112
labels.append(s[(begin_idx + 1) : end_idx])
100113
return labels
101114

102-
def _verify_signle_output(self):
115+
def _verify_single_output(self):
103116
layer_name = next(iter(self.outputs))
104117
layer_shape = self.outputs[layer_name].shape
105118

@@ -197,7 +210,7 @@ def get_saliency_maps(self, outputs: dict) -> np.ndarray:
197210
if not self.hierarchical:
198211
return saliency_maps
199212

200-
reordered_saliency_maps = [[] for _ in range(len(saliency_maps))]
213+
reordered_saliency_maps: list[list[np.ndarray]] = [[] for _ in range(len(saliency_maps))]
201214
model_classes = self.hierarchical_info["cls_heads_info"]["class_to_group_idx"]
202215
label_to_model_out_idx = {lbl: i for i, lbl in enumerate(model_classes.keys())}
203216
for batch in range(len(saliency_maps)):
@@ -279,7 +292,7 @@ def get_multiclass_predictions(self, outputs):
279292
return list(zip(indicesTensor, labels, scoresTensor))
280293

281294

282-
def addOrFindSoftmaxAndTopkOutputs(inference_adapter, topk, output_raw_scores):
295+
def addOrFindSoftmaxAndTopkOutputs(inference_adapter: InferenceAdapter, topk: int, output_raw_scores: bool):
283296
softmaxNode = None
284297
for i in range(len(inference_adapter.model.outputs)):
285298
output_node = inference_adapter.model.get_output_op(i).input(0).get_source_output().get_node()

0 commit comments

Comments
 (0)