Skip to content

Commit 5a4f57c

Browse files
authored
Fix PretrainedTokenizer saving. (#648)
1 parent 3038913 commit 5a4f57c

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

paddlenlp/data/vocab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def to_tokens(self, indices):
215215

216216
tokens = []
217217
for idx in indices:
218-
if not isinstance(idx, int):
218+
if not isinstance(idx, (int, np.integer)):
219219
warnings.warn(
220220
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
221221
)

paddlenlp/transformers/model_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,12 @@ def save_pretrained(self, save_dir):
379379
# reload from save_directory
380380
model = BertForSequenceClassification.from_pretrained('./trained_model/')
381381
"""
382-
assert os.path.isdir(
383-
save_dir), "save_dir ({}) is not available.".format(save_dir)
384-
# Save model config
382+
assert not os.path.isfile(
383+
save_dir
384+
), "Saving directory ({}) should be a directory, not a file".format(
385+
save_dir)
386+
os.makedirs(save_dir, exist_ok=True)
387+
# Save model config
385388
self.save_model_config(save_dir)
386389
# Save model
387390
file_name = os.path.join(save_dir,

paddlenlp/transformers/tokenizer_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import six
2323
import unicodedata
24+
from shutil import copyfile
2425
from typing import Iterable, Iterator, Optional, List, Any, Callable, Union
2526

2627
from paddlenlp.utils.downloader import get_path_from_url
@@ -525,9 +526,12 @@ def save_pretrained(self, save_directory):
525526
# reload from save_directory
526527
tokenizer = BertTokenizer.from_pretrained('trained_model')
527528
"""
528-
assert os.path.isdir(
529+
assert not os.path.isfile(
529530
save_directory
530-
), "Saving directory ({}) should be a directory".format(save_directory)
531+
), "Saving directory ({}) should be a directory, not a file".format(
532+
save_directory)
533+
os.makedirs(save_directory, exist_ok=True)
534+
531535
tokenizer_config_file = os.path.join(save_directory,
532536
self.tokenizer_config_file)
533537
# init_config is set in metaclass created `__init__`,
@@ -540,19 +544,16 @@ def save_pretrained(self, save_directory):
540544
def save_resources(self, save_directory):
541545
"""
542546
Save tokenizer related resources to `resource_files_names` indicating
543-
files under `save_directory`.
544-
545-
Currently, it only can support saving `vocab` of tokenizer by using
546-
`self.save_vocabulary(file_name, self.vocab)`. Override it if necessary.
547+
files under `save_directory` by copying directly. Override it if necessary.
547548
548549
Args:
549550
save_directory (str): Directory to save files into.
550551
"""
551-
assert hasattr(self, 'vocab') and len(
552-
self.resource_files_names) == 1, "Must overwrite `save_resources`"
553-
file_name = os.path.join(save_directory,
554-
list(self.resource_files_names.values())[0])
555-
self.save_vocabulary(file_name, self.vocab)
552+
for name, file_name in self.resource_files_names.items():
553+
src_path = self.init_config[name]
554+
dst_path = os.path.join(save_directory, file_name)
555+
if os.path.abspath(src_path) != os.path.abspath(dst_path):
556+
copyfile(src_path, dst_path)
556557

557558
@staticmethod
558559
def load_vocabulary(filepath,

0 commit comments

Comments
 (0)