@@ -80,7 +80,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
8080
8181 If a more custom pre-tokenization step is desired, the layer can be
8282 configured to apply only the strict WordPiece algorithm by passing
83- `lowercase=False`, `strip_accents=False` and `split_pattern=None `. In
83+ `lowercase=False`, `strip_accents=False` and `split=False `. In
8484 this case, inputs should be pre-split string tensors or ragged tensors.
8585
8686 By default, the layer will output a `tf.RaggedTensor` where the last
@@ -101,10 +101,11 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
101101 tokenization.
102102 strip_accents: If true, all accent marks will be removed from text
103103 before tokenization.
104- split_pattern: A regex pattern to match delimiters to split, or None
105- indicating that the input is pre-split and no splitting should be
106- performed. By default, all whitespace and punctuation marks will
107- be split on.
104+ split: If true, input will be split according to `split_pattern`
105+ and `keep_pattern`. If false, input should be split before calling
106+ the layer.
107+ split_pattern: A regex pattern to match delimiters to split. By default,
108+ all whitespace and punctuation marks will be split on.
108109 keep_pattern: A regex pattern of delimiters contained in the
109110 `split_pattern` of delimeters that should be kept as independent
110111 tokens. By default, all punctuation marks will be kept as tokens.
@@ -167,8 +168,9 @@ def __init__(
167168 sequence_length : int = None ,
168169 lowercase : bool = True ,
169170 strip_accents : bool = True ,
170- split_pattern : str = WHITESPACE_AND_PUNCTUATION_REGEX ,
171- keep_pattern : str = PUNCTUATION_REGEX ,
171+ split : bool = True ,
172+ split_pattern : str = None ,
173+ keep_pattern : str = None ,
172174 suffix_indicator : str = "##" ,
173175 oov_token : str = "[UNK]" ,
174176 ** kwargs ,
@@ -201,9 +203,16 @@ def __init__(
201203 if oov_token is None :
202204 raise ValueError ("`oov_token` cannot be None." )
203205
206+ if split_pattern is None :
207+ split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
208+
209+ if keep_pattern is None :
210+ keep_pattern = PUNCTUATION_REGEX
211+
204212 self .sequence_length = sequence_length
205213 self .lowercase = lowercase
206214 self .strip_accents = strip_accents
215+ self .split = split
207216 self .split_pattern = split_pattern
208217 self .keep_pattern = keep_pattern
209218 self .suffix_indicator = suffix_indicator
@@ -257,6 +266,7 @@ def get_config(self) -> Dict[str, Any]:
257266 "sequence_length" : self .sequence_length ,
258267 "lowercase" : self .lowercase ,
259268 "strip_accents" : self .strip_accents ,
269+ "split" : self .split ,
260270 "split_pattern" : self .split_pattern ,
261271 "keep_pattern" : self .keep_pattern ,
262272 "suffix_indicator" : self .suffix_indicator ,
@@ -280,7 +290,7 @@ def tokenize(self, inputs):
280290 inputs = tf_text .normalize_utf8 (inputs , "NFD" )
281291 # Remove the accent marks.
282292 inputs = tf .strings .regex_replace (inputs , r"\p{Mn}" , "" )
283- if self .split_pattern :
293+ if self .split :
284294 inputs = tf_text .regex_split (
285295 inputs ,
286296 delim_regex_pattern = self .split_pattern ,
0 commit comments