|
15 | 15 | import tensorflow as tf |
16 | 16 | import tensorflow_text as tf_text |
17 | 17 |
|
| 18 | +from keras_cv.utils.conditional_imports import assert_keras_nlp_installed |
| 19 | + |
18 | 20 | try: |
19 | 21 | from keras_nlp.tokenizers import BytePairTokenizer |
20 | 22 | except ImportError: |
21 | | - BytePairTokenizer = None |
| 23 | + BytePairTokenizer = object |
22 | 24 |
|
23 | 25 | # As python and TF handles special spaces differently, we need to |
24 | 26 | # manually handle special spaces during string split. |
@@ -103,83 +105,80 @@ def remove_strings_from_inputs(tensor, string_to_remove): |
103 | 105 | return result |
104 | 106 |
|
105 | 107 |
|
106 | | -if BytePairTokenizer: |
107 | | - class CLIPTokenizer(BytePairTokenizer): |
108 | | - def __init__(self, **kwargs): |
109 | | - super().__init__(**kwargs) |
110 | | - |
111 | | - def _bpe_merge_and_update_cache(self, tokens): |
112 | | - """Process unseen tokens and add to cache.""" |
113 | | - words = self._transform_bytes(tokens) |
114 | | - tokenized_words = self._bpe_merge(words) |
115 | | - |
116 | | - # For each word, join all its token by a whitespace, |
117 | | - # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. |
118 | | - tokenized_words = tf.strings.reduce_join( |
119 | | - tokenized_words, |
120 | | - axis=1, |
121 | | - ) |
122 | | - self.cache.insert(tokens, tokenized_words) |
123 | | - |
124 | | - def tokenize(self, inputs): |
125 | | - self._check_vocabulary() |
126 | | - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
127 | | - inputs = tf.convert_to_tensor(inputs) |
128 | | - |
129 | | - if self.add_prefix_space: |
130 | | - inputs = tf.strings.join([" ", inputs]) |
131 | | - |
132 | | - scalar_input = inputs.shape.rank == 0 |
133 | | - if scalar_input: |
134 | | - inputs = tf.expand_dims(inputs, 0) |
135 | | - |
136 | | - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) |
137 | | - token_row_splits = raw_tokens.row_splits |
138 | | - flat_tokens = raw_tokens.flat_values |
139 | | - # Check cache. |
140 | | - cache_lookup = self.cache.lookup(flat_tokens) |
141 | | - cache_mask = cache_lookup == "" |
142 | | - |
143 | | - has_unseen_words = tf.math.reduce_any( |
144 | | - (cache_lookup == "") & (flat_tokens != "") |
145 | | - ) |
146 | | - |
147 | | - def process_unseen_tokens(): |
148 | | - unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) |
149 | | - self._bpe_merge_and_update_cache(unseen_tokens) |
150 | | - return self.cache.lookup(flat_tokens) |
151 | | - |
152 | | - # If `has_unseen_words == True`, it means not all tokens are, |
153 | | - # in cache we will process the unseen tokens. Otherwise |
154 | | - # return the cache lookup. |
155 | | - tokenized_words = tf.cond( |
156 | | - has_unseen_words, |
157 | | - process_unseen_tokens, |
158 | | - lambda: cache_lookup, |
159 | | - ) |
160 | | - tokens = tf.strings.split(tokenized_words, sep=" ") |
161 | | - if self.compute_dtype != tf.string: |
162 | | - # Encode merged tokens. |
163 | | - tokens = self.token_to_id_map.lookup(tokens) |
164 | | - |
165 | | - # Unflatten to match input. |
166 | | - tokens = tf.RaggedTensor.from_row_splits( |
167 | | - tokens.flat_values, |
168 | | - tf.gather(tokens.row_splits, token_row_splits), |
169 | | - ) |
170 | | - |
171 | | - # Convert to a dense output if `sequence_length` is set. |
172 | | - if self.sequence_length: |
173 | | - output_shape = tokens.shape.as_list() |
174 | | - output_shape[-1] = self.sequence_length |
175 | | - tokens = tokens.to_tensor(shape=output_shape) |
176 | | - |
177 | | - # Convert to a dense output if input in scalar |
178 | | - if scalar_input: |
179 | | - tokens = tf.squeeze(tokens, 0) |
180 | | - tf.ensure_shape(tokens, shape=[self.sequence_length]) |
181 | | - |
182 | | - return tokens |
183 | | - |
184 | | -else: |
185 | | - CLIPTokenizer = None |
| 108 | +class CLIPTokenizer(BytePairTokenizer): |
| 109 | + def __init__(self, **kwargs): |
| 110 | + assert_keras_nlp_installed("CLIPTokenizer") |
| 111 | + super().__init__(**kwargs) |
| 112 | + |
| 113 | + def _bpe_merge_and_update_cache(self, tokens): |
| 114 | + """Process unseen tokens and add to cache.""" |
| 115 | + words = self._transform_bytes(tokens) |
| 116 | + tokenized_words = self._bpe_merge(words) |
| 117 | + |
| 118 | + # For each word, join all its token by a whitespace, |
| 119 | + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. |
| 120 | + tokenized_words = tf.strings.reduce_join( |
| 121 | + tokenized_words, |
| 122 | + axis=1, |
| 123 | + ) |
| 124 | + self.cache.insert(tokens, tokenized_words) |
| 125 | + |
| 126 | + def tokenize(self, inputs): |
| 127 | + self._check_vocabulary() |
| 128 | + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
| 129 | + inputs = tf.convert_to_tensor(inputs) |
| 130 | + |
| 131 | + if self.add_prefix_space: |
| 132 | + inputs = tf.strings.join([" ", inputs]) |
| 133 | + |
| 134 | + scalar_input = inputs.shape.rank == 0 |
| 135 | + if scalar_input: |
| 136 | + inputs = tf.expand_dims(inputs, 0) |
| 137 | + |
| 138 | + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) |
| 139 | + token_row_splits = raw_tokens.row_splits |
| 140 | + flat_tokens = raw_tokens.flat_values |
| 141 | + # Check cache. |
| 142 | + cache_lookup = self.cache.lookup(flat_tokens) |
| 143 | + cache_mask = cache_lookup == "" |
| 144 | + |
| 145 | + has_unseen_words = tf.math.reduce_any( |
| 146 | + (cache_lookup == "") & (flat_tokens != "") |
| 147 | + ) |
| 148 | + |
| 149 | + def process_unseen_tokens(): |
| 150 | + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) |
| 151 | + self._bpe_merge_and_update_cache(unseen_tokens) |
| 152 | + return self.cache.lookup(flat_tokens) |
| 153 | + |
| 154 | + # If `has_unseen_words == True`, it means not all tokens are, |
| 155 | + # in cache we will process the unseen tokens. Otherwise |
| 156 | + # return the cache lookup. |
| 157 | + tokenized_words = tf.cond( |
| 158 | + has_unseen_words, |
| 159 | + process_unseen_tokens, |
| 160 | + lambda: cache_lookup, |
| 161 | + ) |
| 162 | + tokens = tf.strings.split(tokenized_words, sep=" ") |
| 163 | + if self.compute_dtype != tf.string: |
| 164 | + # Encode merged tokens. |
| 165 | + tokens = self.token_to_id_map.lookup(tokens) |
| 166 | + |
| 167 | + # Unflatten to match input. |
| 168 | + tokens = tf.RaggedTensor.from_row_splits( |
| 169 | + tokens.flat_values, |
| 170 | + tf.gather(tokens.row_splits, token_row_splits), |
| 171 | + ) |
| 172 | + |
| 173 | + # Convert to a dense output if `sequence_length` is set. |
| 174 | + if self.sequence_length: |
| 175 | + output_shape = tokens.shape.as_list() |
| 176 | + output_shape[-1] = self.sequence_length |
| 177 | + tokens = tokens.to_tensor(shape=output_shape) |
| 178 | + |
| 179 | + # Convert to a dense output if input in scalar |
| 180 | + if scalar_input: |
| 181 | + tokens = tf.squeeze(tokens, 0) |
| 182 | + tf.ensure_shape(tokens, shape=[self.sequence_length]) |
| 183 | + |
| 184 | + return tokens |
0 commit comments