@@ -801,14 +801,16 @@ def sanitize_special_tokens(self) -> int:
801
801
"""
802
802
return self .add_tokens (self .all_special_tokens_extended , special_tokens = True )
803
803
804
- def add_special_tokens (self , special_tokens_dict : Dict [str , Union [str , AddedToken ]]) -> int :
804
+ def add_special_tokens (
805
+ self , special_tokens_dict : Dict [str , Union [str , AddedToken ]], replace_additional_special_tokens = True
806
+ ) -> int :
805
807
"""
806
808
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
807
809
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
808
810
current vocabulary).
809
811
810
- Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
811
- matrix of the model so that its embedding matrix matches the tokenizer.
812
+ When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the
813
+ model so that its embedding matrix matches the tokenizer.
812
814
813
815
In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
814
816
@@ -829,6 +831,13 @@ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToke
829
831
830
832
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
831
833
assign the index of the `unk_token` to them).
834
+ replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
835
+ If `True`, the existing list of additional special tokens will be replaced by the list provided in
836
+ `special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former
837
+ case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
838
+ as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
839
+ `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
840
+ `additional_special_tokens` are still added tokens, and will not be split by the model.
832
841
833
842
Returns:
834
843
`int`: Number of tokens added to the vocabulary.
@@ -852,25 +861,38 @@ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToke
852
861
if not special_tokens_dict :
853
862
return 0
854
863
855
- added_tokens = 0
864
+ added_tokens = []
856
865
for key , value in special_tokens_dict .items ():
857
866
assert key in self .SPECIAL_TOKENS_ATTRIBUTES , f"Key { key } is not a special token"
858
867
859
868
if self .verbose :
860
869
logger .info (f"Assigning { value } to the { key } key of the tokenizer" )
861
- setattr (self , key , value )
862
870
863
871
if key == "additional_special_tokens" :
864
872
assert isinstance (value , (list , tuple )) and all (
865
873
isinstance (t , (str , AddedToken )) for t in value
866
874
), f"Tokens { value } for key { key } should all be str or AddedToken instances"
867
- added_tokens += self .add_tokens (value , special_tokens = True )
868
- else :
869
- assert isinstance (
870
- value , (str , AddedToken )
871
- ), f"Token { value } for key { key } should be a str or an AddedToken instance"
872
- added_tokens += self .add_tokens ([value ], special_tokens = True )
873
875
876
+ to_add = []
877
+ for token in value :
878
+ if not replace_additional_special_tokens and str (token ) in self .additional_special_tokens :
879
+ continue
880
+ to_add .append (token )
881
+ if replace_additional_special_tokens and len (to_add ) > 0 :
882
+ setattr (self , key , list (to_add ))
883
+ else :
884
+ self ._additional_special_tokens .extend (to_add )
885
+ added_tokens += to_add
886
+
887
+ else :
888
+ if not isinstance (value , (str , AddedToken )):
889
+ raise ValueError (f"Token { value } for key { key } should be a str or an AddedToken instance" )
890
+ setattr (self , key , value )
891
+ if value not in added_tokens :
892
+ added_tokens .append (value )
893
+
894
+ # if we are adding tokens that were not part of the vocab, we ought to add them
895
+ added_tokens = self .add_tokens (added_tokens , special_tokens = True )
874
896
return added_tokens
875
897
876
898
def add_tokens (
0 commit comments