Skip to content

Commit 50ecdde

Browse files
mzuevxeldercrow
authored andcommitted
IMDB annotation for XLNet accuracy test fixes. (openvinotoolkit#831)
* fix path for Sentence Piece model loading (PosixPath->str) * fix method name (PieceTold -> PieceToId) * fix convert_single_example() to be able to work without sample.text_b
1 parent ea9f415 commit 50ecdde

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

tools/accuracy_checker/accuracy_checker/annotation_converters/_nlp_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def __init__(self, tokenizer_model, lower_case=True, remove_space=True):
257257
if spm is None:
258258
raise ConfigError('Sentence piece tokenizer required sentencepiece, please install it before usage')
259259
self.encoder = spm.SentencePieceProcessor()
260-
self.encoder.Load(tokenizer_model)
260+
self.encoder.Load(str(tokenizer_model))
261261
self.lower_case = lower_case
262262
self.remove_space = remove_space
263263

@@ -275,7 +275,7 @@ def preprocess_text(self, inputs):
275275

276276
def encode_ids(self, text, sample=False):
277277
pieces = self.encode_pieces(text, sample)
278-
ids = [self.encoder.PieceTold(piece) for piece in pieces]
278+
ids = [self.encoder.PieceToId(piece) for piece in pieces]
279279
return ids
280280

281281
def encode_pieces(self, text, sample=False):

tools/accuracy_checker/accuracy_checker/annotation_converters/text_classification.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,19 @@ def convert_single_example(self, example):
7373
'segment_ids_{}'.format(example.guid)
7474
]
7575
tokens_a = self.tokenizer.tokenize(example.text_a)
76-
tokens_b = self.tokenizer.tokenize(example.text_b)
77-
truncate_seq_pair(tokens_a, tokens_b, self.max_seq_length - 3)
76+
tokens_b = None
77+
if example.text_b:
78+
tokens_b = self.tokenizer.tokenize(example.text_b if example.text_b is not None else '')
79+
80+
if tokens_b:
81+
# Modifies `tokens_a` and `tokens_b` in place so that the total
82+
# length is less than the specified length.
83+
# Account for two [SEP] & one [CLS] with "- 3"
84+
truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
85+
else:
86+
# Account for one [SEP] & one [CLS] with "- 2"
87+
if len(tokens_a) > self.max_seq_length - 2:
88+
tokens_a = tokens_a[:self.max_seq_length - 2]
7889

7990
tokens = []
8091
segment_ids = []
@@ -83,11 +94,13 @@ def convert_single_example(self, example):
8394
segment_ids.append(SEG_ID_A)
8495
tokens.append('[SEP]' if self.support_vocab else SEP_ID)
8596
segment_ids.append(SEG_ID_A)
86-
for token in tokens_b:
87-
tokens.append(token)
97+
98+
if tokens_b:
99+
for token in tokens_b:
100+
tokens.append(token)
101+
segment_ids.append(SEG_ID_B)
102+
tokens.append('[SEP]' if self.support_vocab else SEP_ID)
88103
segment_ids.append(SEG_ID_B)
89-
tokens.append('[SEP]' if self.support_vocab else SEP_ID)
90-
segment_ids.append(SEG_ID_B)
91104

92105
tokens.append("[CLS]" if self.support_vocab else CLS_ID)
93106
segment_ids.append(SEG_ID_CLS)

0 commit comments

Comments
 (0)