|
6 | 6 | from keras.src.utils import backend_utils
|
7 | 7 | from keras.src.utils.module_utils import tensorflow as tf
|
8 | 8 |
|
| 9 | +if backend.backend() == "torch": |
| 10 | + import torch |
| 11 | + |
9 | 12 |
|
10 | 13 | @keras_export("keras.layers.StringLookup")
|
11 | 14 | class StringLookup(IndexLookup):
|
@@ -382,13 +385,39 @@ def get_config(self):
|
382 | 385 | return {**base_config, **config}
|
383 | 386 |
|
384 | 387 | def call(self, inputs):
|
385 |
| - if isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)): |
386 |
| - tf_inputs = True |
387 |
| - else: |
388 |
| - tf_inputs = False |
389 |
| - if not isinstance(inputs, (np.ndarray, list, tuple)): |
390 |
| - inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) |
391 |
| - outputs = super().call(inputs) |
392 |
| - if not tf_inputs: |
393 |
| - outputs = backend_utils.convert_tf_tensor(outputs) |
394 |
| - return outputs |
| 388 | + is_torch_backend = backend.backend() == "torch" |
| 389 | + |
| 390 | + # Handle input conversion |
| 391 | + inputs_for_processing = inputs |
| 392 | + was_tf_input = isinstance( |
| 393 | + inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor) |
| 394 | + ) |
| 395 | + |
| 396 | + if is_torch_backend and isinstance(inputs, torch.Tensor): |
| 397 | + inputs_for_processing = tf.convert_to_tensor( |
| 398 | + inputs.detach().cpu().numpy() |
| 399 | + ) |
| 400 | + elif isinstance(inputs, (np.ndarray, list, tuple)): |
| 401 | + inputs_for_processing = tf.convert_to_tensor(inputs) |
| 402 | + elif not was_tf_input: |
| 403 | + inputs_for_processing = tf.convert_to_tensor( |
| 404 | + backend.convert_to_numpy(inputs) |
| 405 | + ) |
| 406 | + |
| 407 | + output = super().call(inputs_for_processing) |
| 408 | + |
| 409 | + # Handle torch backend output conversion |
| 410 | + if is_torch_backend and isinstance( |
| 411 | + inputs, (torch.Tensor, np.ndarray, list, tuple) |
| 412 | + ): |
| 413 | + numpy_outputs = output.numpy() |
| 414 | + if self.invert: |
| 415 | + return [n.decode(self.encoding) for n in numpy_outputs] |
| 416 | + else: |
| 417 | + return torch.from_numpy(numpy_outputs) |
| 418 | + |
| 419 | + # other backends |
| 420 | + if not was_tf_input: |
| 421 | + output = backend_utils.convert_tf_tensor(output) |
| 422 | + |
| 423 | + return output |
0 commit comments