Skip to content

Commit 7e678e4

Browse files
authored
Don't accept a string dtype for unicode tokenizer (#147)
* Don't accept a string dtype for unicode tokenizer This tokenizer cannot output strings. * cast the output to the layer dtype
1 parent 4b7a417 commit 7e678e4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

keras_nlp/tokenizers/unicode_character_tokenizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ def __init__(
175175
kwargs["dtype"] = tf.int32
176176
else:
177177
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
178-
if not dtype.is_integer and dtype != tf.string:
178+
if not dtype.is_integer:
179179
raise ValueError(
180-
"Output dtype must be an integer type of a string. "
180+
"Output dtype must be an integer type. "
181181
f"Received: dtype={dtype}"
182182
)
183183

@@ -251,6 +251,7 @@ def tokenize(self, inputs):
251251
replacement_char=self.replacement_char,
252252
input_encoding=self.input_encoding,
253253
)
254+
tokens = tf.cast(tokens, self.compute_dtype)
254255

255256
if self.sequence_length:
256257
output_shape = tokens.shape.as_list()

0 commit comments

Comments
 (0)