Skip to content

Commit 825b192

Browse files
authored
Fix lowercase bug in wordpiece tokenizer (#1543)
* Fix lowercase bug * Add a comment to explain * Change mask builder * Revert "Change mask builder" This reverts commit 5c9f61e.
1 parent 4b6970c commit 825b192

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

keras_nlp/tokenizers/word_piece_tokenizer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def pretokenize(
151151
text = tf.expand_dims(text, 0)
152152
if split_on_cjk and split:
153153
text = tf.strings.regex_replace(text, CJK_REGEX, r" \0 ")
154-
if lowercase:
155-
text = tf_text.case_fold_utf8(text)
156154
if strip_accents:
157155
# Normalize unicode to NFD, which splits out accent mark characters.
158156
text = tf_text.normalize_utf8(text, "NFD")
@@ -187,6 +185,18 @@ def pretokenize(
187185
delim_regex_pattern=split_pattern,
188186
keep_delim_regex_pattern=keep_split_pattern,
189187
)
188+
if lowercase:
189+
if special_tokens_pattern is not None:
190+
# Do not lowercase special tokens in string space. They often
191+
# contain capital letters, e.g. `"[CLS]"`.
192+
mask = (
193+
tf.strings.regex_replace(text, special_tokens_pattern, "६")
194+
== "६"
195+
)
196+
text = tf.where(mask, text, tf_text.case_fold_utf8(text))
197+
else:
198+
text = tf_text.case_fold_utf8(text)
199+
190200
return text
191201

192202

keras_nlp/tokenizers/word_piece_tokenizer_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,21 @@ def test_special_tokens_int_dtype(self):
110110
output = tokenizer(input_data)
111111
self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])
112112

113+
def test_special_tokens_with_lowecase(self):
114+
input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] THE QUICK BROWN FOX."]
115+
special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"]
116+
vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."]
117+
vocab_data = [*special_tokens, *vocab_data]
118+
119+
tokenizer = WordPieceTokenizer(
120+
vocabulary=vocab_data,
121+
lowercase=True,
122+
special_tokens=special_tokens,
123+
special_tokens_in_strings=True,
124+
)
125+
output = tokenizer(input_data)
126+
self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])
127+
113128
def test_cjk_tokens(self):
114129
input_data = ["ah半推zz"]
115130
vocab_data = ["[UNK]", "推", "敐", "乐", "半", "偷", "匕", "ah", "zz"]

0 commit comments

Comments
 (0)