Skip to content

Commit 08636b3

Browse files
authored
Add unit tests for generation models (#3018)
1 parent f1634e5 commit 08636b3

File tree

13 files changed

+2470
-58
lines changed

13 files changed

+2470
-58
lines changed

model_zoo/ernie-gen/encode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0):
8383
mask = np.tril(mask, -1)
8484
elif mask_type == 'diag':
8585
assert query_len == batch_ids.shape[1]
86-
# import pdb; pdb.set_trace()
8786
mask = np.stack([np.diag(np.diag(m)) for m in mask], 0)
8887

8988
else:

paddlenlp/transformers/bart/modeling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,12 @@ def get_encoder(self):
428428
def get_decoder(self):
429429
return self.decoder
430430

431+
def get_input_embeddings(self):
432+
return self.shared
433+
434+
def set_input_embeddings(self, value):
435+
self.shared = value
436+
431437
def forward(self,
432438
input_ids,
433439
attention_mask=None,

paddlenlp/transformers/bart/tokenizer.py

Lines changed: 212 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,64 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
17+
from functools import lru_cache
18+
19+
import json
20+
import shutil
1621
from paddle.utils import try_import
17-
from .. import GPTTokenizer, AddedToken
22+
from .. import PretrainedTokenizer, AddedToken
1823

1924
__all__ = ['BartTokenizer']
2025

26+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
27+
"bart-base": 1024,
28+
"bart-large": 1024,
29+
}
30+
31+
32+
@lru_cache()
33+
def bytes_to_unicode():
34+
"""
35+
Returns list of utf-8 byte and a corresponding list of unicode strings.
36+
The reversible bpe codes work on unicode strings.
37+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
38+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
39+
This is a signficant percentage of your normal, say, 32K bpe vocab.
40+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
41+
And avoids mapping to whitespace/control characters the bpe code barfs on.
42+
"""
43+
_chr = chr
44+
bs = list(range(ord("!"),
45+
ord("~") + 1)) + list(range(
46+
ord("¡"),
47+
ord("¬") + 1)) + list(range(ord("®"),
48+
ord("ÿ") + 1))
49+
cs = bs[:]
50+
n = 0
51+
for b in range(2**8):
52+
if b not in bs:
53+
bs.append(b)
54+
cs.append(2**8 + n)
55+
n += 1
56+
cs = [_chr(n) for n in cs]
57+
return dict(zip(bs, cs))
58+
2159

22-
class BartTokenizer(GPTTokenizer):
60+
def get_pairs(word):
61+
"""Return set of symbol pairs in a word.
62+
63+
Word is represented as tuple of symbols (symbols being variable-length strings).
64+
"""
65+
pairs = set()
66+
prev_char = word[0]
67+
for char in word[1:]:
68+
pairs.add((prev_char, char))
69+
prev_char = char
70+
return pairs
71+
72+
73+
class BartTokenizer(PretrainedTokenizer):
2374
r"""
2475
Construct a BART tokenizer based on byte-level Byte-Pair-Encoding.
2576
@@ -100,12 +151,12 @@ class BartTokenizer(GPTTokenizer):
100151
}
101152
}
102153
pretrained_init_configuration = {"bart-base": {}, "bart-large": {}}
154+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
103155

104156
def __init__(self,
105157
vocab_file,
106158
merges_file,
107159
errors='replace',
108-
max_len=None,
109160
bos_token="<s>",
110161
eos_token="</s>",
111162
cls_token="<s>",
@@ -115,9 +166,6 @@ def __init__(self,
115166
mask_token="<mask>",
116167
**kwargs):
117168

118-
super(BartTokenizer, self).__init__(vocab_file, merges_file, errors,
119-
max_len, pad_token, eos_token)
120-
121169
bos_token = AddedToken(bos_token,
122170
lstrip=False, rstrip=False) if isinstance(
123171
bos_token, str) else bos_token
@@ -150,6 +198,33 @@ def __init__(self,
150198
pad_token=pad_token,
151199
mask_token=mask_token)
152200

201+
self._vocab_file = vocab_file
202+
self._merges_file = merges_file
203+
self.num_command_tokens = 2
204+
self.num_type_tokens = 2
205+
206+
with open(vocab_file, 'r', encoding='utf-8') as f:
207+
self.encoder = json.load(f)
208+
209+
self.decoder = {v: k for k, v in self.encoder.items()}
210+
211+
self.num_tokens = len(self.encoder)
212+
self.num_text_tokens = self.num_tokens - 1
213+
self.errors = errors # how to handle errors in decoding
214+
self.byte_encoder = bytes_to_unicode()
215+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
216+
217+
with open(merges_file, encoding='utf-8') as f:
218+
bpe_data = f.read().split('\n')[1:-1]
219+
220+
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
221+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
222+
self.cache = {}
223+
re = try_import("regex")
224+
self.pat = re.compile(
225+
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
226+
)
227+
153228
def _bpe_encode(self, text):
154229
bpe_tokens = []
155230
re = try_import("regex")
@@ -200,3 +275,134 @@ def create_token_type_ids_from_sequences(self,
200275
if token_ids_1 is None:
201276
return len(cls + token_ids_0 + sep) * [0]
202277
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
278+
279+
def get_vocab(self):
280+
return dict(self.encoder, **self.added_tokens_encoder)
281+
282+
@property
283+
def vocab_size(self):
284+
"""
285+
Returns the size of vocabulary.
286+
287+
Returns:
288+
int: The sum of size of vocabulary and the size of speical tokens.
289+
290+
"""
291+
292+
return len(self.encoder)
293+
294+
@property
295+
def eol_token_id(self):
296+
if self.eol_token is None:
297+
return None
298+
return self.convert_tokens_to_ids(self.eol_token)
299+
300+
def bpe(self, token):
301+
if token in self.cache:
302+
return self.cache[token]
303+
word = tuple(token)
304+
pairs = get_pairs(word)
305+
306+
if not pairs:
307+
return token
308+
309+
while True:
310+
bigram = min(
311+
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
312+
if bigram not in self.bpe_ranks:
313+
break
314+
first, second = bigram
315+
new_word = []
316+
i = 0
317+
while i < len(word):
318+
try:
319+
j = word.index(first, i)
320+
new_word.extend(word[i:j])
321+
i = j
322+
except:
323+
new_word.extend(word[i:])
324+
break
325+
326+
if word[i] == first and i < len(word) - 1 and word[i +
327+
1] == second:
328+
new_word.append(first + second)
329+
i += 2
330+
else:
331+
new_word.append(word[i])
332+
i += 1
333+
new_word = tuple(new_word)
334+
word = new_word
335+
if len(word) == 1:
336+
break
337+
else:
338+
pairs = get_pairs(word)
339+
word = ' '.join(word)
340+
self.cache[token] = word
341+
return word
342+
343+
def _tokenize(self, text):
344+
""" Tokenize a string. """
345+
bpe_tokens = []
346+
re = try_import("regex")
347+
for token in re.findall(self.pat, text):
348+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
349+
bpe_tokens.extend(bpe_token
350+
for bpe_token in self.bpe(token).split(' '))
351+
return bpe_tokens
352+
353+
def _convert_token_to_id(self, token):
354+
return self.encoder.get(token, self.encoder.get(self.unk_token))
355+
356+
def _convert_id_to_token(self, index):
357+
358+
return self.decoder[index]
359+
360+
def convert_ids_to_string(self, ids):
361+
"""
362+
Converts a single index or a sequence of indices to texts.
363+
364+
Args:
365+
ids (int|List[int]):
366+
The token id (or token ids) to be converted to text.
367+
368+
Returns:
369+
str: The decoded text.
370+
371+
Example:
372+
.. code-block::
373+
374+
from paddlenlp.transformers import GPTTokenizer
375+
tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en')
376+
print(tokenizer.convert_ids_to_string(tokenizer.convert_ids_to_string([14618, 284, 779, 350, 37382, 47, 37382, 290, 350, 37382, 45, 19930]))
377+
# 'Welcome to use PaddlePaddle and PaddleNLP'
378+
379+
"""
380+
381+
text = ''.join([self.decoder[id] for id in ids])
382+
text = bytearray([self.byte_decoder[c]
383+
for c in text]).decode('utf-8', errors=self.errors)
384+
return text
385+
386+
def save_resources(self, save_directory):
387+
"""
388+
Saves `SentencePiece <https://github.com/google/sentencepiece>`__ file
389+
(ends with '.spm') under `save_directory`.
390+
391+
Args:
392+
save_directory (str): Directory to save files into.
393+
"""
394+
for name, file_name in self.resource_files_names.items():
395+
source_path = getattr(self, "_%s" % name)
396+
397+
save_path = os.path.join(save_directory, file_name)
398+
if os.path.abspath(source_path) != os.path.abspath(save_path):
399+
shutil.copyfile(source_path, save_path)
400+
401+
def convert_tokens_to_string(self, tokens):
402+
"""
403+
Converts a sequence of tokens (string) in a single string.
404+
"""
405+
text = "".join(tokens)
406+
text = bytearray([self.byte_decoder[c]
407+
for c in text]).decode('utf-8', errors=self.errors)
408+
return text

0 commit comments

Comments
 (0)