Skip to content

Commit 375082e

Browse files
aflah02mattdangerw
andauthored
Adding Utility to Detokenize as list of Strings to Tokenizer Base Class (#124)
* Added Functions to Base Class * Tightened Logic started Work on Tests * Added tests * Updated Docstring * Fixing Tokenizer * Fixed Broken Tests * Ran format and lint * Fix docstring summary to fit on single line Adds a little more description as well * Remove trailing whitespace * fix * Ported tensor_to_string_list to tensor_utils Co-authored-by: Matt Watson <[email protected]> Co-authored-by: Matt Watson <[email protected]>
1 parent 7e678e4 commit 375082e

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

keras_nlp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
from keras_nlp import layers
1616
from keras_nlp import metrics
1717
from keras_nlp import tokenizers
18+
from keras_nlp import utils
1819

1920
__version__ = "0.2.0-dev.1"

keras_nlp/utils/tensor_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
17+
18+
def _decode_strings_to_utf8(inputs):
19+
"""Recursively decodes to list of strings with 'utf-8' encoding."""
20+
if isinstance(inputs, bytes):
21+
# Handles the case when the input is a scalar string.
22+
return inputs.decode("utf-8")
23+
else:
24+
# Recursively iterate when input is a list.
25+
return [_decode_strings_to_utf8(x) for x in inputs]
26+
27+
28+
def tensor_to_string_list(inputs):
29+
"""Detokenize and convert tensor to nested lists of python strings.
30+
31+
This is a convenience method which converts each byte string to a python
32+
string.
33+
34+
Args:
35+
inputs: Input tensor, or dict/list/tuple of input tensors.
36+
*args: Additional positional arguments.
37+
**kwargs: Additional keyword arguments.
38+
"""
39+
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
40+
inputs = tf.convert_to_tensor(inputs)
41+
if isinstance(inputs, tf.RaggedTensor):
42+
list_outputs = inputs.to_list()
43+
elif isinstance(inputs, tf.Tensor):
44+
list_outputs = inputs.numpy()
45+
if inputs.shape.rank != 0:
46+
list_outputs = list_outputs.tolist()
47+
return _decode_strings_to_utf8(list_outputs)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from tensor_utils import tensor_to_string_list
17+
18+
19+
class TensorToStringListTest(tf.test.TestCase):
20+
def test_detokenize_to_strings_for_ragged(self):
21+
input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]])
22+
detokenize_output = tensor_to_string_list(input_data)
23+
self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]])
24+
25+
def test_detokenize_to_strings_for_dense(self):
26+
input_data = tf.constant([["▀▁▂▃", "samurai"]])
27+
detokenize_output = tensor_to_string_list(input_data)
28+
self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]])
29+
30+
def test_detokenize_to_strings_for_scalar(self):
31+
input_data = tf.constant("▀▁▂▃")
32+
detokenize_output = tensor_to_string_list(input_data)
33+
self.assertEqual(detokenize_output, "▀▁▂▃")

0 commit comments

Comments
 (0)