13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import os
17
+ from functools import lru_cache
18
+
19
+ import json
20
+ import shutil
16
21
from paddle .utils import try_import
17
- from .. import GPTTokenizer , AddedToken
22
+ from .. import PretrainedTokenizer , AddedToken
18
23
19
24
__all__ = ['BartTokenizer' ]
20
25
26
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
27
+ "bart-base" : 1024 ,
28
+ "bart-large" : 1024 ,
29
+ }
30
+
31
+
32
+ @lru_cache ()
33
+ def bytes_to_unicode ():
34
+ """
35
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
36
+ The reversible bpe codes work on unicode strings.
37
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
38
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
39
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
40
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
41
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
42
+ """
43
+ _chr = chr
44
+ bs = list (range (ord ("!" ),
45
+ ord ("~" ) + 1 )) + list (range (
46
+ ord ("¡" ),
47
+ ord ("¬" ) + 1 )) + list (range (ord ("®" ),
48
+ ord ("ÿ" ) + 1 ))
49
+ cs = bs [:]
50
+ n = 0
51
+ for b in range (2 ** 8 ):
52
+ if b not in bs :
53
+ bs .append (b )
54
+ cs .append (2 ** 8 + n )
55
+ n += 1
56
+ cs = [_chr (n ) for n in cs ]
57
+ return dict (zip (bs , cs ))
58
+
21
59
22
- class BartTokenizer (GPTTokenizer ):
60
+ def get_pairs (word ):
61
+ """Return set of symbol pairs in a word.
62
+
63
+ Word is represented as tuple of symbols (symbols being variable-length strings).
64
+ """
65
+ pairs = set ()
66
+ prev_char = word [0 ]
67
+ for char in word [1 :]:
68
+ pairs .add ((prev_char , char ))
69
+ prev_char = char
70
+ return pairs
71
+
72
+
73
+ class BartTokenizer (PretrainedTokenizer ):
23
74
r"""
24
75
Construct a BART tokenizer based on byte-level Byte-Pair-Encoding.
25
76
@@ -100,12 +151,12 @@ class BartTokenizer(GPTTokenizer):
100
151
}
101
152
}
102
153
pretrained_init_configuration = {"bart-base" : {}, "bart-large" : {}}
154
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
103
155
104
156
def __init__ (self ,
105
157
vocab_file ,
106
158
merges_file ,
107
159
errors = 'replace' ,
108
- max_len = None ,
109
160
bos_token = "<s>" ,
110
161
eos_token = "</s>" ,
111
162
cls_token = "<s>" ,
@@ -115,9 +166,6 @@ def __init__(self,
115
166
mask_token = "<mask>" ,
116
167
** kwargs ):
117
168
118
- super (BartTokenizer , self ).__init__ (vocab_file , merges_file , errors ,
119
- max_len , pad_token , eos_token )
120
-
121
169
bos_token = AddedToken (bos_token ,
122
170
lstrip = False , rstrip = False ) if isinstance (
123
171
bos_token , str ) else bos_token
@@ -150,6 +198,33 @@ def __init__(self,
150
198
pad_token = pad_token ,
151
199
mask_token = mask_token )
152
200
201
+ self ._vocab_file = vocab_file
202
+ self ._merges_file = merges_file
203
+ self .num_command_tokens = 2
204
+ self .num_type_tokens = 2
205
+
206
+ with open (vocab_file , 'r' , encoding = 'utf-8' ) as f :
207
+ self .encoder = json .load (f )
208
+
209
+ self .decoder = {v : k for k , v in self .encoder .items ()}
210
+
211
+ self .num_tokens = len (self .encoder )
212
+ self .num_text_tokens = self .num_tokens - 1
213
+ self .errors = errors # how to handle errors in decoding
214
+ self .byte_encoder = bytes_to_unicode ()
215
+ self .byte_decoder = {v : k for k , v in self .byte_encoder .items ()}
216
+
217
+ with open (merges_file , encoding = 'utf-8' ) as f :
218
+ bpe_data = f .read ().split ('\n ' )[1 :- 1 ]
219
+
220
+ bpe_merges = [tuple (merge .split ()) for merge in bpe_data ]
221
+ self .bpe_ranks = dict (zip (bpe_merges , range (len (bpe_merges ))))
222
+ self .cache = {}
223
+ re = try_import ("regex" )
224
+ self .pat = re .compile (
225
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
226
+ )
227
+
153
228
def _bpe_encode (self , text ):
154
229
bpe_tokens = []
155
230
re = try_import ("regex" )
@@ -200,3 +275,134 @@ def create_token_type_ids_from_sequences(self,
200
275
if token_ids_1 is None :
201
276
return len (cls + token_ids_0 + sep ) * [0 ]
202
277
return len (cls + token_ids_0 + sep + sep + token_ids_1 + sep ) * [0 ]
278
+
279
+ def get_vocab (self ):
280
+ return dict (self .encoder , ** self .added_tokens_encoder )
281
+
282
+ @property
283
+ def vocab_size (self ):
284
+ """
285
+ Returns the size of vocabulary.
286
+
287
+ Returns:
288
+ int: The sum of size of vocabulary and the size of speical tokens.
289
+
290
+ """
291
+
292
+ return len (self .encoder )
293
+
294
+ @property
295
+ def eol_token_id (self ):
296
+ if self .eol_token is None :
297
+ return None
298
+ return self .convert_tokens_to_ids (self .eol_token )
299
+
300
+ def bpe (self , token ):
301
+ if token in self .cache :
302
+ return self .cache [token ]
303
+ word = tuple (token )
304
+ pairs = get_pairs (word )
305
+
306
+ if not pairs :
307
+ return token
308
+
309
+ while True :
310
+ bigram = min (
311
+ pairs , key = lambda pair : self .bpe_ranks .get (pair , float ('inf' )))
312
+ if bigram not in self .bpe_ranks :
313
+ break
314
+ first , second = bigram
315
+ new_word = []
316
+ i = 0
317
+ while i < len (word ):
318
+ try :
319
+ j = word .index (first , i )
320
+ new_word .extend (word [i :j ])
321
+ i = j
322
+ except :
323
+ new_word .extend (word [i :])
324
+ break
325
+
326
+ if word [i ] == first and i < len (word ) - 1 and word [i +
327
+ 1 ] == second :
328
+ new_word .append (first + second )
329
+ i += 2
330
+ else :
331
+ new_word .append (word [i ])
332
+ i += 1
333
+ new_word = tuple (new_word )
334
+ word = new_word
335
+ if len (word ) == 1 :
336
+ break
337
+ else :
338
+ pairs = get_pairs (word )
339
+ word = ' ' .join (word )
340
+ self .cache [token ] = word
341
+ return word
342
+
343
+ def _tokenize (self , text ):
344
+ """ Tokenize a string. """
345
+ bpe_tokens = []
346
+ re = try_import ("regex" )
347
+ for token in re .findall (self .pat , text ):
348
+ token = '' .join (self .byte_encoder [b ] for b in token .encode ('utf-8' ))
349
+ bpe_tokens .extend (bpe_token
350
+ for bpe_token in self .bpe (token ).split (' ' ))
351
+ return bpe_tokens
352
+
353
+ def _convert_token_to_id (self , token ):
354
+ return self .encoder .get (token , self .encoder .get (self .unk_token ))
355
+
356
+ def _convert_id_to_token (self , index ):
357
+
358
+ return self .decoder [index ]
359
+
360
+ def convert_ids_to_string (self , ids ):
361
+ """
362
+ Converts a single index or a sequence of indices to texts.
363
+
364
+ Args:
365
+ ids (int|List[int]):
366
+ The token id (or token ids) to be converted to text.
367
+
368
+ Returns:
369
+ str: The decoded text.
370
+
371
+ Example:
372
+ .. code-block::
373
+
374
+ from paddlenlp.transformers import GPTTokenizer
375
+ tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en')
376
+ print(tokenizer.convert_ids_to_string(tokenizer.convert_ids_to_string([14618, 284, 779, 350, 37382, 47, 37382, 290, 350, 37382, 45, 19930]))
377
+ # 'Welcome to use PaddlePaddle and PaddleNLP'
378
+
379
+ """
380
+
381
+ text = '' .join ([self .decoder [id ] for id in ids ])
382
+ text = bytearray ([self .byte_decoder [c ]
383
+ for c in text ]).decode ('utf-8' , errors = self .errors )
384
+ return text
385
+
386
+ def save_resources (self , save_directory ):
387
+ """
388
+ Saves `SentencePiece <https://github.com/google/sentencepiece>`__ file
389
+ (ends with '.spm') under `save_directory`.
390
+
391
+ Args:
392
+ save_directory (str): Directory to save files into.
393
+ """
394
+ for name , file_name in self .resource_files_names .items ():
395
+ source_path = getattr (self , "_%s" % name )
396
+
397
+ save_path = os .path .join (save_directory , file_name )
398
+ if os .path .abspath (source_path ) != os .path .abspath (save_path ):
399
+ shutil .copyfile (source_path , save_path )
400
+
401
+ def convert_tokens_to_string (self , tokens ):
402
+ """
403
+ Converts a sequence of tokens (string) in a single string.
404
+ """
405
+ text = "" .join (tokens )
406
+ text = bytearray ([self .byte_decoder [c ]
407
+ for c in text ]).decode ('utf-8' , errors = self .errors )
408
+ return text
0 commit comments