Skip to content

Commit d67718f

Browse files
authored
MD accept select mode for model loading (#2987)
* MD accept select mode for model loading * provide composite model name in ignore composite
1 parent 950d497 commit d67718f

File tree

1 file changed

+61
-48
lines changed

1 file changed

+61
-48
lines changed

tools/model_tools/src/openvino/model_zoo/_configuration.py

Lines changed: 61 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import enum
1617
import fnmatch
1718
import re
1819
import shlex
@@ -200,6 +201,12 @@ def deserialize(cls, model, name, subdirectory, stages):
200201
return cls(name, subdirectory, task_type, model_stages, description, framework,
201202
license_url, precisions, quantization_output_precisions, name)
202203

204+
class ModelLoadingMode(enum.Enum):
205+
all = 0 # return all models
206+
composite_only = 1 # return only composite models
207+
non_composite_only = 2 # return only non composite models
208+
ignore_composite = 3 # ignore composite structure, return flatten models list
209+
203210
def check_composite_model_dir(model_dir):
204211
with validation.deserialization_context('In directory "{}"'.format(model_dir)):
205212
if list(model_dir.glob('*/*/**/model.yml')):
@@ -217,7 +224,7 @@ def check_composite_model_dir(model_dir):
217224
raise validation.DeserializationError(
218225
'Names of composite model parts should start with composite model name')
219226

220-
def load_models(models_root, args):
227+
def load_models(models_root, args, mode=ModelLoadingMode.all):
221228
models = []
222229
model_names = set()
223230

@@ -226,67 +233,73 @@ def load_models(models_root, args):
226233

227234
schema = _common.get_schema()
228235

229-
for composite_model_config in sorted(models_root.glob('**/composite-model.yml')):
230-
composite_model_name = composite_model_config.parent.name
231-
with validation.deserialization_context('In model "{}"'.format(composite_model_name)):
232-
if not RE_MODEL_NAME.fullmatch(composite_model_name):
233-
raise validation.DeserializationError('Invalid name, must consist only of letters, digits or ._-')
236+
if mode in (ModelLoadingMode.all, ModelLoadingMode.composite_only):
234237

235-
check_composite_model_dir(composite_model_config.parent)
238+
for composite_model_config in sorted(models_root.glob('**/composite-model.yml')):
239+
composite_model_name = composite_model_config.parent.name
240+
with validation.deserialization_context('In model "{}"'.format(composite_model_name)):
241+
if not RE_MODEL_NAME.fullmatch(composite_model_name):
242+
raise validation.DeserializationError('Invalid name, must consist only of letters, digits or ._-')
236243

237-
with composite_model_config.open('rb') as config_file, \
238-
validation.deserialization_context('In config "{}"'.format(composite_model_config)):
244+
check_composite_model_dir(composite_model_config.parent)
239245

240-
composite_model = yaml.safe_load(config_file)
241-
model_stages = {}
242-
for stage in sorted(composite_model_config.parent.glob('*/model.yml')):
243-
with stage.open('rb') as stage_config_file, \
244-
validation.deserialization_context('In config "{}"'.format(stage_config_file)):
245-
model = yaml.safe_load(stage_config_file)
246-
if not schema.check(model):
247-
raise validation.DeserializationError('Configuration file check was\'t successful.')
246+
with composite_model_config.open('rb') as config_file, \
247+
validation.deserialization_context('In config "{}"'.format(composite_model_config)):
248248

249-
stage_subdirectory = stage.parent.relative_to(models_root)
250-
model_stages[stage_subdirectory] = model
249+
composite_model = yaml.safe_load(config_file)
250+
model_stages = {}
251+
for stage in sorted(composite_model_config.parent.glob('*/model.yml')):
252+
with stage.open('rb') as stage_config_file, \
253+
validation.deserialization_context('In config "{}"'.format(stage_config_file)):
254+
model = yaml.safe_load(stage_config_file)
255+
if not schema.check(model):
256+
raise validation.DeserializationError('Configuration file check was\'t successful.')
251257

252-
if len(model_stages) == 0:
253-
continue
254-
subdirectory = composite_model_config.parent.relative_to(models_root)
255-
composite_models.append(CompositeModel.deserialize(
256-
composite_model, composite_model_name, subdirectory, model_stages
257-
))
258+
stage_subdirectory = stage.parent.relative_to(models_root)
259+
model_stages[stage_subdirectory] = model
258260

259-
if composite_model_name in composite_model_names:
260-
raise validation.DeserializationError(
261-
'Duplicate composite model name "{}"'.format(composite_model_name))
262-
composite_model_names.add(composite_model_name)
261+
if len(model_stages) == 0:
262+
continue
263+
subdirectory = composite_model_config.parent.relative_to(models_root)
264+
composite_models.append(CompositeModel.deserialize(
265+
composite_model, composite_model_name, subdirectory, model_stages
266+
))
263267

264-
for config_path in sorted(models_root.glob('**/model.yml')):
265-
subdirectory = config_path.parent
268+
if composite_model_name in composite_model_names:
269+
raise validation.DeserializationError(
270+
'Duplicate composite model name "{}"'.format(composite_model_name))
271+
composite_model_names.add(composite_model_name)
266272

267-
is_composite = (subdirectory.parent / 'composite-model.yml').exists()
268-
if is_composite:
269-
continue
273+
if mode != ModelLoadingMode.composite_only:
274+
for config_path in sorted(models_root.glob('**/model.yml')):
275+
subdirectory = config_path.parent
270276

271-
subdirectory = subdirectory.relative_to(models_root)
277+
is_composite = (subdirectory.parent / 'composite-model.yml').exists()
278+
composite_model_name = None
279+
if is_composite:
280+
if mode != ModelLoadingMode.ignore_composite:
281+
continue
282+
composite_model_name = subdirectory.parent.name
272283

273-
with config_path.open('rb') as config_file, \
274-
validation.deserialization_context('In config "{}"'.format(config_path)):
284+
subdirectory = subdirectory.relative_to(models_root)
275285

276-
model = yaml.safe_load(config_file)
277-
if not schema.check(model):
278-
raise validation.DeserializationError('Configuration file check was\'t successful.')
286+
with config_path.open('rb') as config_file, \
287+
validation.deserialization_context('In config "{}"'.format(config_path)):
279288

280-
for bad_key in ['name', 'subdirectory']:
281-
if bad_key in model:
282-
raise validation.DeserializationError('Unsupported key "{}"'.format(bad_key))
289+
model = yaml.safe_load(config_file)
290+
if not schema.check(model):
291+
raise validation.DeserializationError('Configuration file check was\'t successful.')
283292

284-
models.append(Model.deserialize(model, subdirectory.name, subdirectory, None))
293+
for bad_key in ['name', 'subdirectory']:
294+
if bad_key in model:
295+
raise validation.DeserializationError('Unsupported key "{}"'.format(bad_key))
285296

286-
if models[-1].name in model_names:
287-
raise validation.DeserializationError(
288-
'Duplicate model name "{}"'.format(models[-1].name))
289-
model_names.add(models[-1].name)
297+
models.append(Model.deserialize(model, subdirectory.name, subdirectory, composite_model_name))
298+
299+
if models[-1].name in model_names:
300+
raise validation.DeserializationError(
301+
'Duplicate model name "{}"'.format(models[-1].name))
302+
model_names.add(models[-1].name)
290303

291304
return sorted(models + composite_models, key=lambda model : model.name)
292305

0 commit comments

Comments
 (0)