Skip to content

Commit cd1e30b

Browse files
authored
Refine Vocab.to_indice (#641)
1 parent f142f06 commit cd1e30b

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

paddlenlp/data/vocab.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,19 @@ def __init__(self,
9292
assert special_token in token_to_idx, '{} is not in token_to_idx'.format(
9393
special_token)
9494
self._token_to_idx = token_to_idx
95-
self._idx_to_token = {idx: token for token, idx in token_to_idx.items()}
95+
self._idx_to_token = {
96+
idx: token
97+
for token, idx in token_to_idx.items()
98+
}
9699
if unk_token:
97100
unk_index = self._token_to_idx[unk_token]
98101
self._token_to_idx = collections.defaultdict(lambda: unk_index)
99102
self._token_to_idx.update(token_to_idx)
100103
else:
101-
self._idx_to_token = {idx: special_token for idx, special_token in enumerate(special_tokens)}
104+
self._idx_to_token = {
105+
idx: special_token
106+
for idx, special_token in enumerate(special_tokens)
107+
}
102108
self._token_to_idx = collections.defaultdict()
103109
self._token_to_idx.update(
104110
(token, idx) for idx, token in self._idx_to_token.items())
@@ -136,7 +142,8 @@ def _index_counter_keys(self, counter, special_tokens, max_size, min_freq):
136142
if freq < min_freq or len(self._idx_to_token) == max_size:
137143
break
138144
if token not in special_tokens:
139-
self._idx_to_token[max(list(self._idx_to_token.keys()) + [-1]) + 1] = token
145+
self._idx_to_token[max(list(self._idx_to_token.keys()) + [-1]) +
146+
1] = token
140147
self._token_to_idx[token] = max(self._idx_to_token.keys())
141148

142149
def _sort_index_according_to_user_specification(self, token_to_idx):
@@ -206,20 +213,20 @@ def to_tokens(self, indices):
206213
'Token indices is invalid. Expected 1D array, but received {}D array. '.
207214
format(len(indices.shape)))
208215

209-
max_idx = max(self._idx_to_token.keys())
210-
211216
tokens = []
212217
for idx in indices:
213218
if not isinstance(idx, int):
214219
warnings.warn(
215220
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
216221
)
217222
idx = int(idx)
218-
if idx > max_idx:
223+
224+
try:
225+
tokens.append(self._idx_to_token[idx])
226+
except KeyError:
219227
raise ValueError(
220228
'Token index {} in the provided `indices` is invalid.'.
221229
format(idx))
222-
tokens.append(self._idx_to_token[idx])
223230

224231
return tokens[0] if to_reduce else tokens
225232

0 commit comments

Comments
 (0)