@@ -193,7 +193,6 @@ def save_pretrained(
193193 variant : Optional [str ] = None ,
194194 max_shard_size : Optional [Union [int , str ]] = None ,
195195 push_to_hub : bool = False ,
196- dduf_format : bool = False ,
197196 dduf_filename : Optional [Union [str , os .PathLike ]] = None ,
198197 ** kwargs ,
199198 ):
@@ -229,9 +228,6 @@ class implements both a save and loading method. The pipeline is easily reloaded
229228 model_index_dict .pop ("_module" , None )
230229 model_index_dict .pop ("_name_or_path" , None )
231230
232- if dduf_format and dduf_filename is None :
233- raise RuntimeError ("You need set dduf_filename if you want to save your model in DDUF format." )
234-
235231 if push_to_hub :
236232 commit_message = kwargs .pop ("commit_message" , None )
237233 private = kwargs .pop ("private" , False )
@@ -306,9 +302,19 @@ def is_saveable_module(name, value):
306302
307303 save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
308304
309- if dduf_format :
305+ if dduf_filename :
310306 import shutil
311- import tarfile
307+ import zipfile
308+
309+ def zipdir (dir_to_archive , zipf ):
310+ "zip a directory"
311+ for root , dirs , files in os .walk (dir_to_archive ):
312+ for file in files :
313+ file_path = os .path .join (root , file )
314+ arcname = os .path .join (
315+ os .path .basename (dir_to_archive ), os .path .relpath (file_path , start = dir_to_archive )
316+ )
317+ zipf .write (file_path , arcname = arcname )
312318
313319 dduf_file_path = os .path .join (save_directory , dduf_filename )
314320
@@ -320,23 +326,30 @@ def is_saveable_module(name, value):
320326 if (
321327 os .path .exists (dduf_file_path )
322328 and os .path .isfile (dduf_file_path )
323- and tarfile . is_tarfile (dduf_file_path )
329+ and zipfile . is_zipfile (dduf_file_path )
324330 ):
325331 # Open in append mode if the file exists
326332 mode = "a"
327333 else :
328334 # Open in write mode to create it if it doesn't exist
329- mode = "w: "
330- with tarfile . open (dduf_file_path , mode ) as tar :
335+ mode = "w"
336+ with zipfile . ZipFile (dduf_file_path , mode = mode , compression = zipfile . ZIP_STORED ) as zipf :
331337 dir_to_archive = os .path .join (save_directory , pipeline_component_name )
332338 if os .path .isdir (dir_to_archive ):
333- tar .add (dir_to_archive , arcname = os .path .basename (dir_to_archive ))
334- # remove from save_directory after we added it to the archive
339+ zipdir (dir_to_archive , zipf )
335340 shutil .rmtree (dir_to_archive )
336341
337342 # finally save the config
338343 self .save_config (save_directory )
339344
345+ if dduf_filename :
346+ import zipfile
347+
348+ with zipfile .ZipFile (dduf_file_path , mode = "a" , compression = zipfile .ZIP_STORED ) as zipf :
349+ config_path = os .path .join (save_directory , self .config_name )
350+ zipf .write (config_path , arcname = os .path .basename (config_path ))
351+ os .remove (config_path )
352+
340353 if push_to_hub :
341354 # Create a new empty model card and eventually tag it
342355 model_card = load_or_create_model_card (repo_id , token = token , is_pipeline = True )
@@ -652,7 +665,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
652665 variant (`str`, *optional*):
653666 Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
654667 loading `from_flax`.
655- dduf(`str`, *optional*):
668+ dduf (`str`, *optional*):
656669 Load weights from the specified dduf archive or folder.
657670
658671 <Tip>
@@ -796,29 +809,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
796809 )
797810 logger .warning (warn_msg )
798811
799- config_dict = cls .load_config (cached_folder )
800-
801- # pop out "_ignore_files" as it is only needed for download
802- config_dict .pop ("_ignore_files" , None )
803-
804812 if dduf :
805- import tarfile
813+ import zipfile
806814
807- tar_file_path = os .path .join (cached_folder , dduf )
815+ zip_file_path = os .path .join (cached_folder , dduf )
808816 extract_to = os .path .join (cached_folder , f"{ dduf } _extracted" )
809- # if tar file, we need to extract the tarfile and remove it
810- if os .path .isfile (tar_file_path ):
811- if tarfile . is_tarfile ( tar_file_path ):
812- with tarfile . open ( tar_file_path , "r" ) as tar :
813- tar .extractall (extract_to )
814- # remove tar archive to free memory
815- os .remove (tar_file_path )
817+ # if zip file, we need to extract the zipfile and remove it
818+ if os .path .isfile (zip_file_path ):
819+ if zipfile . is_zipfile ( zip_file_path ):
820+ with zipfile . ZipFile ( zip_file_path , "r" ) as zipf :
821+ zipf .extractall (extract_to )
822+ # remove zip archive to free memory
823+ os .remove (zip_file_path )
816824 # rename folder to match the name of the dduf archive
817- os .rename (extract_to , tar_file_path )
825+ os .rename (extract_to , zip_file_path )
818826 else :
819- raise RuntimeError ("The dduf path passed is not a tar archive" )
827+ raise RuntimeError ("The dduf path passed is not a zip archive" )
820828 # udapte cached folder location as the dduf content is in a seperate folder
821- cached_folder = tar_file_path
829+ cached_folder = zip_file_path
830+
831+ config_dict = cls .load_config (cached_folder )
832+
833+ # pop out "_ignore_files" as it is only needed for download
834+ config_dict .pop ("_ignore_files" , None )
822835
823836 # 2. Define which model components should load variants
824837 # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
0 commit comments