3030]
3131
3232import errno
33+ import functools
3334import io
3435import os
3536import time
4243import numpy as np
4344
4445from ..base import get_home_dir
46+ from ..vocab .vocab import Vocab
4547from .utils import _extract_archive
48+ from .wordpiece import tokenize as wordpiece_tokenize
4649
4750
4851class ClipSequence :
@@ -790,14 +793,17 @@ class BERTTokenizer:
790793
791794 Parameters
792795 ----------
793- vocab : gluonnlp.Vocab or None, default None
796+ vocab
794797 Vocabulary for the corpus.
795- lower : bool, default True
798+ lower
796799 whether the text strips accents and convert to lower case.
797800 If you use the BERT pre-training model,
798801 lower is set to Flase when using the cased model,
799802 otherwise it is set to True.
800- max_input_chars_per_word : int, default 200
803+ max_input_chars_per_word
804+ lru_cache_size
805+ Maximum size of a least-recently-used cache to speed up tokenization.
806+ Use size of 2**20 for example.
801807
802808 Examples
803809 --------
@@ -812,10 +818,14 @@ class BERTTokenizer:
812818
813819 _special_prefix = '##'
814820
815- def __init__ (self , vocab , lower = True , max_input_chars_per_word = 200 ):
821+ def __init__ (self , vocab : Vocab , lower : bool = True , max_input_chars_per_word : int = 200 ,
822+ lru_cache_size : Optional [int ] = None ):
816823 self .vocab = vocab
817824 self .max_input_chars_per_word = max_input_chars_per_word
818825 self .basic_tokenizer = BERTBasicTokenizer (lower = lower )
826+ if lru_cache_size :
827+ self ._word_to_wordpiece_optimized = functools .lru_cache (maxsize = lru_cache_size )(
828+ self ._word_to_wordpiece_optimized )
819829
820830 def __call__ (self , sample ):
821831 """
@@ -841,6 +851,10 @@ def _tokenizer(self, text):
841851
842852 return split_tokens
843853
854+ def _word_to_wordpiece_optimized (self , text ): # pylint: disable=method-hidden
855+ return wordpiece_tokenize (text , self .vocab , self .vocab .unknown_token ,
856+ self .max_input_chars_per_word )
857+
844858 def _tokenize_wordpiece (self , text ):
845859 """Tokenizes a piece of text into its word pieces.
846860
@@ -861,35 +875,14 @@ def _tokenize_wordpiece(self, text):
861875 ret : A list of wordpiece tokens.
862876 """
863877
878+ # case where text is a single token
879+ whitespace_tokenized_tokens = self .basic_tokenizer ._whitespace_tokenize (text )
880+ if len (whitespace_tokenized_tokens ) == 1 :
881+ return self ._word_to_wordpiece_optimized (whitespace_tokenized_tokens [0 ])
882+
864883 output_tokens = []
865- for token in self .basic_tokenizer ._whitespace_tokenize (text ):
866- chars = list (token )
867- if len (chars ) > self .max_input_chars_per_word :
868- output_tokens .append (self .vocab .unknown_token )
869- continue
870- is_bad = False
871- start = 0
872- sub_tokens = []
873- while start < len (chars ):
874- end = len (chars )
875- cur_substr = None
876- while start < end :
877- substr = '' .join (chars [start :end ])
878- if start > 0 :
879- substr = self ._special_prefix + substr
880- if substr in self .vocab :
881- cur_substr = substr
882- break
883- end -= 1
884- if cur_substr is None :
885- is_bad = True
886- break
887- sub_tokens .append (cur_substr )
888- start = end
889- if is_bad :
890- output_tokens .append (self .vocab .unknown_token )
891- else :
892- output_tokens .extend (sub_tokens )
884+ for token in whitespace_tokenized_tokens :
885+ output_tokens .extend (self ._word_to_wordpiece_optimized (token ))
893886 return output_tokens
894887
895888 def convert_tokens_to_ids (self , tokens ):
0 commit comments