Skip to content

Commit 90cef20

Browse files
authored
[Tokenizer] Add replace_additional_special_tokens parameter to add_special_tokens (#9144)
1 parent 7faad55 commit 90cef20

File tree

3 files changed

+86
-14
lines changed

3 files changed

+86
-14
lines changed

paddlenlp/transformers/luke/tokenizer.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,12 @@ def __call__(
482482

483483
return encode_output
484484

485+
def __len__(self):
486+
"""
487+
Size of the full vocabulary with the added tokens.
488+
"""
489+
return len(self.encoder) + len(self.added_tokens_encoder)
490+
485491
def tokenize(self, text, add_prefix_space=False):
486492
"""
487493
Tokenize a string.
@@ -608,22 +614,41 @@ def _convert_token_to_id_with_added_voc(self, token):
608614

609615
return self._convert_token_to_id(token)
610616

611-
def add_special_tokens(self, token_list: Union[List[int], Dict]):
617+
def add_special_tokens(self, token_list: Union[List[int], Dict], replace_additional_special_tokens: bool = True):
612618
"""
613619
Adding special tokens if you need.
614620
615621
Args:
616622
token_list (List[int], Dict[List[int]]):
617623
The special token list you provided. If you provide a Dict, the key of the Dict must
618624
be "additional_special_tokens" and the value must be token list.
625+
replace_additional_special_tokens (bool, optional, defaults to True):
626+
If True, the existing list of additional special tokens will be replaced by the list provided in
627+
`token_list`. Otherwise, `self._additional_special_tokens` is just extended. In the former
628+
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
629+
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
630+
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
631+
`additional_special_tokens` are still added tokens, and will not be split by the model.
619632
"""
620633
if isinstance(token_list, dict):
621634
token_list = token_list["additional_special_tokens"]
635+
636+
if replace_additional_special_tokens:
637+
self._additional_special_tokens = list(token_list)
638+
else:
639+
self._additional_special_tokens.extend(
640+
[token for token in token_list if token not in self._additional_special_tokens]
641+
)
622642
encoder_dict = dict()
623643
decoder_dict = dict()
644+
645+
token_id_counter = len(self)
624646
for token in token_list:
625-
encoder_dict[token] = len(self.encoder.keys())
626-
decoder_dict[len(self.decoder.keys())] = token
647+
if token not in self.added_tokens_encoder:
648+
encoder_dict[token] = token_id_counter
649+
decoder_dict[token_id_counter] = token
650+
token_id_counter += 1
651+
627652
self.added_tokens_encoder.update(encoder_dict)
628653
self.added_tokens_decoder.update(decoder_dict)
629654

paddlenlp/transformers/tokenizer_utils_base.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -801,14 +801,16 @@ def sanitize_special_tokens(self) -> int:
801801
"""
802802
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
803803

804-
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
804+
def add_special_tokens(
805+
self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True
806+
) -> int:
805807
"""
806808
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
807809
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
808810
current vocabulary).
809811
810-
Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
811-
matrix of the model so that its embedding matrix matches the tokenizer.
812+
When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the
813+
model so that its embedding matrix matches the tokenizer.
812814
813815
In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
814816
@@ -829,6 +831,13 @@ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToke
829831
830832
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
831833
assign the index of the `unk_token` to them).
834+
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
835+
If `True`, the existing list of additional special tokens will be replaced by the list provided in
836+
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former
837+
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
838+
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
839+
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
840+
`additional_special_tokens` are still added tokens, and will not be split by the model.
832841
833842
Returns:
834843
`int`: Number of tokens added to the vocabulary.
@@ -852,25 +861,38 @@ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToke
852861
if not special_tokens_dict:
853862
return 0
854863

855-
added_tokens = 0
864+
added_tokens = []
856865
for key, value in special_tokens_dict.items():
857866
assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
858867

859868
if self.verbose:
860869
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
861-
setattr(self, key, value)
862870

863871
if key == "additional_special_tokens":
864872
assert isinstance(value, (list, tuple)) and all(
865873
isinstance(t, (str, AddedToken)) for t in value
866874
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
867-
added_tokens += self.add_tokens(value, special_tokens=True)
868-
else:
869-
assert isinstance(
870-
value, (str, AddedToken)
871-
), f"Token {value} for key {key} should be a str or an AddedToken instance"
872-
added_tokens += self.add_tokens([value], special_tokens=True)
873875

876+
to_add = []
877+
for token in value:
878+
if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
879+
continue
880+
to_add.append(token)
881+
if replace_additional_special_tokens and len(to_add) > 0:
882+
setattr(self, key, list(to_add))
883+
else:
884+
self._additional_special_tokens.extend(to_add)
885+
added_tokens += to_add
886+
887+
else:
888+
if not isinstance(value, (str, AddedToken)):
889+
raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance")
890+
setattr(self, key, value)
891+
if value not in added_tokens:
892+
added_tokens.append(value)
893+
894+
# if we are adding tokens that were not part of the vocab, we ought to add them
895+
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
874896
return added_tokens
875897

876898
def add_tokens(

tests/transformers/test_tokenizer_common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,31 @@ def test_maximum_encoding_length_pair_input(self):
11561156

11571157
# self.assertEqual(encoded_masked, encoded_1)
11581158

1159+
def test_special_token_addition(self):
1160+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
1161+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
1162+
# Create tokenizer and add an additional special token
1163+
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
1164+
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
1165+
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
1166+
with tempfile.TemporaryDirectory() as tmp_dir:
1167+
tokenizer_1.save_pretrained(tmp_dir)
1168+
# Load the above tokenizer and add the same special token a second time
1169+
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
1170+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
1171+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
1172+
1173+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
1174+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
1175+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
1176+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
1177+
1178+
tokenizer_2.add_special_tokens(
1179+
{"additional_special_tokens": ["<tok>"]},
1180+
replace_additional_special_tokens=False,
1181+
)
1182+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
1183+
11591184
def test_special_tokens_mask(self):
11601185
tokenizers = self.get_tokenizers(do_lower_case=False)
11611186
for tokenizer in tokenizers:

0 commit comments

Comments
 (0)