Skip to content

Commit 95e5ee5

Browse files
committed
Fix BytePairTokenizer
1 parent 658ded0 commit 95e5ee5

File tree

1 file changed

+80
-81
lines changed

1 file changed

+80
-81
lines changed

keras_cv/models/feature_extractor/clip/clip_tokenizer.py

Lines changed: 80 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
import tensorflow as tf
1616
import tensorflow_text as tf_text
1717

18+
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
19+
1820
try:
1921
from keras_nlp.tokenizers import BytePairTokenizer
2022
except ImportError:
21-
BytePairTokenizer = None
23+
BytePairTokenizer = object
2224

2325
# As python and TF handles special spaces differently, we need to
2426
# manually handle special spaces during string split.
@@ -103,83 +105,80 @@ def remove_strings_from_inputs(tensor, string_to_remove):
103105
return result
104106

105107

106-
if BytePairTokenizer:
107-
class CLIPTokenizer(BytePairTokenizer):
108-
def __init__(self, **kwargs):
109-
super().__init__(**kwargs)
110-
111-
def _bpe_merge_and_update_cache(self, tokens):
112-
"""Process unseen tokens and add to cache."""
113-
words = self._transform_bytes(tokens)
114-
tokenized_words = self._bpe_merge(words)
115-
116-
# For each word, join all its token by a whitespace,
117-
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
118-
tokenized_words = tf.strings.reduce_join(
119-
tokenized_words,
120-
axis=1,
121-
)
122-
self.cache.insert(tokens, tokenized_words)
123-
124-
def tokenize(self, inputs):
125-
self._check_vocabulary()
126-
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
127-
inputs = tf.convert_to_tensor(inputs)
128-
129-
if self.add_prefix_space:
130-
inputs = tf.strings.join([" ", inputs])
131-
132-
scalar_input = inputs.shape.rank == 0
133-
if scalar_input:
134-
inputs = tf.expand_dims(inputs, 0)
135-
136-
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
137-
token_row_splits = raw_tokens.row_splits
138-
flat_tokens = raw_tokens.flat_values
139-
# Check cache.
140-
cache_lookup = self.cache.lookup(flat_tokens)
141-
cache_mask = cache_lookup == ""
142-
143-
has_unseen_words = tf.math.reduce_any(
144-
(cache_lookup == "") & (flat_tokens != "")
145-
)
146-
147-
def process_unseen_tokens():
148-
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
149-
self._bpe_merge_and_update_cache(unseen_tokens)
150-
return self.cache.lookup(flat_tokens)
151-
152-
# If `has_unseen_words == True`, it means not all tokens are,
153-
# in cache we will process the unseen tokens. Otherwise
154-
# return the cache lookup.
155-
tokenized_words = tf.cond(
156-
has_unseen_words,
157-
process_unseen_tokens,
158-
lambda: cache_lookup,
159-
)
160-
tokens = tf.strings.split(tokenized_words, sep=" ")
161-
if self.compute_dtype != tf.string:
162-
# Encode merged tokens.
163-
tokens = self.token_to_id_map.lookup(tokens)
164-
165-
# Unflatten to match input.
166-
tokens = tf.RaggedTensor.from_row_splits(
167-
tokens.flat_values,
168-
tf.gather(tokens.row_splits, token_row_splits),
169-
)
170-
171-
# Convert to a dense output if `sequence_length` is set.
172-
if self.sequence_length:
173-
output_shape = tokens.shape.as_list()
174-
output_shape[-1] = self.sequence_length
175-
tokens = tokens.to_tensor(shape=output_shape)
176-
177-
# Convert to a dense output if input in scalar
178-
if scalar_input:
179-
tokens = tf.squeeze(tokens, 0)
180-
tf.ensure_shape(tokens, shape=[self.sequence_length])
181-
182-
return tokens
183-
184-
else:
185-
CLIPTokenizer = None
108+
class CLIPTokenizer(BytePairTokenizer):
109+
def __init__(self, **kwargs):
110+
assert_keras_nlp_installed("CLIPTokenizer")
111+
super().__init__(**kwargs)
112+
113+
def _bpe_merge_and_update_cache(self, tokens):
114+
"""Process unseen tokens and add to cache."""
115+
words = self._transform_bytes(tokens)
116+
tokenized_words = self._bpe_merge(words)
117+
118+
# For each word, join all its token by a whitespace,
119+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
120+
tokenized_words = tf.strings.reduce_join(
121+
tokenized_words,
122+
axis=1,
123+
)
124+
self.cache.insert(tokens, tokenized_words)
125+
126+
def tokenize(self, inputs):
127+
self._check_vocabulary()
128+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
129+
inputs = tf.convert_to_tensor(inputs)
130+
131+
if self.add_prefix_space:
132+
inputs = tf.strings.join([" ", inputs])
133+
134+
scalar_input = inputs.shape.rank == 0
135+
if scalar_input:
136+
inputs = tf.expand_dims(inputs, 0)
137+
138+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
139+
token_row_splits = raw_tokens.row_splits
140+
flat_tokens = raw_tokens.flat_values
141+
# Check cache.
142+
cache_lookup = self.cache.lookup(flat_tokens)
143+
cache_mask = cache_lookup == ""
144+
145+
has_unseen_words = tf.math.reduce_any(
146+
(cache_lookup == "") & (flat_tokens != "")
147+
)
148+
149+
def process_unseen_tokens():
150+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
151+
self._bpe_merge_and_update_cache(unseen_tokens)
152+
return self.cache.lookup(flat_tokens)
153+
154+
# If `has_unseen_words == True`, it means not all tokens are,
155+
# in cache we will process the unseen tokens. Otherwise
156+
# return the cache lookup.
157+
tokenized_words = tf.cond(
158+
has_unseen_words,
159+
process_unseen_tokens,
160+
lambda: cache_lookup,
161+
)
162+
tokens = tf.strings.split(tokenized_words, sep=" ")
163+
if self.compute_dtype != tf.string:
164+
# Encode merged tokens.
165+
tokens = self.token_to_id_map.lookup(tokens)
166+
167+
# Unflatten to match input.
168+
tokens = tf.RaggedTensor.from_row_splits(
169+
tokens.flat_values,
170+
tf.gather(tokens.row_splits, token_row_splits),
171+
)
172+
173+
# Convert to a dense output if `sequence_length` is set.
174+
if self.sequence_length:
175+
output_shape = tokens.shape.as_list()
176+
output_shape[-1] = self.sequence_length
177+
tokens = tokens.to_tensor(shape=output_shape)
178+
179+
# Convert to a dense output if input in scalar
180+
if scalar_input:
181+
tokens = tf.squeeze(tokens, 0)
182+
tf.ensure_shape(tokens, shape=[self.sequence_length])
183+
184+
return tokens

0 commit comments

Comments
 (0)