|
21 | 21 | import unicodedata
|
22 | 22 | from typing import List, Optional
|
23 | 23 |
|
| 24 | +from ..tokenizer_utils import AddedToken, TextInput |
24 | 25 | from ...utils.log import logger
|
25 | 26 | from paddle.utils import try_import
|
26 | 27 |
|
@@ -780,6 +781,81 @@ def bpe(self, token):
|
780 | 781 | self.cache[token] = word
|
781 | 782 | return word
|
782 | 783 |
|
| 784 | + def tokenize(self, text: TextInput, **kwargs) -> List[str]: |
| 785 | + """ |
| 786 | + Converts a string in a sequence of tokens, using the tokenizer. |
| 787 | +
|
| 788 | + Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies |
| 789 | + (BPE/SentencePieces/WordPieces). Takes care of added tokens. |
| 790 | +
|
| 791 | + Args: |
| 792 | + text (`str`): |
| 793 | + The sequence to be encoded. |
| 794 | + **kwargs (additional keyword arguments): |
| 795 | + Passed along to the model-specific `prepare_for_tokenization` preprocessing method. |
| 796 | +
|
| 797 | + Returns: |
| 798 | + `List[str]`: The list of tokens. |
| 799 | + """ |
| 800 | + # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors |
| 801 | + all_special_tokens_extended = dict( |
| 802 | + (str(t), t) for t in self.all_special_tokens_extended |
| 803 | + if isinstance(t, AddedToken)) |
| 804 | + |
| 805 | + text, kwargs = self.prepare_for_tokenization(text, **kwargs) |
| 806 | + |
| 807 | + # TODO: should this be in the base class? |
| 808 | + if hasattr(self, "do_lower_case") and self.do_lower_case: |
| 809 | + # convert non-special tokens to lowercase |
| 810 | + escaped_special_toks = [ |
| 811 | + re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + |
| 812 | + self.all_special_tokens) |
| 813 | + ] |
| 814 | + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" |
| 815 | + text = re.sub(pattern, |
| 816 | + lambda m: m.groups()[0] or m.groups()[1].lower(), |
| 817 | + text) |
| 818 | + |
| 819 | + no_split_token = set(self.unique_no_split_tokens) |
| 820 | + tokens = self.tokens_trie.split(text) |
| 821 | + # ["This is something", "<special_token_1>", " else"] |
| 822 | + for i, token in enumerate(tokens): |
| 823 | + if token in no_split_token: |
| 824 | + tok_extended = all_special_tokens_extended.get(token, None) |
| 825 | + left = tokens[i - 1] if i > 0 else None |
| 826 | + right = tokens[i + 1] if i < len(tokens) - 1 else None |
| 827 | + if isinstance(tok_extended, AddedToken): |
| 828 | + if tok_extended.rstrip and right: |
| 829 | + # A bit counter-intuitive but we strip the left of the string |
| 830 | + # since tok_extended.rstrip means the special token is eating all white spaces on its right |
| 831 | + tokens[i + 1] = right.lstrip() |
| 832 | + # Strip white spaces on the left |
| 833 | + if tok_extended.lstrip and left: |
| 834 | + tokens[i - 1] = left.rstrip() # Opposite here |
| 835 | + else: |
| 836 | + # We strip left and right by default |
| 837 | + if right: |
| 838 | + tokens[i + 1] = right.lstrip() |
| 839 | + if left: |
| 840 | + tokens[i - 1] = left.rstrip() |
| 841 | + # ["This is something", "<special_token_1>", "else"] |
| 842 | + tokenized_text = [] |
| 843 | + lang = kwargs.pop("lang", "en") |
| 844 | + bypass_tokenizer = kwargs.pop("bypass_tokenizer", False) |
| 845 | + for token in tokens: |
| 846 | + # Need to skip eventual empty (fully stripped) tokens |
| 847 | + if not token: |
| 848 | + continue |
| 849 | + if token in no_split_token: |
| 850 | + tokenized_text.append(token) |
| 851 | + else: |
| 852 | + tokenized_text.extend( |
| 853 | + self._tokenize(token, |
| 854 | + lang=lang, |
| 855 | + bypass_tokenizer=bypass_tokenizer)) |
| 856 | + # ["This", " is", " something", "<special_token_1>", "else"] |
| 857 | + return tokenized_text |
| 858 | + |
783 | 859 | def _tokenize(self, text, lang="en", bypass_tokenizer=False):
|
784 | 860 | """
|
785 | 861 | Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizer.
|
|
0 commit comments