@@ -92,18 +92,16 @@ def __init__(self,
92
92
assert special_token in token_to_idx , '{} is not in token_to_idx' .format (
93
93
special_token )
94
94
self ._token_to_idx = token_to_idx
95
- self ._idx_to_token = sorted (
96
- self ._token_to_idx .keys (),
97
- key = lambda token : self ._token_to_idx [token ])
95
+ self ._idx_to_token = {idx : token for token , idx in token_to_idx .items ()}
98
96
if unk_token :
99
97
unk_index = self ._token_to_idx [unk_token ]
100
98
self ._token_to_idx = collections .defaultdict (lambda : unk_index )
101
99
self ._token_to_idx .update (token_to_idx )
102
100
else :
103
- self ._idx_to_token = list (special_tokens )
101
+ self ._idx_to_token = { idx : special_token for idx , special_token in enumerate (special_tokens )}
104
102
self ._token_to_idx = collections .defaultdict ()
105
103
self ._token_to_idx .update (
106
- (token , idx ) for idx , token in enumerate ( self ._idx_to_token ))
104
+ (token , idx ) for idx , token in self ._idx_to_token . items ( ))
107
105
self ._index_counter_keys (counter , special_tokens , max_size ,
108
106
min_freq )
109
107
if token_to_idx :
@@ -138,8 +136,8 @@ def _index_counter_keys(self, counter, special_tokens, max_size, min_freq):
138
136
if freq < min_freq or len (self ._idx_to_token ) == max_size :
139
137
break
140
138
if token not in special_tokens :
141
- self ._idx_to_token . append ( token )
142
- self ._token_to_idx [token ] = len (self ._idx_to_token ) - 1
139
+ self ._idx_to_token [ max ( list ( self . _idx_to_token . keys ()) + [ - 1 ]) + 1 ] = token
140
+ self ._token_to_idx [token ] = max (self ._idx_to_token . keys ())
143
141
144
142
def _sort_index_according_to_user_specification (self , token_to_idx ):
145
143
# Sanity checks
@@ -208,7 +206,7 @@ def to_tokens(self, indices):
208
206
'Token indices is invalid. Expected 1D array, but received {}D array. ' .
209
207
format (len (indices .shape )))
210
208
211
- max_idx = len (self ._idx_to_token ) - 1
209
+ max_idx = max (self ._idx_to_token . keys ())
212
210
213
211
tokens = []
214
212
for idx in indices :
@@ -316,7 +314,7 @@ def to_json(self, path=None):
316
314
json_str = vocab.to_json(path='./vocab.json')
317
315
"""
318
316
vocab_dict = {}
319
- vocab_dict ['idx_to_token' ] = self .idx_to_token
317
+ vocab_dict ['idx_to_token' ] = dict ( self .idx_to_token )
320
318
vocab_dict ['token_to_idx' ] = dict (self .token_to_idx )
321
319
vocab_dict ['unk_token' ] = self .unk_token
322
320
vocab_dict ['identifiers_to_tokens' ] = self ._identifiers_to_tokens
0 commit comments