Skip to content

Commit 1f39ecc

Browse files
authored
Word piece arg change (#148)
* Add a split_input argument for word piece This way we can have three flags for our pre-tokenization, which should allow for a more uniform experience. We can also remove the large regexes directly for the input arguments, which simplifies our rendered documentation quite a bit. * fix * Another fix * format fix
1 parent 46d1719 commit 1f39ecc

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

keras_nlp/tokenizers/word_piece_tokenizer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
8080
8181
If a more custom pre-tokenization step is desired, the layer can be
8282
configured to apply only the strict WordPiece algorithm by passing
83-
`lowercase=False`, `strip_accents=False` and `split_pattern=None`. In
83+
`lowercase=False`, `strip_accents=False` and `split=False`. In
8484
this case, inputs should be pre-split string tensors or ragged tensors.
8585
8686
By default, the layer will output a `tf.RaggedTensor` where the last
@@ -101,10 +101,11 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
101101
tokenization.
102102
strip_accents: If true, all accent marks will be removed from text
103103
before tokenization.
104-
split_pattern: A regex pattern to match delimiters to split, or None
105-
indicating that the input is pre-split and no splitting should be
106-
performed. By default, all whitespace and punctuation marks will
107-
be split on.
104+
split: If true, input will be split according to `split_pattern`
105+
and `keep_pattern`. If false, input should be split before calling
106+
the layer.
107+
split_pattern: A regex pattern to match delimiters to split. By default,
108+
all whitespace and punctuation marks will be split on.
108109
keep_pattern: A regex pattern of delimiters contained in the
109110
`split_pattern` of delimeters that should be kept as independent
110111
tokens. By default, all punctuation marks will be kept as tokens.
@@ -167,8 +168,9 @@ def __init__(
167168
sequence_length: int = None,
168169
lowercase: bool = True,
169170
strip_accents: bool = True,
170-
split_pattern: str = WHITESPACE_AND_PUNCTUATION_REGEX,
171-
keep_pattern: str = PUNCTUATION_REGEX,
171+
split: bool = True,
172+
split_pattern: str = None,
173+
keep_pattern: str = None,
172174
suffix_indicator: str = "##",
173175
oov_token: str = "[UNK]",
174176
**kwargs,
@@ -201,9 +203,16 @@ def __init__(
201203
if oov_token is None:
202204
raise ValueError("`oov_token` cannot be None.")
203205

206+
if split_pattern is None:
207+
split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
208+
209+
if keep_pattern is None:
210+
keep_pattern = PUNCTUATION_REGEX
211+
204212
self.sequence_length = sequence_length
205213
self.lowercase = lowercase
206214
self.strip_accents = strip_accents
215+
self.split = split
207216
self.split_pattern = split_pattern
208217
self.keep_pattern = keep_pattern
209218
self.suffix_indicator = suffix_indicator
@@ -257,6 +266,7 @@ def get_config(self) -> Dict[str, Any]:
257266
"sequence_length": self.sequence_length,
258267
"lowercase": self.lowercase,
259268
"strip_accents": self.strip_accents,
269+
"split": self.split,
260270
"split_pattern": self.split_pattern,
261271
"keep_pattern": self.keep_pattern,
262272
"suffix_indicator": self.suffix_indicator,
@@ -280,7 +290,7 @@ def tokenize(self, inputs):
280290
inputs = tf_text.normalize_utf8(inputs, "NFD")
281291
# Remove the accent marks.
282292
inputs = tf.strings.regex_replace(inputs, r"\p{Mn}", "")
283-
if self.split_pattern:
293+
if self.split:
284294
inputs = tf_text.regex_split(
285295
inputs,
286296
delim_regex_pattern=self.split_pattern,

keras_nlp/tokenizers/word_piece_tokenizer_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ def test_custom_spliting(self):
135135
def test_no_spliting(self):
136136
input_data = ["t o k e n", "m i s s i n g", "t o k e n"]
137137
vocab_data = ["[UNK]", "t o k e n"]
138-
tokenizer = WordPieceTokenizer(
139-
vocabulary=vocab_data, split_pattern=None
140-
)
138+
tokenizer = WordPieceTokenizer(vocabulary=vocab_data, split=False)
141139
call_output = tokenizer(input_data)
142140
self.assertAllEqual(call_output, [1, 0, 1])
143141

@@ -148,7 +146,7 @@ def test_word_piece_only(self):
148146
vocabulary=vocab_data,
149147
lowercase=False,
150148
strip_accents=False,
151-
split_pattern=None,
149+
split=False,
152150
)
153151
call_output = tokenizer(input_data)
154152
self.assertAllEqual(call_output, [1, 2, 3, 4, 5, 6])

0 commit comments

Comments
 (0)