Skip to content

Commit e988786

Browse files
authored
Fix: StringLookup returns torch native types for torch backend (#21614)
* Fix: StringLookup returns torch native types for torch backend * Formatting and making logic clean * Backend other than tensorflow and pytorch * fixed backend other than torch and tensorflow
1 parent 124a258 commit e988786

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

keras/src/layers/preprocessing/string_lookup.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from keras.src.utils import backend_utils
77
from keras.src.utils.module_utils import tensorflow as tf
88

9+
if backend.backend() == "torch":
10+
import torch
11+
912

1013
@keras_export("keras.layers.StringLookup")
1114
class StringLookup(IndexLookup):
@@ -382,13 +385,39 @@ def get_config(self):
382385
return {**base_config, **config}
383386

384387
def call(self, inputs):
385-
if isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
386-
tf_inputs = True
387-
else:
388-
tf_inputs = False
389-
if not isinstance(inputs, (np.ndarray, list, tuple)):
390-
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
391-
outputs = super().call(inputs)
392-
if not tf_inputs:
393-
outputs = backend_utils.convert_tf_tensor(outputs)
394-
return outputs
388+
is_torch_backend = backend.backend() == "torch"
389+
390+
# Handle input conversion
391+
inputs_for_processing = inputs
392+
was_tf_input = isinstance(
393+
inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)
394+
)
395+
396+
if is_torch_backend and isinstance(inputs, torch.Tensor):
397+
inputs_for_processing = tf.convert_to_tensor(
398+
inputs.detach().cpu().numpy()
399+
)
400+
elif isinstance(inputs, (np.ndarray, list, tuple)):
401+
inputs_for_processing = tf.convert_to_tensor(inputs)
402+
elif not was_tf_input:
403+
inputs_for_processing = tf.convert_to_tensor(
404+
backend.convert_to_numpy(inputs)
405+
)
406+
407+
output = super().call(inputs_for_processing)
408+
409+
# Handle torch backend output conversion
410+
if is_torch_backend and isinstance(
411+
inputs, (torch.Tensor, np.ndarray, list, tuple)
412+
):
413+
numpy_outputs = output.numpy()
414+
if self.invert:
415+
return [n.decode(self.encoding) for n in numpy_outputs]
416+
else:
417+
return torch.from_numpy(numpy_outputs)
418+
419+
# other backends
420+
if not was_tf_input:
421+
output = backend_utils.convert_tf_tensor(output)
422+
423+
return output

keras/src/layers/preprocessing/string_lookup_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,33 @@ def test_tensor_as_vocab(self):
8989
)
9090
output = layer(data)
9191
self.assertAllClose(output, np.array([[1, 3, 4], [4, 0, 2]]))
92+
93+
@pytest.mark.skipif(backend.backend() != "torch", reason="Only torch")
94+
def test_torch_backend_compatibility(self):
95+
import torch
96+
97+
# Forward lookup: String -> number
98+
forward_lookup = layers.StringLookup(
99+
vocabulary=["a", "b", "c"], oov_token="[OOV]"
100+
)
101+
input_data_str = ["a", "b", "[OOV]", "d"]
102+
output_numeric = forward_lookup(input_data_str)
103+
104+
# assert instance of output is torch.Tensor
105+
self.assertIsInstance(output_numeric, torch.Tensor)
106+
expected_numeric = torch.tensor([1, 2, 0, 0])
107+
self.assertAllClose(output_numeric.cpu(), expected_numeric)
108+
109+
oov = "[OOV]"
110+
# Inverse lookup: Number -> string
111+
inverse_lookup = layers.StringLookup(
112+
vocabulary=["a", "b", "c"], oov_token=oov, invert=True
113+
)
114+
input_data_int = torch.tensor([1, 2, 0], dtype=torch.int64)
115+
output_string = inverse_lookup(input_data_int)
116+
# Assert that the output is a list
117+
# See : https://docs.pytorch.org/text/stable/_modules/torchtext/vocab/vocab.html#Vocab.lookup_tokens
118+
# The torch equivalent implementation of this returns a list of strings
119+
self.assertIsInstance(output_string, list)
120+
expected_string = ["a", "b", "[OOV]"]
121+
self.assertEqual(output_string, expected_string)

0 commit comments

Comments
 (0)