1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """ Byte-pair encoder implementation.
15+ """Byte-pair encoder implementation.
1616
1717This file implements the same logic as openai BPE:
1818https://github.com/openai/gpt-2/blob/master/src/encoder.py,
@@ -159,12 +159,12 @@ def create_static_hashtable(keys, values, default):
159159class BytePairTokenizer (tokenizer .Tokenizer ):
160160 """Bype-pair encoding tokenizer layer.
161161
162- This BPE tokenizer provides the same funtionality as official GPT2
162+ This BPE tokenizer provides the same functionality as the official GPT-2
163163 tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
164164 which describes BPE merge rules, it should provide the same output
165- as openai implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
166- Different from openai , this implementation is graph-compatible, so you can
167- use it within a tf.data pipeline.
165+ as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
166+ Different from OpenAI , this implementation is graph-compatible, so you can
167+ use it within a ` tf.data` pipeline.
168168
169169 If input is a batch of strings (rank > 0):
170170 By default, the layer will output a `tf.RaggedTensor` where the last
@@ -187,7 +187,7 @@ class BytePairTokenizer(tokenizer.Tokenizer):
187187
188188 Examples:
189189
190- Use in-momery vocabulary and merge list.
190+ Use in-memory vocabulary and merge list.
191191
192192 >>> vocab = {"butter": 1, "fly": 2}
193193 >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
@@ -244,7 +244,7 @@ def __init__(
244244 kwargs ["dtype" ] = tf .int32
245245 else :
246246 dtype = tf .dtypes .as_dtype (kwargs ["dtype" ])
247- if not dtype .is_integer :
247+ if not dtype .is_integer and dtype != tf . string :
248248 raise ValueError (
249249 "Output dtype must be an integer type or a string. "
250250 f"Received: `dtype={ dtype } `"
@@ -484,28 +484,29 @@ def process_unseen_tokens():
484484 lambda : cache_lookup ,
485485 )
486486
487- # Encode merged tokens.
488- tokenized_words = tf .strings .split (tokenized_words , sep = " " )
489- encoding = self .token_to_id_map .lookup (tokenized_words )
487+ tokens = tf .strings .split (tokenized_words , sep = " " )
488+ if self .compute_dtype != tf .string :
489+ # Encode merged tokens.
490+ tokens = self .token_to_id_map .lookup (tokens )
490491
491492 # Unflatten to match input.
492- encoding = tf .RaggedTensor .from_row_splits (
493- encoding .flat_values ,
494- tf .gather (encoding .row_splits , token_row_splits ),
493+ tokens = tf .RaggedTensor .from_row_splits (
494+ tokens .flat_values ,
495+ tf .gather (tokens .row_splits , token_row_splits ),
495496 )
496497
497498 # Convert to a dense output if `sequence_length` is set.
498499 if self .sequence_length :
499- output_shape = encoding .shape .as_list ()
500+ output_shape = tokens .shape .as_list ()
500501 output_shape [- 1 ] = self .sequence_length
501- encoding = encoding .to_tensor (shape = output_shape )
502+ tokens = tokens .to_tensor (shape = output_shape )
502503
503504 # Convert to a dense output if input in scalar
504505 if scalar_input :
505- encoding = tf .squeeze (encoding , 0 )
506- tf .ensure_shape (encoding , shape = [self .sequence_length ])
506+ tokens = tf .squeeze (tokens , 0 )
507+ tf .ensure_shape (tokens , shape = [self .sequence_length ])
507508
508- return encoding
509+ return tokens
509510
510511 def detokenize (self , inputs ):
511512 if not isinstance (inputs , (tf .Tensor , tf .RaggedTensor )):
0 commit comments