Skip to content

Commit 84bed77

Browse files
authored
Add a SentencePiece tokenizer (#218)
* Add SentencePiece tokenizer * Address review comments * Update sentencepiece tokenizer with new arg name * Fix serialization * Improve accessors * Address review comments * Refactor docstrings
1 parent 5c87ada commit 84bed77

File tree

5 files changed

+405
-1
lines changed

5 files changed

+405
-1
lines changed

keras_nlp/tokenizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer
16+
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
1617
from keras_nlp.tokenizers.tokenizer import Tokenizer
1718
from keras_nlp.tokenizers.unicode_character_tokenizer import (
1819
UnicodeCharacterTokenizer,
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
import binascii
17+
from typing import List
18+
19+
import tensorflow as tf
20+
import tensorflow_text as tf_text
21+
22+
from keras_nlp.tokenizers import tokenizer
23+
from keras_nlp.utils.tensor_utils import tensor_to_string_list
24+
25+
26+
class SentencePieceTokenizer(tokenizer.Tokenizer):
27+
"""A SentencePiece tokenizer layer.
28+
29+
This layer provides an implementation of SentencePiece tokenization
30+
as described in the [SentencePiece paper](https://arxiv.org/abs/1808.06226)
31+
and the [SentencePiece package](https://pypi.org/project/sentencepiece/).
32+
The tokenization will run entirely within the Tensorflow graph, and can
33+
be saved inside a `keras.Model`.
34+
35+
By default, the layer will output a `tf.RaggedTensor` where the last
36+
dimension of the output is ragged after whitespace splitting and sub-word
37+
tokenizing. If `sequence_length` is set, the layer will output a dense
38+
`tf.Tensor` where all inputs have been padded or truncated to
39+
`sequence_length`. The output dtype can be controlled via the `dtype`
40+
argument, which should be either an integer or string type.
41+
42+
Args:
43+
proto: Either a `string` path to a SentencePiece proto file, or a
44+
`bytes` object with a serialized SentencePiece proto. See the
45+
[SentencePiece repository](https://github.com/google/sentencepiece)
46+
for more details on the format.
47+
sequence_length: If set, the output will be converted to a dense
48+
tensor and padded/trimmed so all outputs are of `sequence_length`.
49+
50+
References:
51+
- [Kudo and Richardson, 2018](https://arxiv.org/abs/1808.06226)
52+
53+
Examples:
54+
55+
From bytes.
56+
```python
57+
def train_sentence_piece_bytes(ds, size):
58+
bytes_io = io.BytesIO()
59+
sentencepiece.SentencePieceTrainer.train(
60+
sentence_iterator=ds.as_numpy_iterator(),
61+
model_writer=bytes_io,
62+
vocab_size=size,
63+
)
64+
return bytes_io.getvalue()
65+
66+
# Train a sentencepiece proto.
67+
ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."])
68+
proto = train_sentence_piece_bytes(ds, 20)
69+
# Tokenize inputs.
70+
tokenizer = keras_nlp.tokenizers.SentencePieceTokenizer(proto=proto)
71+
ds = ds.map(tokenizer)
72+
```
73+
74+
From a file.
75+
```python
76+
def train_sentence_piece_file(ds, path, size):
77+
with open(path, "wb") as model_file:
78+
sentencepiece.SentencePieceTrainer.train(
79+
sentence_iterator=ds.as_numpy_iterator(),
80+
model_writer=model_file,
81+
vocab_size=size,
82+
)
83+
84+
# Train a sentencepiece proto.
85+
ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."])
86+
proto = train_sentence_piece_file(ds, "model.spm", 20)
87+
# Tokenize inputs.
88+
tokenizer = keras_nlp.tokenizers.SentencePieceTokenizer(proto="model.spm")
89+
ds = ds.map(tokenizer)
90+
```
91+
"""
92+
93+
def __init__(
94+
self,
95+
proto,
96+
sequence_length: int = None,
97+
**kwargs,
98+
) -> None:
99+
# Check dtype and provide a default.
100+
if "dtype" not in kwargs or kwargs["dtype"] is None:
101+
kwargs["dtype"] = tf.int32
102+
else:
103+
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
104+
if not dtype.is_integer and dtype != tf.string:
105+
raise ValueError(
106+
"Output dtype must be one of `'string'`, `'int32'`, and "
107+
f"`'int64'`. Received: dtype={dtype}"
108+
)
109+
110+
super().__init__(**kwargs)
111+
112+
if isinstance(proto, str):
113+
# A string could be either a filepath, or a base64 encoded byte
114+
# array (which we need for serialization). We will heuristically
115+
# try to distinguish, by checking if a string is both longer and
116+
# than 2048 characters and valid base64 characters.
117+
is_base64 = False
118+
if len(proto) > 2048:
119+
try:
120+
proto_bytes = base64.b64decode(proto, validate=True)
121+
is_base64 = True
122+
except binascii.Error:
123+
pass
124+
if not is_base64:
125+
proto_bytes = tf.io.gfile.GFile(proto, "rb").read()
126+
elif isinstance(proto, bytes):
127+
proto_bytes = proto
128+
else:
129+
raise ValueError(
130+
"SentencePiece `proto` argument should be either a `string` "
131+
f"filepath or a `bytes` sequence. "
132+
f"Received unknown type: {type(proto)}"
133+
)
134+
135+
self._sentence_piece = tf_text.SentencepieceTokenizer(
136+
model=proto_bytes,
137+
out_type=self.compute_dtype,
138+
)
139+
140+
# Keras cannot serialize a bytestring, so we base64 encode the model
141+
# byte array as a string for saving.
142+
self.proto = base64.b64encode(proto_bytes).decode("ascii")
143+
self.sequence_length = sequence_length
144+
145+
def vocabulary_size(self) -> int:
146+
"""Get the size of the tokenizer vocabulary."""
147+
return int(self._sentence_piece.vocab_size().numpy())
148+
149+
def get_vocabulary(self) -> List[str]:
150+
"""Get the size of the tokenizer vocabulary."""
151+
return tensor_to_string_list(
152+
self._sentence_piece.id_to_string(tf.range(self.vocabulary_size()))
153+
)
154+
155+
def id_to_token(self, id: int) -> str:
156+
"""Convert an integer id to a string token."""
157+
return tensor_to_string_list(self._sentence_piece.id_to_string(id))
158+
159+
def token_to_id(self, token: str) -> int:
160+
"""Convert a string token to an integer id."""
161+
return int(self._sentence_piece.string_to_id(token).numpy())
162+
163+
def get_config(self):
164+
config = super().get_config()
165+
config.update(
166+
{
167+
# Ideally the model would be saved as a file asset in
168+
# the saved model. We have no good way to support this
169+
# currently, so we save the model string in the config.
170+
"proto": self.proto,
171+
"sequence_length": self.sequence_length,
172+
}
173+
)
174+
return config
175+
176+
def tokenize(self, inputs):
177+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
178+
inputs = tf.convert_to_tensor(inputs)
179+
scalar_input = inputs.shape.rank == 0
180+
if scalar_input:
181+
inputs = tf.expand_dims(inputs, 0)
182+
183+
tokens = self._sentence_piece.tokenize(inputs)
184+
185+
# Convert to a dense output if `sequence_length` is set.
186+
if self.sequence_length:
187+
output_shape = tokens.shape.as_list()
188+
output_shape[-1] = self.sequence_length
189+
tokens = tokens.to_tensor(shape=output_shape)
190+
191+
# Convert to a dense output if input was a scalar.
192+
if scalar_input:
193+
tokens = tf.squeeze(tokens, 0)
194+
tf.ensure_shape(tokens, shape=[self.sequence_length])
195+
196+
return tokens
197+
198+
def detokenize(self, inputs):
199+
return self._sentence_piece.detokenize(inputs)

0 commit comments

Comments
 (0)