Skip to content

Commit 0e96599

Browse files
authored
AC: update data provider api (#3067)
* AC: update data provider api * fix linting
1 parent 8414624 commit 0e96599

File tree

1 file changed

+36
-3
lines changed
  • tools/accuracy_checker/openvino/tools/accuracy_checker

1 file changed

+36
-3
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/dataset.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def parameters(cls):
106106
is_directory=True, optional=True, description='additional data source for annotation loading'
107107
),
108108
'subset_file': PathField(optional=True, description='file with identifiers for subset', check_exists=False),
109+
'subset': ListField(optional=True, description='identifiers for subset'),
109110
'store_subset': BoolField(
110111
optional=True, default=False,
111112
description='save subset ids to file specified in subset_file parameter'
@@ -603,13 +604,39 @@ def __init__(
603604
if self.store_subset:
604605
self.sava_subset()
605606

607+
@classmethod
608+
def from_config(cls, config, load_annotation=False):
609+
if load_annotation:
610+
annotation, meta = Dataset.load_annotation(config)
611+
annotation_provider = AnnotationProvider(annotation, meta)
612+
else:
613+
annotation_provider = None
614+
data_reader_config = config.get('reader', 'opencv_imread')
615+
data_source = config.get('data_source')
616+
if isinstance(data_reader_config, str):
617+
data_reader_type = data_reader_config
618+
data_reader_config = None
619+
elif isinstance(data_reader_config, dict):
620+
data_reader_type = data_reader_config['type']
621+
else:
622+
raise ConfigError('reader should be dict or string')
623+
if data_reader_type in REQUIRES_ANNOTATIONS:
624+
data_source = annotation_provider
625+
data_reader = BaseReader.provide(data_reader_type, data_source, data_reader_config)
626+
return cls(
627+
data_reader, annotation_provider, dataset_config=config
628+
)
629+
606630
def create_data_list(self, data_list=None):
607631
if data_list is not None:
608632
self._data_list = data_list
609633
return
610634
self.store_subset = self.dataset_config.get('store_subset', False)
635+
if 'subset' in self.dataset_config:
636+
self._create_data_list(self.dataset_config['subset'])
637+
return
611638

612-
if self.dataset_config.get('subset_file'):
639+
if 'subset_file' in self.dataset_config:
613640
subset_file = Path(self.dataset_config['subset_file'])
614641
if subset_file.exists() and not self.store_subset:
615642
self.read_subset(subset_file)
@@ -623,10 +650,13 @@ def create_data_list(self, data_list=None):
623650
self._data_list = [file.name for file in self.data_reader.data_source.glob('*')]
624651

625652
def read_subset(self, subset_file):
626-
identifiers = [deserialize_identifier(idx) for idx in read_yaml(subset_file)]
627-
self._data_list = identifiers
653+
self._create_data_list(read_yaml(subset_file))
628654
print_info("loaded {} data items from {}".format(len(self._data_list), subset_file))
629655

656+
def _create_data_list(self, subset):
657+
identifiers = [deserialize_identifier(idx) for idx in subset]
658+
self._data_list = identifiers
659+
630660
def sava_subset(self):
631661
identifiers = [serialize_identifier(idx) for idx in self._data_list]
632662
subset_file = Path(self.dataset_config.get(
@@ -663,6 +693,9 @@ def __len__(self):
663693
def identifiers(self):
664694
return self._data_list
665695

696+
def get_data_path(self):
697+
return [self.data_reader.data_source / identifier for identifier in self._data_list]
698+
666699
def make_subset(self, ids=None, start=0, step=1, end=None, accept_pairs=False):
667700
if self.annotation_provider:
668701
ids = self.annotation_provider.make_subset(ids, start, step, end, accept_pairs)

0 commit comments

Comments
 (0)