Skip to content

Commit 545a1a1

Browse files
authored
If multiple models found by automatic_model_search select one with model_name (#4001)
1 parent 9f99ec4 commit 545a1a1

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tools/accuracy_checker/accuracy_checker/launcher/dlsdk_launcher_config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright (c) 2018-2024 Intel Corporation
2+
Copyright (c) 2018-2025 Intel Corporation
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -271,6 +271,13 @@ def automatic_model_search(model_name, model_cfg, weights_cfg, model_type=None):
271271
'tflite': 'tflite',
272272
}
273273

274+
275+
def get_model_by_name(model_name, model_list):
276+
models = [model for model in model_list if model.name == '{}.xml'.format(model_name)]
277+
if not models:
278+
models = [model for model in model_list if model.name == 'openvino_model.xml']
279+
return models
280+
274281
def get_model_by_suffix(model_name, model_dir, suffix):
275282
model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix)))
276283
if not model_list:
@@ -296,7 +303,9 @@ def get_model():
296303
if not model_list:
297304
raise ConfigError('suitable model is not found')
298305
if len(model_list) != 1:
299-
raise ConfigError('More than one model matched, please specify explicitly')
306+
model_list = get_model_by_name(model_name, model_list)
307+
if len(model_list) != 1:
308+
raise ConfigError('More than one model matched, please specify explicitly')
300309
model = model_list[0]
301310
print_info('Found model {}'.format(model))
302311
return model, model.suffix == '.blob'

0 commit comments

Comments
 (0)