Skip to content

Commit 42db1f8

Browse files
author
Anna Grebneva
authored
AC: Refactoring mtcnn models and init in cascade models (#2952)
* Minor refactoring * Refactoring mtcnn * Aligned init in cascade models
1 parent 3dff421 commit 42db1f8

19 files changed

+564
-997
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/asr_custom_encoder_decoder_joint.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .asr_encoder_prediction_joint_evaluator import ASREvaluator
2020
from .base_models import create_model, BaseCascadeModel, BaseDLSDKModel, BaseONNXModel, BaseOpenVINOModel
2121
from ...adapters import create_adapter
22-
from ...utils import generate_layer_name, contains_all, contains_any
22+
from ...utils import generate_layer_name, contains_all
2323
from ...config import ConfigError
2424

2525

@@ -382,22 +382,10 @@ class ASRModel(BaseCascadeModel):
382382

383383
def __init__(self, network_info, adapter_config, launcher, models_args, is_blob, delayed_model_loading=False):
384384
super().__init__(network_info, launcher)
385-
if models_args and not delayed_model_loading:
386-
encoder = network_info.get('encoder', {})
387-
decoder = network_info.get('decoder', {})
388-
joint = network_info.get('joint', {})
389-
if not contains_any(encoder, ['model', 'onnx_model']) and models_args:
390-
encoder['model'] = models_args[0]
391-
encoder['_model_is_blob'] = is_blob
392-
if not contains_any(decoder, ['model', 'onnx_model']) and models_args:
393-
decoder['model'] = models_args[1 if len(models_args) > 1 else 0]
394-
decoder['_model_is_blob'] = is_blob
395-
if not contains_any(joint, ['model', 'onnx_model']) and models_args:
396-
joint['model'] = models_args[2 if len(models_args) > 2 else 0]
397-
joint['_model_is_blob'] = is_blob
398-
network_info.update({'encoder': encoder, 'decoder': decoder, 'joint': joint})
399-
if not contains_all(network_info, ['encoder', 'decoder', 'joint']) and not delayed_model_loading:
400-
raise ConfigError('network_info should contain encoder, prediction and joint fields')
385+
parts = ['encoder', 'decoder', 'joint']
386+
network_info = self.fill_part_with_model(network_info, parts, models_args, is_blob, delayed_model_loading)
387+
if not contains_all(network_info, parts) and not delayed_model_loading:
388+
raise ConfigError('network_info should contain encoder, decoder and joint fields')
401389
self._decoder_mapping = {
402390
'dlsdk': DLSDKDecoder,
403391
'openvino': OVDecoder,

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/asr_encoder_decoder_evaluator.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
create_model, create_encoder)
2727
from ...adapters import create_adapter
2828
from ...config import ConfigError
29-
from ...utils import contains_all, contains_any, extract_image_representations, read_pickle
29+
from ...utils import contains_all, extract_image_representations, read_pickle
3030

3131

3232
class AutomaticSpeechRecognitionEvaluator(BaseCustomEvaluator):
@@ -67,17 +67,9 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
6767
class ASRModel(BaseCascadeModel):
6868
def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_loading=False):
6969
super().__init__(network_info, launcher)
70-
if models_args and not delayed_model_loading:
71-
encoder = network_info.get('encoder', {})
72-
decoder = network_info.get('decoder', {})
73-
if not contains_any(encoder, ['model', 'onnx_model']) and models_args:
74-
encoder['model'] = models_args[0]
75-
encoder['_model_is_blob'] = is_blob
76-
if not contains_any(decoder, ['model', 'onnx_model']) and models_args:
77-
decoder['model'] = models_args[1 if len(models_args) > 1 else 0]
78-
decoder['_model_is_blob'] = is_blob
79-
network_info.update({'encoder': encoder, 'decoder': decoder})
80-
if not contains_all(network_info, ['encoder', 'decoder']) and not delayed_model_loading:
70+
parts = ['encoder', 'decoder']
71+
network_info = self.fill_part_with_model(network_info, parts, models_args, is_blob, delayed_model_loading)
72+
if not contains_all(network_info, parts) and not delayed_model_loading:
8173
raise ConfigError('network_info should contain encoder and decoder fields')
8274
self._decoder_mapping = {
8375
'dlsdk': DecoderDLSDKModel,

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/asr_encoder_prediction_joint_evaluator.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...adapters import create_adapter
2323
from ...config import ConfigError
24-
from ...utils import contains_all, contains_any, read_pickle, parse_partial_shape
24+
from ...utils import contains_all, read_pickle, parse_partial_shape
2525
from .asr_encoder_decoder_evaluator import AutomaticSpeechRecognitionEvaluator
2626
from .base_models import (
2727
BaseCascadeModel, BaseDLSDKModel, BaseOpenVINOModel, BaseONNXModel, create_model, create_encoder
@@ -43,21 +43,9 @@ def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
4343
class ASRModel(BaseCascadeModel):
4444
def __init__(self, network_info, launcher, models_args, is_blob, adapter_info, delayed_model_loading=False):
4545
super().__init__(network_info, launcher)
46-
if models_args and not delayed_model_loading:
47-
encoder = network_info.get('encoder', {})
48-
prediction = network_info.get('prediction', {})
49-
joint = network_info.get('joint', {})
50-
if not contains_any(encoder, ['model', 'onnx_model']) and models_args:
51-
encoder['model'] = models_args[0]
52-
encoder['_model_is_blob'] = is_blob
53-
if not contains_any(prediction, ['model', 'onnx_model']) and models_args:
54-
prediction['model'] = models_args[1 if len(models_args) > 1 else 0]
55-
prediction['_model_is_blob'] = is_blob
56-
if not contains_any(joint, ['model', 'onnx_model']) and models_args:
57-
joint['model'] = models_args[2 if len(models_args) > 2 else 0]
58-
joint['_model_is_blob'] = is_blob
59-
network_info.update({'encoder': encoder, 'prediction': prediction, 'joint': joint})
60-
if not contains_all(network_info, ['encoder', 'prediction', 'joint']) and not delayed_model_loading:
46+
parts = ['encoder', 'prediction', 'joint']
47+
network_info = self.fill_part_with_model(network_info, parts, models_args, is_blob, delayed_model_loading)
48+
if not contains_all(network_info, parts) and not delayed_model_loading:
6149
raise ConfigError('network_info should contain encoder, prediction and joint fields')
6250
self._encoder_mapping = {
6351
'dlsdk': EncoderDLSDKModel,

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/base_models.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919

2020
from ...config import ConfigError
21-
from ...utils import get_path, parse_partial_shape
21+
from ...utils import get_path, parse_partial_shape, contains_any
2222
from ...logging import print_info
2323

2424

@@ -70,6 +70,17 @@ def get_network(self):
7070
def reset(self):
7171
pass
7272

73+
@staticmethod
74+
def fill_part_with_model(network_info, parts, models_args, is_blob, delayed_model_loading):
75+
if models_args and not delayed_model_loading:
76+
for idx, part in enumerate(parts):
77+
part_info = network_info.get(part, {})
78+
if not contains_any(part_info, ['model', 'onnx_model']) and models_args:
79+
part_info['model'] = models_args[idx if len(models_args) > idx else 0]
80+
part_info['_model_is_blob'] = is_blob
81+
network_info.update({part: part_info})
82+
return network_info
83+
7384

7485
class BaseDLSDKModel:
7586
def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=False):
@@ -415,3 +426,52 @@ def release(self):
415426
def automatic_model_search(network_info):
416427
model = Path(network_info['model'])
417428
return model
429+
430+
431+
class BaseCaffeModel:
432+
def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=False):
433+
self.network_info = network_info
434+
self.launcher = launcher
435+
self.default_model_suffix = suffix
436+
437+
def fit_to_input(self, data, layer_name, layout, precision, tmpl=None):
438+
return self.launcher.fit_to_input(data, layer_name, layout, precision, template=tmpl)
439+
440+
def predict(self, identifiers, input_data):
441+
raise NotImplementedError
442+
443+
def release(self):
444+
del self.net
445+
446+
def automatic_model_search(self, network_info):
447+
model = Path(network_info.get('model', ''))
448+
weights = network_info.get('weights')
449+
if model.is_dir():
450+
models_list = list(Path(model).glob('{}.prototxt'.format(self.default_model_name)))
451+
if not models_list:
452+
models_list = list(Path(model).glob('*.prototxt'))
453+
if not models_list:
454+
raise ConfigError('Suitable model description is not detected')
455+
if len(models_list) != 1:
456+
raise ConfigError('Several suitable models found, please specify required model')
457+
model = models_list[0]
458+
if weights is None or Path(weights).is_dir():
459+
weights_dir = weights or model.parent
460+
weights = Path(weights_dir) / model.name.replace('prototxt', 'caffemodel')
461+
if not weights.exists():
462+
weights_list = list(weights_dir.glob('*.caffemodel'))
463+
if not weights_list:
464+
raise ConfigError('Suitable weights is not detected')
465+
if len(weights_list) != 1:
466+
raise ConfigError('Several suitable weights found, please specify required explicitly')
467+
weights = weights_list[0]
468+
weights = Path(weights)
469+
accepted_suffixes = ['.prototxt']
470+
if model.suffix not in accepted_suffixes:
471+
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
472+
print_info('{} - Found model: {}'.format(self.default_model_name, model))
473+
accepted_weights_suffixes = ['.caffemodel']
474+
if weights.suffix not in accepted_weights_suffixes:
475+
raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes))
476+
print_info('{} - Found weights: {}'.format(self.default_model_name, weights))
477+
return model, weights

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/cocosnet_evaluator.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import cv2
2020

2121
from .base_custom_evaluator import BaseCustomEvaluator
22-
from .base_models import BaseDLSDKModel, BaseOpenVINOModel, BaseCascadeModel
22+
from .base_models import BaseDLSDKModel, BaseOpenVINOModel, BaseCascadeModel, create_model
2323
from ...adapters import create_adapter
2424
from ...config import ConfigError
2525
from ...data_readers import DataRepresentation
@@ -117,21 +117,19 @@ def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_l
117117
})
118118
if not contains_all(network_info, ['cocosnet_network']) and not delayed_model_loading:
119119
raise ConfigError('network_info should contain cocosnet_network field')
120-
use_api2 = launcher.config['framework'] == 'openvino'
121-
122-
if not use_api2:
123-
self.test_model = CocosnetModel(network_info.get('cocosnet_network', {}), launcher, 'cocosnet_network',
124-
delayed_model_loading)
125-
else:
126-
self.test_model = CoCosNetModelOV(network_info.get('cocosnet_network', {}), launcher, 'cocosnet_network',
127-
delayed_model_loading)
120+
self._test_mapping = {
121+
'dlsdk': CocosnetModel,
122+
'openvino': CoCosNetModelOV
123+
}
124+
self._check_mapping = {
125+
'dlsdk': GanCheckModel,
126+
'openvino': GANCheckOVModel
127+
}
128+
self.test_model = create_model(network_info.get('cocosnet_network', {}), launcher, self._test_mapping,
129+
'cocosnet_network', delayed_model_loading)
128130
if network_info.get('verification_network'):
129-
if not use_api2:
130-
self.check_model = GanCheckModel(network_info.get('verification_network', {}), launcher,
131-
'verification_network', delayed_model_loading)
132-
else:
133-
self.check_model = GANCheckOVModel(network_info.get('verification_network', {}), launcher,
134-
'verification_network', delayed_model_loading)
131+
self.check_model = create_model(network_info.get('verification_network', {}), launcher, self._check_mapping,
132+
'verification_network', delayed_model_loading)
135133
else:
136134
self.check_model = None
137135
self._part_by_name = {

tools/accuracy_checker/openvino/tools/accuracy_checker/evaluators/custom_evaluators/colorization_evaluator.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import cv2
1919

2020
from .base_custom_evaluator import BaseCustomEvaluator
21-
from .base_models import BaseDLSDKModel, BaseCascadeModel, BaseOpenVINOModel
21+
from .base_models import BaseDLSDKModel, BaseCascadeModel, BaseOpenVINOModel, create_model
2222
from ...adapters import create_adapter
2323
from ...config import ConfigError
2424
from ...utils import extract_image_representations, contains_all, parse_partial_shape
@@ -63,37 +63,24 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
6363
class ColorizationCascadeModel(BaseCascadeModel):
6464
def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_loading=False):
6565
super().__init__(network_info, launcher)
66-
if models_args and not delayed_model_loading:
67-
colorization_network = network_info.get('colorization_network', {})
68-
verification_network = network_info.get('verification_network', {})
69-
if 'model' not in colorization_network and models_args:
70-
colorization_network['model'] = models_args[0]
71-
colorization_network['_model_is_blob'] = is_blob
72-
if 'model' not in verification_network and models_args:
73-
verification_network['model'] = models_args[1 if len(models_args) > 1 else 0]
74-
verification_network['_model_is_blob'] = is_blob
75-
network_info.update({
76-
'colorization_network': colorization_network,
77-
'verification_network': verification_network
78-
})
79-
if not contains_all(network_info, ['colorization_network', 'verification_network']):
66+
parts = ['colorization_network', 'verification_network']
67+
network_info = self.fill_part_with_model(network_info, parts, models_args, is_blob, delayed_model_loading)
68+
if not contains_all(network_info, parts) and not delayed_model_loading:
8069
raise ConfigError('configuration for colorization_network/verification_network does not exist')
81-
use_api2 = launcher.config['framework'] == 'openvino'
82-
83-
if not use_api2:
84-
self.test_model = ColorizationTestModel(network_info.get('colorization_network', {}), launcher,
85-
'colorization_network', delayed_model_loading)
86-
self.check_model = ColorizationCheckModel(network_info.get('verification_network', {}), launcher,
87-
'verification_network', delayed_model_loading)
88-
else:
89-
self.test_model = ColorizationTestOVModel(network_info.get('colorization_network', {}), launcher,
90-
'colorization_network', delayed_model_loading)
91-
self.check_model = ColorizationCheckOVModel(network_info.get('verification_network', {}), launcher,
92-
'verification_network', delayed_model_loading)
93-
self._part_by_name = {
94-
'colorization_network': self.test_model,
95-
'verification_network': self.check_model
70+
71+
self._test_mapping = {
72+
'dlsdk': ColorizationTestModel,
73+
'openvino': ColorizationTestOVModel
74+
}
75+
self._check_mapping = {
76+
'dlsdk': ColorizationCheckModel,
77+
'openvino': ColorizationCheckOVModel
9678
}
79+
self.test_model = create_model(network_info.get('colorization_network', {}), launcher, self._test_mapping,
80+
'colorization_network', delayed_model_loading)
81+
self.check_model = create_model(network_info.get('verification_network', {}), launcher, self._check_mapping,
82+
'verification_network', delayed_model_loading)
83+
self._part_by_name = {'colorization_network': self.test_model, 'verification_network': self.check_model}
9784

9885
@property
9986
def adapter(self):

0 commit comments

Comments
 (0)