Skip to content

Commit d5068c3

Browse files
Model API improvements (#2900)
* Added factory method * Add configurable model parameters * Fix warnings * Factory improvement - only non-abstract wrappers are available * Factory create object now * Detection demo refactoring * Segmentation demo - applied factory and config * Fixes after rebase * Add configurable parameters to Bert* models and Bert* demos * Update deblurring wrapper and demo * Add WrapperError * Update configurable values error handling * Add description for configurable values (some of them) * Add WrapperError * Fix remarks * Fix errors * Update HPE demo and models * Fixes * Add preload option to model's ctors * Apply remarks * Fix * Fix * Fix error with pose-estimation wrappers * Update classification demo and wrapper * Fix Co-authored-by: Anzhella Pankratova <[email protected]>
1 parent ee811e1 commit d5068c3

File tree

28 files changed

+692
-269
lines changed

28 files changed

+692
-269
lines changed

demos/bert_named_entity_recognition_demo/python/bert_named_entity_recognition_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def main():
114114
serving_config = {"address": "localhost", "port": 9000}
115115
model_adapter = RemoteAdapter(args.model, serving_config)
116116

117-
model = BertNamedEntityRecognition(model_adapter, vocab, args.input_names)
117+
model = BertNamedEntityRecognition(model_adapter, {'vocab': vocab, 'input_names': args.input_names})
118118
if max_sentence_length > model.max_length:
119119
model.reshape(max_sentence_length)
120120
model.log_layers_info()

demos/bert_question_answering_demo/python/bert_question_answering_demo.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,14 @@ def main():
172172
serving_config = {"address": "localhost", "port": 9000}
173173
model_adapter = RemoteAdapter(args.model, serving_config)
174174

175-
model = BertQuestionAnswering(model_adapter, vocab, args.input_names, args.output_names,
176-
args.max_answer_token_num, args.model_squad_ver)
175+
config = {
176+
'vocab': vocab,
177+
'input_names': args.input_names,
178+
'output_names': args.output_names,
179+
'max_answer_token_num': args.max_answer_token_num,
180+
'squad_ver': args.model_squad_ver
181+
}
182+
model = BertQuestionAnswering(model_adapter, config)
177183
if args.reshape:
178184
# find the closest multiple of 64, if it is smaller than current network's sequence length, do reshape
179185
new_length = min(model.max_length, int(np.ceil((len(c_tokens[0]) + args.max_question_token_num) / 64) * 64))

demos/bert_question_answering_embedding_demo/python/bert_question_answering_embedding_demo.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def main():
172172
plugin_config = get_user_config(args.device, args.num_streams, args.num_threads)
173173
model_emb_adapter = OpenvinoAdapter(ie, args.model_emb, device=args.device, plugin_config=plugin_config,
174174
max_num_requests=args.num_infer_requests)
175-
model_emb = BertEmbedding(model_emb_adapter, vocab, args.input_names_emb)
175+
model_emb = BertEmbedding(model_emb_adapter, {'vocab': vocab, 'input_names': args.input_names_emb})
176176
model_emb.log_layers_info()
177177

178178
# reshape BertEmbedding model to infer short questions and long contexts
@@ -189,8 +189,14 @@ def main():
189189
if args.model_qa:
190190
model_qa_adapter = OpenvinoAdapter(ie, args.model_qa, device=args.device, plugin_config=plugin_config,
191191
max_num_requests=args.num_infer_requests)
192-
model_qa = BertQuestionAnswering(model_qa_adapter, vocab, args.input_names_qa, args.output_names_qa,
193-
args.max_answer_token_num, args.model_qa_squad_ver)
192+
config = {
193+
'vocab': vocab,
194+
'input_names': args.input_names_qa,
195+
'output_names': args.output_names_qa,
196+
'max_answer_token_num': args.max_answer_token_num,
197+
'squad_ver': args.model_qa_squad_ver
198+
}
199+
model_qa = BertQuestionAnswering(model_qa_adapter, config)
194200
model_qa.log_layers_info()
195201
qa_pipeline = AsyncPipeline(model_qa)
196202

demos/classification_demo/python/classification_demo.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
sys.path.append(str(Path(__file__).resolve().parents[2] / 'common/python'))
2727
sys.path.append(str(Path(__file__).resolve().parents[2] / 'common/python/openvino/model_zoo'))
2828

29-
from model_api import models
29+
from model_api.models import Classification, OutputTransform
3030
from model_api.performance_metrics import put_highlighted_text, PerformanceMetrics
3131
from model_api.pipelines import get_user_config, AsyncPipeline
3232
from model_api.adapters import create_core, OpenvinoAdapter, RemoteAdapter
@@ -55,7 +55,7 @@ def build_argparser():
5555
'Default value is CPU.')
5656

5757
common_model_args = parser.add_argument_group('Common model options')
58-
common_model_args.add_argument('--labels', help='Optional. Labels mapping file.', default=None, type=Path)
58+
common_model_args.add_argument('--labels', help='Optional. Labels mapping file.', default=None, type=str)
5959
common_model_args.add_argument('-topk', help='Optional. Number of top results. Default value is 5. Must be from 1 to 10.', default=5,
6060
type=int, choices=range(1, 11))
6161

@@ -166,8 +166,14 @@ def main():
166166
serving_config = {"address": "localhost", "port": 9000}
167167
model_adapter = RemoteAdapter(args.model, serving_config)
168168

169-
model = models.Classification(model_adapter, topk=args.topk, labels=args.labels)
170-
model.set_inputs_preprocessing(args.reverse_input_channels, args.mean_values, args.scale_values)
169+
config = {
170+
'mean_values': args.mean_values,
171+
'scale_values': args.scale_values,
172+
'reverse_input_channels': args.reverse_input_channels,
173+
'topk': args.topk,
174+
'path_to_labels': args.labels
175+
}
176+
model = Classification(model_adapter, config)
171177
model.log_layers_info()
172178

173179
async_pipeline = AsyncPipeline(model)
@@ -223,7 +229,7 @@ def main():
223229
raise ValueError("Can't read an image from the input")
224230
break
225231
if next_frame_id == 0:
226-
output_transform = models.OutputTransform(frame.shape[:2], args.output_resolution)
232+
output_transform = OutputTransform(frame.shape[:2], args.output_resolution)
227233
if args.output_resolution:
228234
output_resolution = output_transform.new_resolution
229235
else:

demos/common/python/openvino/model_zoo/model_api/models/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
from .centernet import CenterNet
2020
from .classification import Classification
2121
from .deblurring import Deblurring
22+
from .detection_model import DetectionModel
2223
from .detr import DETR
2324
from .ctpn import CTPN
2425
from .faceboxes import FaceBoxes
2526
from .hpe_associative_embedding import HpeAssociativeEmbedding
27+
from .image_model import ImageModel
28+
from .model import Model
2629
from .monodepth import MonoDepthModel
2730
from .open_pose import OpenPose
2831
from .retinaface import RetinaFace, RetinaFacePyTorch
@@ -39,12 +42,15 @@
3942
'CenterNet',
4043
'Classification',
4144
'CTPN',
42-
'DetectionWithLandmarks',
4345
'Deblurring',
46+
'DetectionModel',
47+
'DetectionWithLandmarks',
4448
'DETR',
4549
'FaceBoxes',
4650
'HpeAssociativeEmbedding',
51+
'ImageModel',
4752
'InputTransform',
53+
'Model',
4854
'MonoDepthModel',
4955
'OpenPose',
5056
'OutputTransform',

demos/common/python/openvino/model_zoo/model_api/models/bert.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,33 @@
1313

1414
import numpy as np
1515

16-
from .model import Model
16+
from .model import Model, WrapperError
17+
from .types import DictValue, NumericalValue, StringValue
1718

1819

1920
class Bert(Model):
20-
def __init__(self, model_adapter, vocab, input_names):
21-
super().__init__(model_adapter)
22-
self.token_cls = [vocab['[CLS]']]
23-
self.token_sep = [vocab['[SEP]']]
24-
self.token_pad = [vocab['[PAD]']]
25-
self.input_names = [i.strip() for i in input_names.split(',')]
21+
__model__ = 'bert'
22+
23+
def __init__(self, model_adapter, configuration, preload=False):
24+
super().__init__(model_adapter, configuration, preload)
25+
self.token_cls = [self.vocab['[CLS]']]
26+
self.token_sep = [self.vocab['[SEP]']]
27+
self.token_pad = [self.vocab['[PAD]']]
28+
self.input_names = [i.strip() for i in self.input_names.split(',')]
2629
if self.inputs.keys() != set(self.input_names):
27-
raise RuntimeError('The Bert model expects input names: {}, actual network input names: {}'.format(
30+
raise WrapperError(self.__model__, 'The Wrapper expects input names: {}, actual network input names: {}'.format(
2831
self.input_names, list(self.inputs.keys())))
2932
self.max_length = self.inputs[self.input_names[0]].shape[1]
3033

34+
@classmethod
35+
def parameters(cls):
36+
parameters = super().parameters()
37+
parameters.update({
38+
'vocab': DictValue(),
39+
'input_names': StringValue(description='Comma-separated names of input layers'),
40+
})
41+
return parameters
42+
3143
def preprocess(self, inputs):
3244
input_ids, attention_mask, token_type_ids = self.form_request(inputs)
3345

@@ -71,12 +83,13 @@ def reshape(self, new_length):
7183

7284

7385
class BertNamedEntityRecognition(Bert):
74-
def __init__(self, model_adapter, vocab, input_names):
75-
super().__init__(model_adapter, vocab, input_names)
86+
__model__ = 'bert-named-entity-recognition'
87+
88+
def __init__(self, model_adapter, configuration, preload=False):
89+
super().__init__(model_adapter, configuration, preload)
7690

7791
self.output_names = list(self.outputs)
78-
if len(self.output_names) != 1:
79-
raise RuntimeError("The BertNamedEntityRecognition model wrapper supports only 1 output")
92+
self._check_io_number(-1, 1)
8093

8194
def form_request(self, inputs):
8295
c_tokens_id = inputs
@@ -99,12 +112,13 @@ def postprocess(self, outputs, meta):
99112

100113

101114
class BertEmbedding(Bert):
102-
def __init__(self, model_adapter, vocab, input_names):
103-
super().__init__(model_adapter, vocab, input_names)
115+
__model__ = 'bert-embedding'
116+
117+
def __init__(self, model_adapter, configuration, preload=False):
118+
super().__init__(model_adapter, configuration, preload)
104119

105120
self.output_names = list(self.outputs)
106-
if len(self.output_names) != 1:
107-
raise RuntimeError("The BertEmbedding model wrapper supports only 1 output")
121+
self._check_io_number(-1, 1)
108122

109123
def form_request(self, inputs):
110124
tokens_id, self.max_length = inputs
@@ -119,17 +133,26 @@ def postprocess(self, outputs, meta):
119133

120134

121135
class BertQuestionAnswering(Bert):
122-
def __init__(self, model_adapter, vocab, input_names, output_names,
123-
max_answer_token_num, squad_ver):
124-
super().__init__(model_adapter, vocab, input_names)
136+
__model__ = 'bert-question-answering'
125137

126-
self.max_answer_token_num = max_answer_token_num
127-
self.squad_ver = squad_ver
128-
self.output_names = [o.strip() for o in output_names.split(',')]
138+
def __init__(self, model_adapter, configuration, preload=False):
139+
super().__init__(model_adapter, configuration, preload)
140+
141+
self.output_names = [o.strip() for o in self.output_names.split(',')]
129142
if self.outputs.keys() != set(self.output_names):
130-
raise RuntimeError('The BertQuestionAnswering model output names: {}, actual network output names: {}'.format(
143+
raise WrapperError(self.__model__, 'The Wrapper output names: {}, actual network output names: {}'.format(
131144
self.output_names, list(self.outputs.keys())))
132145

146+
@classmethod
147+
def parameters(cls):
148+
parameters = super().parameters()
149+
parameters.update({
150+
'output_names': StringValue(description='Comma-separated names of output layers'),
151+
'max_answer_token_num': NumericalValue(value_type=int),
152+
'squad_ver': StringValue(),
153+
})
154+
return parameters
155+
133156
def form_request(self, inputs):
134157
c_data, q_tokens_id = inputs
135158
input_ids = self.token_cls + q_tokens_id + self.token_sep + c_data.c_tokens_id + self.token_sep

demos/common/python/openvino/model_zoo/model_api/models/centernet.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@
2323

2424

2525
class CenterNet(DetectionModel):
26-
def __init__(self, model_adapter, resize_type=None,
27-
labels=None, threshold=0.5, iou_threshold=0.5):
28-
if not resize_type:
29-
resize_type = 'standard'
30-
super().__init__(model_adapter, resize_type=resize_type,
31-
labels=labels, threshold=threshold, iou_threshold=iou_threshold)
26+
__model__ = 'centernet'
27+
28+
def __init__(self, model_adapter, configuration=None, preload=False):
29+
super().__init__(model_adapter, configuration, preload)
3230
self._check_io_number(1, 3)
3331
self._output_layer_names = sorted(self.outputs)
3432

33+
@classmethod
34+
def parameters(cls):
35+
parameters = super().parameters()
36+
parameters['resize_type'].update_default_value('standard')
37+
return parameters
38+
3539
def postprocess(self, outputs, meta):
3640
heat = outputs[self._output_layer_names[0]][0]
3741
reg = outputs[self._output_layer_names[1]][0]

demos/common/python/openvino/model_zoo/model_api/models/classification.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@
1515
"""
1616

1717
import numpy as np
18+
19+
from .types import NumericalValue, ListValue, StringValue
1820
from .utils import softmax
1921

2022
from .image_model import ImageModel
2123

2224

2325
class Classification(ImageModel):
24-
def __init__(self, model_adapter, topk = 1, labels = None, resize_type='crop'):
25-
super().__init__(model_adapter, resize_type=resize_type)
26+
def __init__(self, model_adapter, configuration=None, preload=False):
27+
super().__init__(model_adapter, configuration, preload)
2628
self._check_io_number(1, 1)
27-
self.topk = topk
28-
if isinstance(labels, (list, tuple)):
29-
self.labels = labels
30-
else:
31-
self.labels = self._load_labels(labels) if labels else None
29+
if self.path_to_labels:
30+
self.labels = self._load_labels(self.path_to_labels)
3231
self.out_layer_name = self._get_outputs()
3332

3433
@staticmethod
@@ -61,6 +60,19 @@ def _get_outputs(self):
6160
'labels must match ({} != {})'.format(layer_shape[1], len(self.labels)))
6261
return layer_name
6362

63+
@classmethod
64+
def parameters(cls):
65+
parameters = super().parameters()
66+
parameters['resize_type'].update_default_value('crop')
67+
parameters.update({
68+
'topk': NumericalValue(value_type=int, default_value=1, min=1),
69+
'labels': ListValue(description="List of class labels"),
70+
'path_to_labels': StringValue(
71+
description="Path to file with labels. Overrides the labels, if they sets via 'labels' parameter"
72+
),
73+
})
74+
return parameters
75+
6476
def postprocess(self, outputs, meta):
6577
outputs = outputs[self.out_layer_name].squeeze()
6678
indices = np.argpartition(outputs, -self.topk)[-self.topk:]

demos/common/python/openvino/model_zoo/model_api/models/ctpn.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
import cv2
1818
import numpy as np
1919

20+
from .model import WrapperError
2021
from .detection_model import DetectionModel
22+
from .types import ListValue, NumericalValue
2123
from .utils import Detection, nms, clip_detections
2224

2325

2426
class CTPN(DetectionModel):
25-
def __init__(self, model_adapter, input_size, threshold=0.9, iou_threshold=0.5):
26-
super().__init__(model_adapter, labels=['Text'],
27-
threshold=threshold, iou_threshold=iou_threshold)
27+
__model__ = 'CTPN'
28+
29+
def __init__(self, model_adapter, configuration=None, preload=False):
30+
super().__init__(model_adapter, configuration, False)
2831
self._check_io_number(1, 2)
2932
self.bboxes_blob_name, self.scores_blob_name = self._get_outputs()
3033

@@ -48,24 +51,37 @@ def __init__(self, model_adapter, input_size, threshold=0.9, iou_threshold=0.5):
4851
[0, -134, 15, 149]
4952
])
5053

51-
self.h1, self.w1 = self.ctpn_keep_aspect_ratio(1200, 600, input_size[1], input_size[0])
54+
self.h1, self.w1 = self.ctpn_keep_aspect_ratio(1200, 600, self.input_size[1], self.input_size[0])
5255
self.h2, self.w2 = self.ctpn_keep_aspect_ratio(600, 600, self.w1, self.h1)
5356
default_input_shape = self.inputs[self.image_blob_name].shape
5457
input_shape = {self.image_blob_name: (default_input_shape[:-2] + [self.h2, self.w2])}
5558
self.logger.debug('\tReshape model from {} to {}'.format(default_input_shape, input_shape[self.image_blob_name]))
5659
self.reshape(input_shape)
60+
if preload:
61+
self.load()
5762

5863
def _get_outputs(self):
5964
(boxes_name, boxes_data_repr), (scores_name, scores_data_repr) = self.outputs.items()
6065

6166
if len(boxes_data_repr.shape) != 4 or len(scores_data_repr.shape) != 4:
62-
raise RuntimeError("Unexpected output blob shape. Only 4D output blobs are supported")
67+
raise WrapperError(self.__model__, "Unexpected output blob shape. Only 4D output blobs are supported")
6368

6469
if scores_data_repr.shape[1] == boxes_data_repr.shape[1] * 2:
6570
return scores_name, boxes_name
6671
if boxes_data_repr.shape[1] == scores_data_repr.shape[1] * 2:
6772
return boxes_name, scores_name
68-
raise RuntimeError("One of outputs must be two times larger than another for the CTPN topology")
73+
raise WrapperError(self.__model__, "One of outputs must be two times larger than another")
74+
75+
@classmethod
76+
def parameters(cls):
77+
parameters = super().parameters()
78+
parameters.update({
79+
'iou_threshold': NumericalValue(default_value=0.5, description="Threshold for NMS filtering"),
80+
'input_size': ListValue()
81+
})
82+
parameters['threshold'].update_default_value(0.9)
83+
parameters['labels'].update_default_value(['Text'])
84+
return parameters
6985

7086
def preprocess(self, inputs):
7187
meta = {'original_shape': inputs.shape}

0 commit comments

Comments
 (0)