Skip to content

Commit f75965f

Browse files
committed
Fix device scope issues (#1841)
We want to always place tf ops on a GPU device, this broke.
1 parent 0f35d5e commit f75965f

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed

keras_nlp/src/tokenizers/byte_pair_tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def tokenize(self, inputs):
540540
if self.add_prefix_space:
541541
inputs = tf.strings.join([" ", inputs])
542542

543+
inputs = tf.convert_to_tensor(inputs)
543544
unbatched = inputs.shape.rank == 0
544545
if unbatched:
545546
inputs = tf.expand_dims(inputs, 0)

keras_nlp/src/tokenizers/sentence_piece_tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _check_vocabulary(self):
238238
@preprocessing_function
239239
def tokenize(self, inputs):
240240
self._check_vocabulary()
241+
inputs = tf.convert_to_tensor(inputs)
241242
unbatched = inputs.shape.rank == 0
242243
if unbatched:
243244
inputs = tf.expand_dims(inputs, 0)

keras_nlp/src/tokenizers/word_piece_tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def _check_vocabulary(self):
473473
@preprocessing_function
474474
def tokenize(self, inputs):
475475
self._check_vocabulary()
476+
inputs = tf.convert_to_tensor(inputs)
476477
unbatched = inputs.shape.rank == 0
477478
pattern = None
478479
if self.split and self.special_tokens_in_strings:

keras_nlp/src/utils/tensor_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,29 @@ def preprocessing_function(fn):
5353

5454
params = inspect.signature(fn).parameters
5555
accepts_labels = all(k in params for k in ("x", "y", "sample_weight"))
56-
with tf.device("cpu"):
57-
if not accepts_labels:
56+
if not accepts_labels:
5857

59-
@functools.wraps(fn)
60-
def wrapper(self, x, **kwargs):
58+
@functools.wraps(fn)
59+
def wrapper(self, x, **kwargs):
60+
with tf.device("cpu"):
6161
x = convert_preprocessing_inputs(x)
6262
with no_convert_scope():
6363
x = fn(self, x, **kwargs)
6464
return convert_preprocessing_outputs(x)
6565

66-
else:
66+
else:
6767

68-
@functools.wraps(fn)
69-
def wrapper(self, x, y=None, sample_weight=None, **kwargs):
68+
@functools.wraps(fn)
69+
def wrapper(self, x, y=None, sample_weight=None, **kwargs):
70+
with tf.device("cpu"):
7071
x, y, sample_weight = convert_preprocessing_inputs(
7172
(x, y, sample_weight)
7273
)
7374
with no_convert_scope():
7475
x = fn(self, x, y=y, sample_weight=sample_weight, **kwargs)
7576
return convert_preprocessing_outputs(x)
7677

77-
return wrapper
78+
return wrapper
7879

7980

8081
def convert_preprocessing_inputs(x):

keras_nlp/src/utils/tensor_utils_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
from keras_nlp.src.utils.tensor_utils import convert_preprocessing_outputs
2424
from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch
2525
from keras_nlp.src.utils.tensor_utils import is_tensor_type
26+
from keras_nlp.src.utils.tensor_utils import preprocessing_function
2627
from keras_nlp.src.utils.tensor_utils import tensor_to_list
2728

2829

2930
class ConvertHelpers(TestCase):
3031
def test_basics(self):
31-
inputs = ops.array([1, 2, 3])
32+
inputs = [1, 2, 3]
3233
# Convert to tf.
3334
outputs = convert_preprocessing_inputs(inputs)
3435
self.assertAllEqual(outputs, ops.array(inputs))
@@ -92,6 +93,18 @@ def to_list(x):
9293
inputs = tree.flatten(tree.map_structure(to_list, inputs))
9394
self.assertAllEqual(outputs, inputs)
9495

96+
def test_placement(self):
97+
# Make sure we always place preprocessing on the CPU on all backends.
98+
@preprocessing_function
99+
def test(self, inputs):
100+
for x in inputs:
101+
if isinstance(x, tf.Tensor):
102+
self.assertTrue("CPU" in x.device)
103+
self.assertFalse("GPU" in x.device)
104+
return inputs
105+
106+
test(self, ([1, 2, 3], ["foo", "bar"], "foo"))
107+
95108

96109
class TensorToListTest(TestCase):
97110
def test_ragged_input(self):

0 commit comments

Comments
 (0)