Skip to content

Commit 5cadc0a

Browse files
authored
AC: fix text classification truncation by max len (#3015)
* AC: fix text classification truncation by max len * update data reading for annotation
1 parent 073dc69 commit 5cadc0a

File tree

6 files changed

+51
-14
lines changed

6 files changed

+51
-14
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/text_classification.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
from ..config import PathField, StringField, NumberField, BoolField, ListField, ConfigError
23+
from ..data_readers import AnnotationDataIdentifier
2324
from ..representation import TextClassificationAnnotation
2425
from ..utils import string_to_list, UnsupportedPackage, read_json
2526
from .format_converter import BaseFormatConverter, ConverterReturn, verify_label_map
@@ -112,11 +113,7 @@ def read_annotation(self):
112113
return lines
113114

114115
def convert_single_example(self, example): # pylint:disable=R0912
115-
identifier = [
116-
'input_ids_{}'.format(example.guid),
117-
'input_mask_{}'.format(example.guid),
118-
'segment_ids_{}'.format(example.guid)
119-
]
116+
identifier = AnnotationDataIdentifier(example.guid, [])
120117
if not self.external_tok:
121118
tokens_a = self.tokenizer.tokenize(example.text_a)
122119
tokens_b = None
@@ -165,6 +162,7 @@ def convert_single_example(self, example): # pylint:disable=R0912
165162

166163
if len(tokens) > self.max_seq_length:
167164
tokens = tokens[:self.max_seq_length]
165+
segment_ids = segment_ids[:self.max_seq_length]
168166

169167
input_ids = self.tokenizer.convert_tokens_to_ids(tokens) if self.support_vocab or self.external_tok else tokens
170168
input_mask = [0 if not self.class_token_first else 1] * len(input_ids)

tools/accuracy_checker/openvino/tools/accuracy_checker/data_readers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DataReaderField,
2020
ReaderCombiner,
2121
DataRepresentation,
22+
AnnotationDataIdentifier,
2223
ClipIdentifier,
2324
MultiFramesInputIdentifier,
2425
ImagePairIdentifier,
@@ -32,6 +33,7 @@
3233
serialize_identifier,
3334
deserialize_identifier,
3435
create_identifier_key,
36+
create_ann_identifier_key,
3537

3638
create_reader,
3739
REQUIRES_ANNOTATIONS
@@ -67,6 +69,7 @@
6769
'KaldiFrameIdentifier',
6870
'ParametricImageIdentifier',
6971
'VideoFrameIdentifier',
72+
'AnnotationDataIdentifier',
7073

7174
'OpenCVFrameReader',
7275
'OpenCVImageReader',
@@ -94,5 +97,6 @@
9497

9598
'serialize_identifier',
9699
'deserialize_identifier',
97-
'create_identifier_key'
100+
'create_identifier_key',
101+
'create_ann_identifier_key'
98102
]

tools/accuracy_checker/openvino/tools/accuracy_checker/data_readers/annotation_readers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from ..config import ListField, ConfigError
18-
from .data_reader import BaseReader, create_identifier_key
18+
from .data_reader import BaseReader, create_ann_identifier_key, AnnotationDataIdentifier
1919
from ..utils import contains_all
2020

2121

@@ -47,7 +47,10 @@ def configure(self):
4747
self.multi_infer = self.get_value_from_config('multi_infer')
4848

4949
def read(self, data_id):
50-
relevant_annotation = self.data_source[create_identifier_key(data_id)]
50+
if isinstance(data_id, AnnotationDataIdentifier):
51+
ordered_data_id = ['{}_{}'.format(feat, data_id.annotation_id) for feat in self.feature_list]
52+
data_id.data_id = ordered_data_id if not self.single else ordered_data_id[0]
53+
relevant_annotation = self.data_source[create_ann_identifier_key(data_id)]
5154
if not contains_all(relevant_annotation.__dict__, self.feature_list):
5255
raise ConfigError(
5356
'annotation_class prototype does not contain provided features {}'.format(', '.join(self.feature_list))

tools/accuracy_checker/openvino/tools/accuracy_checker/data_readers/data_reader.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def __init__(self, data, meta=None, identifier=''):
4949
self.metadata['image_size'] = data.shape if not isinstance(data, list) else np.shape(data[0])
5050

5151

52+
class AnnotationDataIdentifier:
53+
def __init__(self, ann_id, data_id):
54+
self.annotation_id = ann_id
55+
self.data_id = data_id
56+
5257
ClipIdentifier = namedtuple('ClipIdentifier', ['video', 'clip_id', 'frames'])
5358
MultiFramesInputIdentifier = namedtuple('MultiFramesInputIdentifier', ['input_id', 'frames'])
5459
ImagePairIdentifier = namedtuple('ImagePairIdentifier', ['first', 'second'])
@@ -64,6 +69,10 @@ def __init__(self, data, meta=None, identifier=''):
6469
)
6570

6671
identifier_serialization = {
72+
'AnnotationDataIdentifier': IdentifierSerializationOptions(
73+
'annotation_data_identifier', ['annotation_id', 'data_id'],
74+
AnnotationDataIdentifier, [False, True], [False, True]
75+
),
6776
'ClipIdentifier': IdentifierSerializationOptions(
6877
'clip_identifier', ['video', 'clip_id', 'frames'], ClipIdentifier, [False, False, False], [False, False, True]),
6978
'MultiFramesInputIdentifier': IdentifierSerializationOptions(
@@ -120,15 +129,31 @@ def deserialize_identifier(identifier):
120129
return identifier
121130

122131

132+
def create_ann_identifier_key(identifier):
133+
if isinstance(identifier, list):
134+
return ListIdentifier(tuple(create_ann_identifier_key(elem) for elem in identifier))
135+
if isinstance(identifier, ClipIdentifier):
136+
return ClipIdentifier(identifier.video, identifier.clip_id, tuple(identifier.frames))
137+
if isinstance(identifier, MultiFramesInputIdentifier):
138+
return MultiFramesInputIdentifier(tuple(identifier.input_id), tuple(identifier.frames))
139+
if isinstance(identifier, ParametricImageIdentifier):
140+
return ParametricImageIdentifier(identifier.identifier, tuple(identifier.parameters))
141+
if isinstance(identifier, AnnotationDataIdentifier):
142+
return identifier.annotation_id
143+
return identifier
144+
145+
123146
def create_identifier_key(identifier):
124147
if isinstance(identifier, list):
125-
return ListIdentifier(tuple(create_identifier_key(elem) for elem in identifier))
148+
return ListIdentifier(tuple(create_ann_identifier_key(elem) for elem in identifier))
126149
if isinstance(identifier, ClipIdentifier):
127150
return ClipIdentifier(identifier.video, identifier.clip_id, tuple(identifier.frames))
128151
if isinstance(identifier, MultiFramesInputIdentifier):
129152
return MultiFramesInputIdentifier(tuple(identifier.input_id), tuple(identifier.frames))
130153
if isinstance(identifier, ParametricImageIdentifier):
131154
return ParametricImageIdentifier(identifier.identifier, tuple(identifier.parameters))
155+
if isinstance(identifier, AnnotationDataIdentifier):
156+
return AnnotationDataIdentifier(identifier.annotation_id, tuple(identifier.data_id))
132157
return identifier
133158

134159

tools/accuracy_checker/openvino/tools/accuracy_checker/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from .data_readers import (
5555
DataReaderField, REQUIRES_ANNOTATIONS, BaseReader,
56-
serialize_identifier, deserialize_identifier, create_identifier_key
56+
serialize_identifier, deserialize_identifier, create_ann_identifier_key
5757
)
5858
from .logging import print_info
5959

@@ -460,15 +460,15 @@ def __init__(self, annotations, meta, name='', config=None):
460460
self._data_buffer = OrderedDict()
461461
self._meta = meta
462462
for ann in annotations:
463-
idx = create_identifier_key(ann.identifier)
463+
idx = create_ann_identifier_key(ann.identifier)
464464
self._data_buffer[idx] = ann
465465

466466
def __getitem__(self, item):
467-
return self._data_buffer[item]
467+
return self._data_buffer[create_ann_identifier_key(item)]
468468

469469
@property
470470
def identifiers(self):
471-
return list(self._data_buffer)
471+
return list(map(lambda ann: ann.identifier, self._data_buffer.values()))
472472

473473
def __len__(self):
474474
return len(self._data_buffer)

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/input_feeder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from ..config import ConfigError
2222
from ..utils import extract_image_representations
2323
from ..data_readers import (
24-
MultiFramesInputIdentifier, KaldiFrameIdentifier, KaldiMatrixIdentifier, ParametricImageIdentifier
24+
MultiFramesInputIdentifier,
25+
KaldiFrameIdentifier,
26+
KaldiMatrixIdentifier,
27+
ParametricImageIdentifier,
28+
AnnotationDataIdentifier
2529
)
2630

2731
LAYER_LAYOUT_TO_IMAGE_LAYOUT = {
@@ -192,6 +196,9 @@ def match_by_regex(data, identifiers, input_regex):
192196
for data_representation in data_representation_batch:
193197
identifiers = data_representation.identifier
194198
data = data_representation.data
199+
if isinstance(identifiers, AnnotationDataIdentifier):
200+
identifiers = identifiers.data_id
201+
195202
if isinstance(identifiers, ParametricImageIdentifier):
196203
input_batch.append(data[idx])
197204
continue

0 commit comments

Comments
 (0)