Skip to content

Commit 7e1a433

Browse files
JunnYugongenlei
andauthored
override xlm's tokenize method (#2551)
Co-authored-by: gongenlei <[email protected]>
1 parent ad36a14 commit 7e1a433

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

paddlenlp/transformers/xlm/tokenizer.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import unicodedata
2222
from typing import List, Optional
2323

24+
from ..tokenizer_utils import AddedToken, TextInput
2425
from ...utils.log import logger
2526
from paddle.utils import try_import
2627

@@ -780,6 +781,81 @@ def bpe(self, token):
780781
self.cache[token] = word
781782
return word
782783

784+
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
785+
"""
786+
Converts a string in a sequence of tokens, using the tokenizer.
787+
788+
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
789+
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
790+
791+
Args:
792+
text (`str`):
793+
The sequence to be encoded.
794+
**kwargs (additional keyword arguments):
795+
Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
796+
797+
Returns:
798+
`List[str]`: The list of tokens.
799+
"""
800+
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
801+
all_special_tokens_extended = dict(
802+
(str(t), t) for t in self.all_special_tokens_extended
803+
if isinstance(t, AddedToken))
804+
805+
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
806+
807+
# TODO: should this be in the base class?
808+
if hasattr(self, "do_lower_case") and self.do_lower_case:
809+
# convert non-special tokens to lowercase
810+
escaped_special_toks = [
811+
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens +
812+
self.all_special_tokens)
813+
]
814+
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
815+
text = re.sub(pattern,
816+
lambda m: m.groups()[0] or m.groups()[1].lower(),
817+
text)
818+
819+
no_split_token = set(self.unique_no_split_tokens)
820+
tokens = self.tokens_trie.split(text)
821+
# ["This is something", "<special_token_1>", " else"]
822+
for i, token in enumerate(tokens):
823+
if token in no_split_token:
824+
tok_extended = all_special_tokens_extended.get(token, None)
825+
left = tokens[i - 1] if i > 0 else None
826+
right = tokens[i + 1] if i < len(tokens) - 1 else None
827+
if isinstance(tok_extended, AddedToken):
828+
if tok_extended.rstrip and right:
829+
# A bit counter-intuitive but we strip the left of the string
830+
# since tok_extended.rstrip means the special token is eating all white spaces on its right
831+
tokens[i + 1] = right.lstrip()
832+
# Strip white spaces on the left
833+
if tok_extended.lstrip and left:
834+
tokens[i - 1] = left.rstrip() # Opposite here
835+
else:
836+
# We strip left and right by default
837+
if right:
838+
tokens[i + 1] = right.lstrip()
839+
if left:
840+
tokens[i - 1] = left.rstrip()
841+
# ["This is something", "<special_token_1>", "else"]
842+
tokenized_text = []
843+
lang = kwargs.pop("lang", "en")
844+
bypass_tokenizer = kwargs.pop("bypass_tokenizer", False)
845+
for token in tokens:
846+
# Need to skip eventual empty (fully stripped) tokens
847+
if not token:
848+
continue
849+
if token in no_split_token:
850+
tokenized_text.append(token)
851+
else:
852+
tokenized_text.extend(
853+
self._tokenize(token,
854+
lang=lang,
855+
bypass_tokenizer=bypass_tokenizer))
856+
# ["This", " is", " something", "<special_token_1>", "else"]
857+
return tokenized_text
858+
783859
def _tokenize(self, text, lang="en", bypass_tokenizer=False):
784860
"""
785861
Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizer.

0 commit comments

Comments
 (0)