@@ -92,13 +92,19 @@ def __init__(self,
92
92
assert special_token in token_to_idx , '{} is not in token_to_idx' .format (
93
93
special_token )
94
94
self ._token_to_idx = token_to_idx
95
- self ._idx_to_token = {idx : token for token , idx in token_to_idx .items ()}
95
+ self ._idx_to_token = {
96
+ idx : token
97
+ for token , idx in token_to_idx .items ()
98
+ }
96
99
if unk_token :
97
100
unk_index = self ._token_to_idx [unk_token ]
98
101
self ._token_to_idx = collections .defaultdict (lambda : unk_index )
99
102
self ._token_to_idx .update (token_to_idx )
100
103
else :
101
- self ._idx_to_token = {idx : special_token for idx , special_token in enumerate (special_tokens )}
104
+ self ._idx_to_token = {
105
+ idx : special_token
106
+ for idx , special_token in enumerate (special_tokens )
107
+ }
102
108
self ._token_to_idx = collections .defaultdict ()
103
109
self ._token_to_idx .update (
104
110
(token , idx ) for idx , token in self ._idx_to_token .items ())
@@ -136,7 +142,8 @@ def _index_counter_keys(self, counter, special_tokens, max_size, min_freq):
136
142
if freq < min_freq or len (self ._idx_to_token ) == max_size :
137
143
break
138
144
if token not in special_tokens :
139
- self ._idx_to_token [max (list (self ._idx_to_token .keys ()) + [- 1 ]) + 1 ] = token
145
+ self ._idx_to_token [max (list (self ._idx_to_token .keys ()) + [- 1 ]) +
146
+ 1 ] = token
140
147
self ._token_to_idx [token ] = max (self ._idx_to_token .keys ())
141
148
142
149
def _sort_index_according_to_user_specification (self , token_to_idx ):
@@ -206,20 +213,20 @@ def to_tokens(self, indices):
206
213
'Token indices is invalid. Expected 1D array, but received {}D array. ' .
207
214
format (len (indices .shape )))
208
215
209
- max_idx = max (self ._idx_to_token .keys ())
210
-
211
216
tokens = []
212
217
for idx in indices :
213
218
if not isinstance (idx , int ):
214
219
warnings .warn (
215
220
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
216
221
)
217
222
idx = int (idx )
218
- if idx > max_idx :
223
+
224
+ try :
225
+ tokens .append (self ._idx_to_token [idx ])
226
+ except KeyError :
219
227
raise ValueError (
220
228
'Token index {} in the provided `indices` is invalid.' .
221
229
format (idx ))
222
- tokens .append (self ._idx_to_token [idx ])
223
230
224
231
return tokens [0 ] if to_reduce else tokens
225
232
0 commit comments