1313from model_api .models .result import ClassificationResult , Label
1414
1515from .model import Model
16- from .types import BooleanValue , ListValue , NumericalValue , StringValue
16+ from .parameters import ParameterRegistry
1717from .utils import load_labels
1818
1919if TYPE_CHECKING :
@@ -65,26 +65,19 @@ def __init__(
6565 self .image_blob_names = self ._get_inputs ()
6666 self .image_blob_name = self .image_blob_names [0 ]
6767 self .nscthw_layout = "NSCTHW" in self .inputs [self .image_blob_name ].layout
68- self .labels : list [str ]
69- self .path_to_labels : str
70- self .mean_values : list [int | float ]
71- self .pad_value : int
72- self .resize_type : str
73- self .reverse_input_channels : bool
74- self .scale_values : list [int | float ]
7568
7669 if self .nscthw_layout :
7770 self .n , self .s , self .c , self .t , self .h , self .w = self .inputs [self .image_blob_name ].shape
7871 else :
7972 self .n , self .s , self .t , self .h , self .w , self .c = self .inputs [self .image_blob_name ].shape
80- self .resize = RESIZE_TYPES [self .resize_type ]
73+ self .resize = RESIZE_TYPES [self .params . resize_type ]
8174 self .input_transform = InputTransform (
82- self .reverse_input_channels ,
83- self .mean_values ,
84- self .scale_values ,
75+ self .params . reverse_input_channels ,
76+ self .params . mean_values ,
77+ self .params . scale_values ,
8578 )
86- if self .path_to_labels :
87- self .labels = load_labels (self .path_to_labels )
79+ if self .params . path_to_labels :
80+ self ._labels = load_labels (self . params .path_to_labels )
8881
8982 @property
9083 def clip_size (self ) -> int :
@@ -94,39 +87,11 @@ def clip_size(self) -> int:
9487 def parameters (cls ) -> dict [str , Any ]:
9588 parameters = super ().parameters ()
9689 parameters .update (
97- {
98- "labels" : ListValue (description = "List of class labels" ),
99- "path_to_labels" : StringValue (
100- description = "Path to file with labels. Overrides the labels, if they sets via 'labels' parameter" ,
101- ),
102- "mean_values" : ListValue (
103- description = (
104- "Normalization values, which will be subtracted from image channels "
105- "for image-input layer during preprocessing"
106- ),
107- default_value = [],
108- ),
109- "pad_value" : NumericalValue (
110- int ,
111- min = 0 ,
112- max = 255 ,
113- description = "Pad value for resize_image_letterbox embedded into a model" ,
114- default_value = 0 ,
115- ),
116- "resize_type" : StringValue (
117- default_value = "standard" ,
118- choices = tuple (RESIZE_TYPES .keys ()),
119- description = "Type of input image resizing" ,
120- ),
121- "reverse_input_channels" : BooleanValue (
122- default_value = False ,
123- description = "Reverse the input channel order" ,
124- ),
125- "scale_values" : ListValue (
126- default_value = [],
127- description = "Normalization values, which will divide the image channels for image-input layer" ,
128- ),
129- },
90+ ParameterRegistry .merge (
91+ ParameterRegistry .LABELS ,
92+ ParameterRegistry .IMAGE_RESIZE ,
93+ ParameterRegistry .IMAGE_PREPROCESSING ,
94+ ),
13095 )
13196 return parameters
13297
@@ -193,7 +158,7 @@ def preprocess(
193158 "original_shape" : inputs .shape ,
194159 "resized_shape" : (self .n , self .s , self .c , self .t , self .h , self .w ),
195160 }
196- resized_inputs = [self .resize (frame , (self .w , self .h ), pad_value = self .pad_value ) for frame in inputs ]
161+ resized_inputs = [self .resize (frame , (self .w , self .h ), pad_value = self .params . pad_value ) for frame in inputs ]
197162 np_frames = self ._change_layout (
198163 [self .input_transform (inputs ) for inputs in resized_inputs ],
199164 )
@@ -222,8 +187,9 @@ def postprocess(
222187 """Post-process."""
223188 logits = next (iter (outputs .values ())).squeeze ()
224189 index = np .argmax (logits )
190+ labels = self .params .labels
225191 return ClassificationResult (
226- [Label (int (index ), self . labels [index ], logits [index ])],
192+ [Label (int (index ), labels [index ], logits [index ])],
227193 np .ndarray (0 ),
228194 np .ndarray (0 ),
229195 np .ndarray (0 ),
0 commit comments