Skip to content

Commit e575e98

Browse files
abheesht17mattdangerw
authored andcommitted
Support String Output for BytePairTokenizer (#438)
* Support string output for BytePairTokenizer * Add unit test * Minor edit
1 parent 7fe0f98 commit e575e98

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

keras_nlp/tokenizers/byte_pair_tokenizer.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Byte-pair encoder implementation.
15+
"""Byte-pair encoder implementation.
1616
1717
This file implements the same logic as openai BPE:
1818
https://github.com/openai/gpt-2/blob/master/src/encoder.py,
@@ -159,12 +159,12 @@ def create_static_hashtable(keys, values, default):
159159
class BytePairTokenizer(tokenizer.Tokenizer):
160160
"""Bype-pair encoding tokenizer layer.
161161
162-
This BPE tokenizer provides the same funtionality as official GPT2
162+
This BPE tokenizer provides the same functionality as the official GPT-2
163163
tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
164164
which describes BPE merge rules, it should provide the same output
165-
as openai implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
166-
Different from openai, this implementation is graph-compatible, so you can
167-
use it within a tf.data pipeline.
165+
as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
166+
Different from OpenAI, this implementation is graph-compatible, so you can
167+
use it within a `tf.data` pipeline.
168168
169169
If input is a batch of strings (rank > 0):
170170
By default, the layer will output a `tf.RaggedTensor` where the last
@@ -187,7 +187,7 @@ class BytePairTokenizer(tokenizer.Tokenizer):
187187
188188
Examples:
189189
190-
Use in-momery vocabulary and merge list.
190+
Use in-memory vocabulary and merge list.
191191
192192
>>> vocab = {"butter": 1, "fly": 2}
193193
>>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
@@ -244,7 +244,7 @@ def __init__(
244244
kwargs["dtype"] = tf.int32
245245
else:
246246
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
247-
if not dtype.is_integer:
247+
if not dtype.is_integer and dtype != tf.string:
248248
raise ValueError(
249249
"Output dtype must be an integer type or a string. "
250250
f"Received: `dtype={dtype}`"
@@ -484,28 +484,29 @@ def process_unseen_tokens():
484484
lambda: cache_lookup,
485485
)
486486

487-
# Encode merged tokens.
488-
tokenized_words = tf.strings.split(tokenized_words, sep=" ")
489-
encoding = self.token_to_id_map.lookup(tokenized_words)
487+
tokens = tf.strings.split(tokenized_words, sep=" ")
488+
if self.compute_dtype != tf.string:
489+
# Encode merged tokens.
490+
tokens = self.token_to_id_map.lookup(tokens)
490491

491492
# Unflatten to match input.
492-
encoding = tf.RaggedTensor.from_row_splits(
493-
encoding.flat_values,
494-
tf.gather(encoding.row_splits, token_row_splits),
493+
tokens = tf.RaggedTensor.from_row_splits(
494+
tokens.flat_values,
495+
tf.gather(tokens.row_splits, token_row_splits),
495496
)
496497

497498
# Convert to a dense output if `sequence_length` is set.
498499
if self.sequence_length:
499-
output_shape = encoding.shape.as_list()
500+
output_shape = tokens.shape.as_list()
500501
output_shape[-1] = self.sequence_length
501-
encoding = encoding.to_tensor(shape=output_shape)
502+
tokens = tokens.to_tensor(shape=output_shape)
502503

503504
# Convert to a dense output if input in scalar
504505
if scalar_input:
505-
encoding = tf.squeeze(encoding, 0)
506-
tf.ensure_shape(encoding, shape=[self.sequence_length])
506+
tokens = tf.squeeze(tokens, 0)
507+
tf.ensure_shape(tokens, shape=[self.sequence_length])
507508

508-
return encoding
509+
return tokens
509510

510511
def detokenize(self, inputs):
511512
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):

keras_nlp/tokenizers/byte_pair_tokenizer_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@ def test_tokenize_list_input(self):
5252
encoded = self.tokenizer(input_data)
5353
self.assertAllEqual(encoded, expected)
5454

55+
def test_tokenize_string_output(self):
56+
input_data = ["quick brown fox.", "slow black bear."]
57+
tokenizer = BytePairTokenizer(
58+
vocabulary=VOCAB_PATH, merges=MERGE_PATH, dtype=tf.string
59+
)
60+
call_output = tokenizer(input_data)
61+
expected = tf.ragged.constant(
62+
[
63+
["quick", "Ġbrown", "Ġfox", "."],
64+
["slow", "Ġblack", "Ġbear", "."],
65+
]
66+
)
67+
self.assertAllEqual(call_output, expected)
68+
5569
def test_tokenize_scalar_input(self):
5670
input_data = "brown."
5771
encoded = self.tokenizer.tokenize(input_data)

0 commit comments

Comments
 (0)