|
| 1 | +# Copyright 2022 The KerasNLP Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import base64 |
| 16 | +import binascii |
| 17 | +from typing import List |
| 18 | + |
| 19 | +import tensorflow as tf |
| 20 | +import tensorflow_text as tf_text |
| 21 | + |
| 22 | +from keras_nlp.tokenizers import tokenizer |
| 23 | +from keras_nlp.utils.tensor_utils import tensor_to_string_list |
| 24 | + |
| 25 | + |
| 26 | +class SentencePieceTokenizer(tokenizer.Tokenizer): |
| 27 | + """A SentencePiece tokenizer layer. |
| 28 | +
|
| 29 | + This layer provides an implementation of SentencePiece tokenization |
| 30 | + as described in the [SentencePiece paper](https://arxiv.org/abs/1808.06226) |
| 31 | + and the [SentencePiece package](https://pypi.org/project/sentencepiece/). |
| 32 | + The tokenization will run entirely within the Tensorflow graph, and can |
| 33 | + be saved inside a `keras.Model`. |
| 34 | +
|
| 35 | + By default, the layer will output a `tf.RaggedTensor` where the last |
| 36 | + dimension of the output is ragged after whitespace splitting and sub-word |
| 37 | + tokenizing. If `sequence_length` is set, the layer will output a dense |
| 38 | + `tf.Tensor` where all inputs have been padded or truncated to |
| 39 | + `sequence_length`. The output dtype can be controlled via the `dtype` |
| 40 | + argument, which should be either an integer or string type. |
| 41 | +
|
| 42 | + Args: |
| 43 | + proto: Either a `string` path to a SentencePiece proto file, or a |
| 44 | + `bytes` object with a serialized SentencePiece proto. See the |
| 45 | + [SentencePiece repository](https://github.com/google/sentencepiece) |
| 46 | + for more details on the format. |
| 47 | + sequence_length: If set, the output will be converted to a dense |
| 48 | + tensor and padded/trimmed so all outputs are of `sequence_length`. |
| 49 | +
|
| 50 | + References: |
| 51 | + - [Kudo and Richardson, 2018](https://arxiv.org/abs/1808.06226) |
| 52 | +
|
| 53 | + Examples: |
| 54 | +
|
| 55 | + From bytes. |
| 56 | + ```python |
| 57 | + def train_sentence_piece_bytes(ds, size): |
| 58 | + bytes_io = io.BytesIO() |
| 59 | + sentencepiece.SentencePieceTrainer.train( |
| 60 | + sentence_iterator=ds.as_numpy_iterator(), |
| 61 | + model_writer=bytes_io, |
| 62 | + vocab_size=size, |
| 63 | + ) |
| 64 | + return bytes_io.getvalue() |
| 65 | +
|
| 66 | + # Train a sentencepiece proto. |
| 67 | + ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."]) |
| 68 | + proto = train_sentence_piece_bytes(ds, 20) |
| 69 | + # Tokenize inputs. |
| 70 | + tokenizer = keras_nlp.tokenizers.SentencePieceTokenizer(proto=proto) |
| 71 | + ds = ds.map(tokenizer) |
| 72 | + ``` |
| 73 | +
|
| 74 | + From a file. |
| 75 | + ```python |
| 76 | + def train_sentence_piece_file(ds, path, size): |
| 77 | + with open(path, "wb") as model_file: |
| 78 | + sentencepiece.SentencePieceTrainer.train( |
| 79 | + sentence_iterator=ds.as_numpy_iterator(), |
| 80 | + model_writer=model_file, |
| 81 | + vocab_size=size, |
| 82 | + ) |
| 83 | +
|
| 84 | + # Train a sentencepiece proto. |
| 85 | + ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."]) |
| 86 | + proto = train_sentence_piece_file(ds, "model.spm", 20) |
| 87 | + # Tokenize inputs. |
| 88 | + tokenizer = keras_nlp.tokenizers.SentencePieceTokenizer(proto="model.spm") |
| 89 | + ds = ds.map(tokenizer) |
| 90 | + ``` |
| 91 | + """ |
| 92 | + |
| 93 | + def __init__( |
| 94 | + self, |
| 95 | + proto, |
| 96 | + sequence_length: int = None, |
| 97 | + **kwargs, |
| 98 | + ) -> None: |
| 99 | + # Check dtype and provide a default. |
| 100 | + if "dtype" not in kwargs or kwargs["dtype"] is None: |
| 101 | + kwargs["dtype"] = tf.int32 |
| 102 | + else: |
| 103 | + dtype = tf.dtypes.as_dtype(kwargs["dtype"]) |
| 104 | + if not dtype.is_integer and dtype != tf.string: |
| 105 | + raise ValueError( |
| 106 | + "Output dtype must be one of `'string'`, `'int32'`, and " |
| 107 | + f"`'int64'`. Received: dtype={dtype}" |
| 108 | + ) |
| 109 | + |
| 110 | + super().__init__(**kwargs) |
| 111 | + |
| 112 | + if isinstance(proto, str): |
| 113 | + # A string could be either a filepath, or a base64 encoded byte |
| 114 | + # array (which we need for serialization). We will heuristically |
| 115 | + # try to distinguish, by checking if a string is both longer and |
| 116 | + # than 2048 characters and valid base64 characters. |
| 117 | + is_base64 = False |
| 118 | + if len(proto) > 2048: |
| 119 | + try: |
| 120 | + proto_bytes = base64.b64decode(proto, validate=True) |
| 121 | + is_base64 = True |
| 122 | + except binascii.Error: |
| 123 | + pass |
| 124 | + if not is_base64: |
| 125 | + proto_bytes = tf.io.gfile.GFile(proto, "rb").read() |
| 126 | + elif isinstance(proto, bytes): |
| 127 | + proto_bytes = proto |
| 128 | + else: |
| 129 | + raise ValueError( |
| 130 | + "SentencePiece `proto` argument should be either a `string` " |
| 131 | + f"filepath or a `bytes` sequence. " |
| 132 | + f"Received unknown type: {type(proto)}" |
| 133 | + ) |
| 134 | + |
| 135 | + self._sentence_piece = tf_text.SentencepieceTokenizer( |
| 136 | + model=proto_bytes, |
| 137 | + out_type=self.compute_dtype, |
| 138 | + ) |
| 139 | + |
| 140 | + # Keras cannot serialize a bytestring, so we base64 encode the model |
| 141 | + # byte array as a string for saving. |
| 142 | + self.proto = base64.b64encode(proto_bytes).decode("ascii") |
| 143 | + self.sequence_length = sequence_length |
| 144 | + |
| 145 | + def vocabulary_size(self) -> int: |
| 146 | + """Get the size of the tokenizer vocabulary.""" |
| 147 | + return int(self._sentence_piece.vocab_size().numpy()) |
| 148 | + |
| 149 | + def get_vocabulary(self) -> List[str]: |
| 150 | + """Get the size of the tokenizer vocabulary.""" |
| 151 | + return tensor_to_string_list( |
| 152 | + self._sentence_piece.id_to_string(tf.range(self.vocabulary_size())) |
| 153 | + ) |
| 154 | + |
| 155 | + def id_to_token(self, id: int) -> str: |
| 156 | + """Convert an integer id to a string token.""" |
| 157 | + return tensor_to_string_list(self._sentence_piece.id_to_string(id)) |
| 158 | + |
| 159 | + def token_to_id(self, token: str) -> int: |
| 160 | + """Convert a string token to an integer id.""" |
| 161 | + return int(self._sentence_piece.string_to_id(token).numpy()) |
| 162 | + |
| 163 | + def get_config(self): |
| 164 | + config = super().get_config() |
| 165 | + config.update( |
| 166 | + { |
| 167 | + # Ideally the model would be saved as a file asset in |
| 168 | + # the saved model. We have no good way to support this |
| 169 | + # currently, so we save the model string in the config. |
| 170 | + "proto": self.proto, |
| 171 | + "sequence_length": self.sequence_length, |
| 172 | + } |
| 173 | + ) |
| 174 | + return config |
| 175 | + |
| 176 | + def tokenize(self, inputs): |
| 177 | + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
| 178 | + inputs = tf.convert_to_tensor(inputs) |
| 179 | + scalar_input = inputs.shape.rank == 0 |
| 180 | + if scalar_input: |
| 181 | + inputs = tf.expand_dims(inputs, 0) |
| 182 | + |
| 183 | + tokens = self._sentence_piece.tokenize(inputs) |
| 184 | + |
| 185 | + # Convert to a dense output if `sequence_length` is set. |
| 186 | + if self.sequence_length: |
| 187 | + output_shape = tokens.shape.as_list() |
| 188 | + output_shape[-1] = self.sequence_length |
| 189 | + tokens = tokens.to_tensor(shape=output_shape) |
| 190 | + |
| 191 | + # Convert to a dense output if input was a scalar. |
| 192 | + if scalar_input: |
| 193 | + tokens = tf.squeeze(tokens, 0) |
| 194 | + tf.ensure_shape(tokens, shape=[self.sequence_length]) |
| 195 | + |
| 196 | + return tokens |
| 197 | + |
| 198 | + def detokenize(self, inputs): |
| 199 | + return self._sentence_piece.detokenize(inputs) |
0 commit comments