21
21
import os
22
22
import six
23
23
import unicodedata
24
+ from shutil import copyfile
24
25
from typing import Iterable , Iterator , Optional , List , Any , Callable , Union
25
26
26
27
from paddlenlp .utils .downloader import get_path_from_url
@@ -525,9 +526,12 @@ def save_pretrained(self, save_directory):
525
526
# reload from save_directory
526
527
tokenizer = BertTokenizer.from_pretrained('trained_model')
527
528
"""
528
- assert os .path .isdir (
529
+ assert not os .path .isfile (
529
530
save_directory
530
- ), "Saving directory ({}) should be a directory" .format (save_directory )
531
+ ), "Saving directory ({}) should be a directory, not a file" .format (
532
+ save_directory )
533
+ os .makedirs (save_directory , exist_ok = True )
534
+
531
535
tokenizer_config_file = os .path .join (save_directory ,
532
536
self .tokenizer_config_file )
533
537
# init_config is set in metaclass created `__init__`,
@@ -540,19 +544,16 @@ def save_pretrained(self, save_directory):
540
544
def save_resources (self , save_directory ):
541
545
"""
542
546
Save tokenizer related resources to `resource_files_names` indicating
543
- files under `save_directory`.
544
-
545
- Currently, it only can support saving `vocab` of tokenizer by using
546
- `self.save_vocabulary(file_name, self.vocab)`. Override it if necessary.
547
+ files under `save_directory` by copying directly. Override it if necessary.
547
548
548
549
Args:
549
550
save_directory (str): Directory to save files into.
550
551
"""
551
- assert hasattr ( self , 'vocab' ) and len (
552
- self . resource_files_names ) == 1 , "Must overwrite `save_resources`"
553
- file_name = os .path .join (save_directory ,
554
- list ( self . resource_files_names . values ())[ 0 ])
555
- self . save_vocabulary ( file_name , self . vocab )
552
+ for name , file_name in self . resource_files_names . items ():
553
+ src_path = self . init_config [ name ]
554
+ dst_path = os .path .join (save_directory , file_name )
555
+ if os . path . abspath ( src_path ) != os . path . abspath ( dst_path ):
556
+ copyfile ( src_path , dst_path )
556
557
557
558
@staticmethod
558
559
def load_vocabulary (filepath ,
0 commit comments