Skip to content

Commit 6c211be

Browse files
authored
refactor: model parameters definition (#439)
* refactor parameters * fix
1 parent d88f6d3 commit 6c211be

19 files changed

+543
-446
lines changed

src/model_api/models/action_classification.py

Lines changed: 15 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from model_api.models.result import ClassificationResult, Label
1414

1515
from .model import Model
16-
from .types import BooleanValue, ListValue, NumericalValue, StringValue
16+
from .parameters import ParameterRegistry
1717
from .utils import load_labels
1818

1919
if TYPE_CHECKING:
@@ -65,26 +65,19 @@ def __init__(
6565
self.image_blob_names = self._get_inputs()
6666
self.image_blob_name = self.image_blob_names[0]
6767
self.nscthw_layout = "NSCTHW" in self.inputs[self.image_blob_name].layout
68-
self.labels: list[str]
69-
self.path_to_labels: str
70-
self.mean_values: list[int | float]
71-
self.pad_value: int
72-
self.resize_type: str
73-
self.reverse_input_channels: bool
74-
self.scale_values: list[int | float]
7568

7669
if self.nscthw_layout:
7770
self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[self.image_blob_name].shape
7871
else:
7972
self.n, self.s, self.t, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape
80-
self.resize = RESIZE_TYPES[self.resize_type]
73+
self.resize = RESIZE_TYPES[self.params.resize_type]
8174
self.input_transform = InputTransform(
82-
self.reverse_input_channels,
83-
self.mean_values,
84-
self.scale_values,
75+
self.params.reverse_input_channels,
76+
self.params.mean_values,
77+
self.params.scale_values,
8578
)
86-
if self.path_to_labels:
87-
self.labels = load_labels(self.path_to_labels)
79+
if self.params.path_to_labels:
80+
self._labels = load_labels(self.params.path_to_labels)
8881

8982
@property
9083
def clip_size(self) -> int:
@@ -94,39 +87,11 @@ def clip_size(self) -> int:
9487
def parameters(cls) -> dict[str, Any]:
9588
parameters = super().parameters()
9689
parameters.update(
97-
{
98-
"labels": ListValue(description="List of class labels"),
99-
"path_to_labels": StringValue(
100-
description="Path to file with labels. Overrides the labels, if they sets via 'labels' parameter",
101-
),
102-
"mean_values": ListValue(
103-
description=(
104-
"Normalization values, which will be subtracted from image channels "
105-
"for image-input layer during preprocessing"
106-
),
107-
default_value=[],
108-
),
109-
"pad_value": NumericalValue(
110-
int,
111-
min=0,
112-
max=255,
113-
description="Pad value for resize_image_letterbox embedded into a model",
114-
default_value=0,
115-
),
116-
"resize_type": StringValue(
117-
default_value="standard",
118-
choices=tuple(RESIZE_TYPES.keys()),
119-
description="Type of input image resizing",
120-
),
121-
"reverse_input_channels": BooleanValue(
122-
default_value=False,
123-
description="Reverse the input channel order",
124-
),
125-
"scale_values": ListValue(
126-
default_value=[],
127-
description="Normalization values, which will divide the image channels for image-input layer",
128-
),
129-
},
90+
ParameterRegistry.merge(
91+
ParameterRegistry.LABELS,
92+
ParameterRegistry.IMAGE_RESIZE,
93+
ParameterRegistry.IMAGE_PREPROCESSING,
94+
),
13095
)
13196
return parameters
13297

@@ -193,7 +158,7 @@ def preprocess(
193158
"original_shape": inputs.shape,
194159
"resized_shape": (self.n, self.s, self.c, self.t, self.h, self.w),
195160
}
196-
resized_inputs = [self.resize(frame, (self.w, self.h), pad_value=self.pad_value) for frame in inputs]
161+
resized_inputs = [self.resize(frame, (self.w, self.h), pad_value=self.params.pad_value) for frame in inputs]
197162
np_frames = self._change_layout(
198163
[self.input_transform(inputs) for inputs in resized_inputs],
199164
)
@@ -222,8 +187,9 @@ def postprocess(
222187
"""Post-process."""
223188
logits = next(iter(outputs.values())).squeeze()
224189
index = np.argmax(logits)
190+
labels = self.params.labels
225191
return ClassificationResult(
226-
[Label(int(index), self.labels[index], logits[index])],
192+
[Label(int(index), labels[index], logits[index])],
227193
np.ndarray(0),
228194
np.ndarray(0),
229195
np.ndarray(0),

src/model_api/models/anomaly.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import numpy as np
1616

1717
from model_api.models.image_model import ImageModel
18+
from model_api.models.parameters import ParameterRegistry
1819
from model_api.models.result import AnomalyResult
19-
from model_api.models.types import ListValue, NumericalValue, StringValue
2020

2121
if TYPE_CHECKING:
2222
from model_api.adapters.inference_adapter import InferenceAdapter
@@ -67,11 +67,6 @@ def __init__(
6767
) -> None:
6868
super().__init__(inference_adapter, configuration, preload)
6969
self._check_io_number(1, (1, 4))
70-
self.normalization_scale: float
71-
self.image_threshold: float
72-
self.pixel_threshold: float
73-
self.task: str
74-
self.labels: list[str]
7570

7671
def preprocess(self, inputs: np.ndarray) -> list[dict]:
7772
"""Data preprocess method for Anomalib models.
@@ -103,7 +98,7 @@ def preprocess(self, inputs: np.ndarray) -> list[dict]:
10398
else:
10499
resized_shape = (self.w, self.h, self.c)
105100
# For fixed models, use standard preprocessing
106-
if self.embedded_processing:
101+
if self.params.embedded_processing:
107102
processed_image = inputs[None]
108103
else:
109104
# Resize image to expected model input dimensions
@@ -148,16 +143,17 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
148143
anomaly_map = predictions.squeeze()
149144
npred_score = anomaly_map.reshape(-1).max()
150145

151-
pred_label = self.labels[1] if npred_score > self.image_threshold else self.labels[0]
146+
labels_list = self.params.labels
147+
pred_label = labels_list[1] if npred_score > self.params.image_threshold else labels_list[0]
152148

153149
assert anomaly_map is not None
154-
pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8)
155-
anomaly_map = self._normalize(anomaly_map, self.pixel_threshold)
150+
pred_mask = (anomaly_map >= self.params.pixel_threshold).astype(np.uint8)
151+
anomaly_map = self._normalize(anomaly_map, self.params.pixel_threshold)
156152

157153
# normalize
158-
npred_score = self._normalize(npred_score, self.image_threshold)
154+
npred_score = self._normalize(npred_score, self.params.image_threshold)
159155

160-
if pred_label == self.labels[0]: # normal
156+
if pred_label == labels_list[0]: # normal
161157
npred_score = 1 - npred_score # Score of normal is 1 - score of anomaly
162158
pred_score = npred_score.item()
163159
else:
@@ -180,7 +176,7 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
180176
(meta["original_shape"][1], meta["original_shape"][0]),
181177
)
182178

183-
if self.task == "detection":
179+
if self.params.task == "detection":
184180
pred_boxes = self._get_boxes(pred_mask)
185181

186182
return AnomalyResult(
@@ -194,33 +190,13 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
194190
@classmethod
195191
def parameters(cls) -> dict:
196192
parameters = super().parameters()
197-
parameters.update(
198-
{
199-
"image_threshold": NumericalValue(
200-
description="Image threshold",
201-
min=0.0,
202-
default_value=0.5,
203-
),
204-
"pixel_threshold": NumericalValue(
205-
description="Pixel threshold",
206-
min=0.0,
207-
default_value=0.5,
208-
),
209-
"normalization_scale": NumericalValue(
210-
description="Value used for normalization",
211-
),
212-
"task": StringValue(
213-
description="Task type",
214-
default_value="segmentation",
215-
),
216-
"labels": ListValue(description="List of class labels", value_type=str),
217-
},
218-
)
193+
parameters.update(ParameterRegistry.ANOMALY)
194+
parameters.update(ParameterRegistry.LABELS)
219195
return parameters
220196

221197
def _normalize(self, tensor: np.ndarray, threshold: float) -> np.ndarray:
222198
"""Currently supports only min-max normalization."""
223-
normalized = ((tensor - threshold) / self.normalization_scale) + 0.5
199+
normalized = ((tensor - threshold) / self.params.normalization_scale) + 0.5
224200
return np.clip(normalized, 0, 1)
225201

226202
@staticmethod

0 commit comments

Comments
 (0)