forked from PaddlePaddle/ERNIE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer.py
More file actions
425 lines (370 loc) · 15.6 KB
/
tokenizer.py
File metadata and controls
425 lines (370 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ernie4_5_Tokenizer
"""
import os
import re
from shutil import copyfile
from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import sentencepiece as spm
from paddleformers.transformers import PretrainedTokenizer
from paddleformers.transformers.legacy.tokenizer_utils_base import (
PaddingStrategy,
TextInput,
)
from paddleformers.utils.log import logger
class Ernie4_5_Tokenizer(PretrainedTokenizer):
"""
Ernie4_5_Tokenizer
Attributes:
resource_files_names (dict): Mapping of resource file names.
pretrained_resource_files_map (dict): Mapping of pretrained resources.
pretrained_init_configuration (dict): Mapping of pretrained init configuration.
model_input_names (list): Model input names expected by the tokenizer
padding_side (str): Padding side (where to add padding tokens)
"""
resource_files_names = {
"vocab_file": "tokenizer.model",
}
pretrained_resource_files_map = {"vocab_file": {"ernie-bot": None}}
pretrained_init_configuration = {
"ernie-bot": {},
}
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
padding_side = "right"
def __init__(
self,
vocab_file,
bos_token="<s>",
cls_token="<cls>",
eos_token="</s>",
mask_token="<mask:0>",
pad_token="<pad>",
sep_token="<sep>",
unk_token="<unk>",
additional_special_tokens=None,
verbose=False,
**kwargs,
):
"""
Initialize the ERNIE tokenizer.
Args:
vocab_file (str): Path to the SentencePiece model file.
bos_token (str, optional): Beginning of sentence token. Defaults to "<s>".
cls_token (str, optional): Classification token. Defaults to "<cls>".
eos_token (str, optional): End of sentence token. Defaults to "</s>".
mask_token (str, optional): Mask token. Defaults to "<mask:0>".
pad_token (str, optional): Padding token. Defaults to "<pad>".
sep_token (str, optional): Separator token. Defaults to "<sep>".
unk_token (str, optional): Unknown token. Defaults to "<unk>".
additional_special_tokens (List[str], optional): Additional special tokens.
Defaults to ["<mask:1>", "<mask:7>"].
verbose (bool, optional): Whether to print detailed logs or progress information during execution.
**kwargs: Additional keyword arguments passed to the parent class.
"""
if additional_special_tokens is None:
additional_special_tokens = ["<mask:1>", "<mask:7>"]
super().__init__(
bos_token=bos_token,
cls_token=cls_token,
eos_token=eos_token,
mask_token=mask_token,
pad_token=pad_token,
sep_token=sep_token,
unk_token=unk_token,
additional_special_tokens=additional_special_tokens,
verbose=verbose,
**kwargs,
)
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
"""Returns the size of the vocabulary.
Returns:
int: The number of tokens in the vocabulary.
"""
return self.sp_model.vocab_size()
def get_vocab(self):
"""Get the vocabulary as a dictionary mapping tokens to their IDs.
Returns:
dict: A dictionary mapping tokens to their corresponding IDs.
"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""Tokenize text using SentencePiece.
Args:
text (str): The text to tokenize.
Returns:
list: A list of tokens.
"""
return self.sp_model.encode_as_pieces(text)
def _convert_token_to_id(self, token):
"""Convert a token (str) to an ID using the vocabulary.
Args:
token (str): The token to convert.
Returns:
int: The corresponding token ID.
"""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, id):
"""Convert an ID to a token (str) using the vocabulary.
Args:
id (int): The token ID to convert.
Returns:
str: The corresponding token.
"""
return self.sp_model.id_to_piece(id)
@classmethod
def from_pretrained(cls, *args, **kwargs):
tokenizer = super().from_pretrained(*args, **kwargs)
# pre-process map type all_special_tokens
tokenizer.all_spec_tok = set(tokenizer.all_special_tokens)
return tokenizer
def convert_tokens_to_string(self, tokens):
"""Convert a sequence of tokens back to a single string.
Args:
tokens (List[str]): A list of tokens to convert.
Returns:
str: The reconstructed string.
"""
current_sub_tokens = []
out_string = ""
# prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_spec_tok:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string # .strip()
def prepare_for_model(self, *args, **kwargs):
"""doc"""
if "add_special_tokens" in kwargs:
kwargs.pop("add_special_tokens")
# logger.warning(f'Ernie4_5_Tokenizer v2 does not support `add_special_tokens`')
return super().prepare_for_model(*args, **kwargs)
def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (str): The directory in which to save the vocabulary.
filename_prefix (Optional[str]): Optional prefix for the saved filename.
Returns:
Tuple[str]: Paths to the files saved.
Raises:
ValueError: If the save_directory is not a valid directory.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ self.resource_files_names["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
"""
Converts a string in a sequence of tokens, using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
Args:
text (`str`):
The sequence to be encoded.
**kwargs (additional keyword arguments):
Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
Returns:
`List[str]`: The list of tokens.
"""
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
# TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [
re.escape(s_tok)
for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(
pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text
)
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
tokenized_text = []
for token in tokens:
# Need to skip eventual empty (fully stripped) tokens
if not token:
continue
if token in no_split_token:
tokenized_text.append(token)
else:
tokenized_text.extend(self._tokenize(token))
# ["This", " is", " something", "<special_token_1>", "else"]
return tokenized_text
def _decode(self, *args, **kwargs):
"""doc"""
kwargs.pop("clean_up_tokenization_spaces", None)
kwargs.pop("spaces_between_special_tokens", None)
return super()._decode(
*args,
**kwargs,
clean_up_tokenization_spaces=False,
spaces_between_special_tokens=False,
)
def _pad(
self,
encoded_inputs: Dict,
max_length: Optional[int] = None,
padding_strategy=PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs according to specified strategy.
Args:
encoded_inputs (Union[Dict]): Dictionary of encoded inputs.
max_length (Optional[int]): Maximum length to pad to.
padding_strategy (PaddingStrategy): Strategy for padding.
pad_to_multiple_of (Optional[int]): Pad to a multiple of this value.
return_attention_mask (Optional[bool]): Whether to return attention mask.
Returns:
dict: Dictionary with padded inputs and optional attention mask.
Raises:
ValueError: If attention_mask has unexpected type or invalid padding strategy.
"""
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if return_attention_mask:
required_input = encoded_inputs[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if (
max_length is not None
and pad_to_multiple_of is not None
and (max_length % pad_to_multiple_of != 0)
):
max_length = (
(max_length // pad_to_multiple_of) + 1
) * pad_to_multiple_of
needs_to_be_padded = (
padding_strategy != PaddingStrategy.DO_NOT_PAD
and len(required_input) != max_length
)
if (
"attention_mask" in encoded_inputs
and encoded_inputs["attention_mask"] is not None
):
attention_mask = encoded_inputs.pop("attention_mask")
if isinstance(attention_mask, paddle.Tensor):
attention_mask = attention_mask.numpy()
elif isinstance(attention_mask, list):
attention_mask = np.array(attention_mask)
elif not isinstance(attention_mask, np.ndarray):
raise ValueError(
f"Unexpected type {type(attention_mask)} of attention_mask, "
)
else:
attention_mask = np.tril(
np.ones((len(required_input), len(required_input)), dtype=np.int64)
)
attention_mask = np.expand_dims(attention_mask, axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if attention_mask.ndim == 1:
pad_width = [(0, difference)]
else:
pad_width = [(0, 0), (0, difference), (0, difference)]
elif self.padding_side == "left":
if attention_mask.ndim == 1:
pad_width = [(difference, 0)]
else:
pad_width = [(0, 0), (difference, 0), (difference, 0)]
else:
raise ValueError(
"Invalid padding strategy:" + str(self.padding_side)
)
attention_mask = np.pad(
attention_mask,
pad_width=pad_width,
mode="constant",
constant_values=0,
)
encoded_inputs = super()._pad(
encoded_inputs,
max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
)
if return_attention_mask:
encoded_inputs["attention_mask"] = attention_mask.tolist()
return encoded_inputs
def add_special_tokens(
tokenizer,
special_tokens_info,
use_ocr_specialtoken=False,
special_token_ids_start=254208,
special_token_ids_end=256256,
):
"""
Add special tokens to the tokenizer
Placeholder tokens: [<|IMAGE_PLACEHOLDER|>, <|AUDIO_PLACEHOLDER|>, <|VIDEO_PLACEHOLDER|>] total 3
Modal start and end special tokens: [<|BOI|> <|EOI|> <|BOA|> <|EOA|> <|BOV|> <|EOV|>]
OCR special tokens: [<|LOC_0|> <|LOC_1|> ... <|LOC_1000|>] total 1001
total 2048 unused tokens
Args:
tokenizer (Ernie4_5_VLTokenizer): tokenizer object to add special tokens to
special_tokens_info (dict): dictionary containing special token information
use_ocr_specialtoken (bool): whether to add OCR special tokens
special_token_ids_start (int, optional): starting ID for special tokens. Defaults to 254208.
special_token_ids_end (int, optional): maximum supported vocabulary size. Defaults to 256256.
"""
special_tokens = [
special_tokens_info["image_placeholder"],
special_tokens_info["audio_placeholder"],
]
if use_ocr_specialtoken:
special_tokens.extend(special_tokens_info["ocr_coor"])
special_tokens.extend(special_tokens_info["ocr_begin_end"])
# add special_tokens
additional_special_tokens = {"additional_special_tokens": special_tokens}
tokenizer.add_special_tokens(additional_special_tokens)
# check
first_special_tokens = tokenizer.encode(special_tokens[0])["input_ids"]
assert (
first_special_tokens[0] == special_token_ids_start
), f"[ERROR] first_special_tokens={first_special_tokens}"
assert (
len(tokenizer.get_vocab()) < special_token_ids_end
), f"[ERROR] vocab_size = {len(tokenizer.get_vocab())} >= {special_token_ids_end} 增加过多special token了!"