@@ -905,7 +905,7 @@ def tokenize(self, text, **kwargs):
905
905
def convert_tokens_to_ids (self , tokens ):
906
906
if tokens is None :
907
907
return None
908
- if isinstance (tokens , str ):
908
+ if isinstance (tokens , ( str , AddedToken ) ):
909
909
if tokens in self .added_tokens_encoder :
910
910
return self .added_tokens_encoder [tokens ]
911
911
else :
@@ -1066,6 +1066,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
1066
1066
init_args = init_args if not args else args
1067
1067
init_kwargs .update (kwargs )
1068
1068
1069
+ def convert_added_tokens (obj ):
1070
+ if isinstance (
1071
+ obj ,
1072
+ dict ) and "__type" in obj and obj ["__type" ] == "AddedToken" :
1073
+ obj .pop ("__type" )
1074
+ return AddedToken (** obj )
1075
+ elif isinstance (obj , (list , tuple )):
1076
+ return list (convert_added_tokens (o ) for o in obj )
1077
+ elif isinstance (obj , dict ):
1078
+ return {k : convert_added_tokens (v ) for k , v in obj .items ()}
1079
+ return obj
1080
+
1081
+ init_kwargs = convert_added_tokens (init_kwargs )
1082
+
1069
1083
# Merge resolved_vocab_files arguments in init_kwargs if not including.
1070
1084
# Maybe need more ways to load resources.
1071
1085
for args_name , file_path in resolved_vocab_files .items ():
0 commit comments