Skip to content

Commit 7fe0f98

Browse files
committed
Fix byte pair detokenization of 2d arrays (#423)
Before this fix, the detokenize function would squish everything down into a single string. So it would not preserve the structure of what was passed in.
1 parent 82e4914 commit 7fe0f98

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

keras_nlp/tokenizers/byte_pair_tokenizer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,16 +516,13 @@ def detokenize(self, inputs):
516516
inputs = tf.expand_dims(inputs, 0)
517517

518518
unicode_text = tf.strings.reduce_join(
519-
self.id_to_token_map.lookup(inputs), axis=1
519+
self.id_to_token_map.lookup(inputs), axis=-1
520520
)
521521
split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
522522
byte_text = tf.strings.reduce_join(
523-
self.unicode2byte.lookup(split_unicode_text)
523+
self.unicode2byte.lookup(split_unicode_text), axis=-1
524524
)
525525

526-
if not scalar_input:
527-
byte_text = tf.expand_dims(byte_text, 0)
528-
529526
return byte_text
530527

531528
def _transform_bytes(self, tokens):

keras_nlp/tokenizers/byte_pair_tokenizer_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,14 @@ def test_tokenize_scalar_input(self):
5757
encoded = self.tokenizer.tokenize(input_data)
5858
self.assertAllEqual(encoded, [31876, 4])
5959

60-
def test_detokenize(self):
61-
input_data = ["brown."]
60+
def test_detokenize_scalar_input(self):
61+
input_data = ["quick brown fox."]
62+
encoded = self.tokenizer.tokenize(input_data)
63+
decoded = self.tokenizer.detokenize(encoded)
64+
self.assertAllEqual(input_data, decoded)
65+
66+
def test_detokenize_list_input(self):
67+
input_data = ["quick brown fox.", "slow black bear."]
6268
encoded = self.tokenizer.tokenize(input_data)
6369
decoded = self.tokenizer.detokenize(encoded)
6470
self.assertAllEqual(input_data, decoded)

0 commit comments

Comments
 (0)