Skip to content

Commit 2eeda25

Browse files
committed
switch to zip uncompressed
1 parent 0389333 commit 2eeda25

File tree

1 file changed

+42
-29
lines changed

1 file changed

+42
-29
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def save_pretrained(
193193
variant: Optional[str] = None,
194194
max_shard_size: Optional[Union[int, str]] = None,
195195
push_to_hub: bool = False,
196-
dduf_format: bool = False,
197196
dduf_filename: Optional[Union[str, os.PathLike]] = None,
198197
**kwargs,
199198
):
@@ -229,9 +228,6 @@ class implements both a save and loading method. The pipeline is easily reloaded
229228
model_index_dict.pop("_module", None)
230229
model_index_dict.pop("_name_or_path", None)
231230

232-
if dduf_format and dduf_filename is None:
233-
raise RuntimeError("You need set dduf_filename if you want to save your model in DDUF format.")
234-
235231
if push_to_hub:
236232
commit_message = kwargs.pop("commit_message", None)
237233
private = kwargs.pop("private", False)
@@ -306,9 +302,19 @@ def is_saveable_module(name, value):
306302

307303
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
308304

309-
if dduf_format:
305+
if dduf_filename:
310306
import shutil
311-
import tarfile
307+
import zipfile
308+
309+
def zipdir(dir_to_archive, zipf):
310+
"zip a directory"
311+
for root, dirs, files in os.walk(dir_to_archive):
312+
for file in files:
313+
file_path = os.path.join(root, file)
314+
arcname = os.path.join(
315+
os.path.basename(dir_to_archive), os.path.relpath(file_path, start=dir_to_archive)
316+
)
317+
zipf.write(file_path, arcname=arcname)
312318

313319
dduf_file_path = os.path.join(save_directory, dduf_filename)
314320

@@ -320,23 +326,30 @@ def is_saveable_module(name, value):
320326
if (
321327
os.path.exists(dduf_file_path)
322328
and os.path.isfile(dduf_file_path)
323-
and tarfile.is_tarfile(dduf_file_path)
329+
and zipfile.is_zipfile(dduf_file_path)
324330
):
325331
# Open in append mode if the file exists
326332
mode = "a"
327333
else:
328334
# Open in write mode to create it if it doesn't exist
329-
mode = "w:"
330-
with tarfile.open(dduf_file_path, mode) as tar:
335+
mode = "w"
336+
with zipfile.ZipFile(dduf_file_path, mode=mode, compression=zipfile.ZIP_STORED) as zipf:
331337
dir_to_archive = os.path.join(save_directory, pipeline_component_name)
332338
if os.path.isdir(dir_to_archive):
333-
tar.add(dir_to_archive, arcname=os.path.basename(dir_to_archive))
334-
# remove from save_directory after we added it to the archive
339+
zipdir(dir_to_archive, zipf)
335340
shutil.rmtree(dir_to_archive)
336341

337342
# finally save the config
338343
self.save_config(save_directory)
339344

345+
if dduf_filename:
346+
import zipfile
347+
348+
with zipfile.ZipFile(dduf_file_path, mode="a", compression=zipfile.ZIP_STORED) as zipf:
349+
config_path = os.path.join(save_directory, self.config_name)
350+
zipf.write(config_path, arcname=os.path.basename(config_path))
351+
os.remove(config_path)
352+
340353
if push_to_hub:
341354
# Create a new empty model card and eventually tag it
342355
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
@@ -652,7 +665,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
652665
variant (`str`, *optional*):
653666
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
654667
loading `from_flax`.
655-
dduf(`str`, *optional*):
668+
dduf (`str`, *optional*):
656669
Load weights from the specified dduf archive or folder.
657670
658671
<Tip>
@@ -796,29 +809,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
796809
)
797810
logger.warning(warn_msg)
798811

799-
config_dict = cls.load_config(cached_folder)
800-
801-
# pop out "_ignore_files" as it is only needed for download
802-
config_dict.pop("_ignore_files", None)
803-
804812
if dduf:
805-
import tarfile
813+
import zipfile
806814

807-
tar_file_path = os.path.join(cached_folder, dduf)
815+
zip_file_path = os.path.join(cached_folder, dduf)
808816
extract_to = os.path.join(cached_folder, f"{dduf}_extracted")
809-
# if tar file, we need to extract the tarfile and remove it
810-
if os.path.isfile(tar_file_path):
811-
if tarfile.is_tarfile(tar_file_path):
812-
with tarfile.open(tar_file_path, "r") as tar:
813-
tar.extractall(extract_to)
814-
# remove tar archive to free memory
815-
os.remove(tar_file_path)
817+
# if zip file, we need to extract the zipfile and remove it
818+
if os.path.isfile(zip_file_path):
819+
if zipfile.is_zipfile(zip_file_path):
820+
with zipfile.ZipFile(zip_file_path, "r") as zipf:
821+
zipf.extractall(extract_to)
822+
# remove zip archive to free memory
823+
os.remove(zip_file_path)
816824
# rename folder to match the name of the dduf archive
817-
os.rename(extract_to, tar_file_path)
825+
os.rename(extract_to, zip_file_path)
818826
else:
819-
raise RuntimeError("The dduf path passed is not a tar archive")
827+
raise RuntimeError("The dduf path passed is not a zip archive")
820828
# udapte cached folder location as the dduf content is in a seperate folder
821-
cached_folder = tar_file_path
829+
cached_folder = zip_file_path
830+
831+
config_dict = cls.load_config(cached_folder)
832+
833+
# pop out "_ignore_files" as it is only needed for download
834+
config_dict.pop("_ignore_files", None)
822835

823836
# 2. Define which model components should load variants
824837
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.

0 commit comments

Comments
 (0)