Skip to content

Commit f142f06

Browse files
BFJLsmallv0221guoshengCS
authored
Change the _idx_to_token list to a dictionary (#617)
Co-authored-by: smallv0221 <[email protected]> Co-authored-by: Guo Sheng <[email protected]>
1 parent 65db861 commit f142f06

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

paddlenlp/data/vocab.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,16 @@ def __init__(self,
9292
assert special_token in token_to_idx, '{} is not in token_to_idx'.format(
9393
special_token)
9494
self._token_to_idx = token_to_idx
95-
self._idx_to_token = sorted(
96-
self._token_to_idx.keys(),
97-
key=lambda token: self._token_to_idx[token])
95+
self._idx_to_token = {idx: token for token, idx in token_to_idx.items()}
9896
if unk_token:
9997
unk_index = self._token_to_idx[unk_token]
10098
self._token_to_idx = collections.defaultdict(lambda: unk_index)
10199
self._token_to_idx.update(token_to_idx)
102100
else:
103-
self._idx_to_token = list(special_tokens)
101+
self._idx_to_token = {idx: special_token for idx, special_token in enumerate(special_tokens)}
104102
self._token_to_idx = collections.defaultdict()
105103
self._token_to_idx.update(
106-
(token, idx) for idx, token in enumerate(self._idx_to_token))
104+
(token, idx) for idx, token in self._idx_to_token.items())
107105
self._index_counter_keys(counter, special_tokens, max_size,
108106
min_freq)
109107
if token_to_idx:
@@ -138,8 +136,8 @@ def _index_counter_keys(self, counter, special_tokens, max_size, min_freq):
138136
if freq < min_freq or len(self._idx_to_token) == max_size:
139137
break
140138
if token not in special_tokens:
141-
self._idx_to_token.append(token)
142-
self._token_to_idx[token] = len(self._idx_to_token) - 1
139+
self._idx_to_token[max(list(self._idx_to_token.keys()) + [-1]) + 1] = token
140+
self._token_to_idx[token] = max(self._idx_to_token.keys())
143141

144142
def _sort_index_according_to_user_specification(self, token_to_idx):
145143
# Sanity checks
@@ -208,7 +206,7 @@ def to_tokens(self, indices):
208206
'Token indices is invalid. Expected 1D array, but received {}D array. '.
209207
format(len(indices.shape)))
210208

211-
max_idx = len(self._idx_to_token) - 1
209+
max_idx = max(self._idx_to_token.keys())
212210

213211
tokens = []
214212
for idx in indices:
@@ -316,7 +314,7 @@ def to_json(self, path=None):
316314
json_str = vocab.to_json(path='./vocab.json')
317315
"""
318316
vocab_dict = {}
319-
vocab_dict['idx_to_token'] = self.idx_to_token
317+
vocab_dict['idx_to_token'] = dict(self.idx_to_token)
320318
vocab_dict['token_to_idx'] = dict(self.token_to_idx)
321319
vocab_dict['unk_token'] = self.unk_token
322320
vocab_dict['identifiers_to_tokens'] = self._identifiers_to_tokens

0 commit comments

Comments
 (0)