Skip to content

Commit b14bffe

Browse files
committed
first draft
1 parent e66c4d0 commit b14bffe

File tree

9 files changed

+116
-38
lines changed

9 files changed

+116
-38
lines changed

src/diffusers/configuration_utils.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def load_config(
347347
_ = kwargs.pop("mirror", None)
348348
subfolder = kwargs.pop("subfolder", None)
349349
user_agent = kwargs.pop("user_agent", {})
350+
dduf_reader = kwargs.pop("dduf_reader", None)
350351

351352
user_agent = {**user_agent, "file_type": "config"}
352353
user_agent = http_user_agent(user_agent)
@@ -358,8 +359,25 @@ def load_config(
358359
"`self.config_name` is not defined. Note that one should not load a config from "
359360
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
360361
)
361-
362-
if os.path.isfile(pretrained_model_name_or_path):
362+
# Custom path for now
363+
if dduf_reader:
364+
if subfolder is not None:
365+
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
366+
config_file = os.path.join(subfolder, cls.config_name)
367+
else:
368+
raise ValueError(
369+
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
370+
)
371+
elif dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, cls.config_name)):
372+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
373+
else:
374+
raise ValueError(
375+
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
376+
)
377+
print(f"File found: {config_file}")
378+
elif not dduf_reader:
379+
print("not dduf")
380+
elif os.path.isfile(pretrained_model_name_or_path):
363381
config_file = pretrained_model_name_or_path
364382
elif os.path.isdir(pretrained_model_name_or_path):
365383
if subfolder is not None and os.path.isfile(
@@ -426,10 +444,8 @@ def load_config(
426444
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
427445
f"containing a {cls.config_name} file"
428446
)
429-
430447
try:
431-
# Load config dict
432-
config_dict = cls._dict_from_json_file(config_file)
448+
config_dict = cls._dict_from_json_file(config_file, dduf_reader=dduf_reader)
433449

434450
commit_hash = extract_commit_hash(config_file)
435451
except (json.JSONDecodeError, UnicodeDecodeError):
@@ -552,9 +568,12 @@ def extract_init_dict(cls, config_dict, **kwargs):
552568
return init_dict, unused_kwargs, hidden_config_dict
553569

554570
@classmethod
555-
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
556-
with open(json_file, "r", encoding="utf-8") as reader:
557-
text = reader.read()
571+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_reader=None):
572+
if dduf_reader:
573+
text = dduf_reader.read_file(json_file, encoding="utf-8")
574+
else:
575+
with open(json_file, "r", encoding="utf-8") as reader:
576+
text = reader.read()
558577
return json.loads(text)
559578

560579
def __repr__(self):

src/diffusers/models/model_loading_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
128128
return old_class
129129

130130

131-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
131+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_reader=None):
132132
"""
133133
Reads a checkpoint file, returning properly formatted errors if they arise.
134134
"""
@@ -138,8 +138,15 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
138138
return checkpoint_file
139139
try:
140140
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
141+
if dduf_reader:
142+
checkpoint_file = dduf_reader.read_file(checkpoint_file)
141143
if file_extension == SAFETENSORS_FILE_EXTENSION:
142-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
144+
if dduf_reader:
145+
# tensors are loaded on cpu
146+
return safetensors.torch.load(checkpoint_file)
147+
else:
148+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
149+
143150
else:
144151
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
145152
return torch.load(
@@ -272,6 +279,7 @@ def _fetch_index_file(
272279
revision,
273280
user_agent,
274281
commit_hash,
282+
dduf_reader=None,
275283
):
276284
if is_local:
277285
index_file = Path(
@@ -297,6 +305,7 @@ def _fetch_index_file(
297305
subfolder=None,
298306
user_agent=user_agent,
299307
commit_hash=commit_hash,
308+
dduf_reader=dduf_reader,
300309
)
301310
index_file = Path(index_file)
302311
except (EntryNotFoundError, EnvironmentError):

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
557557
variant = kwargs.pop("variant", None)
558558
use_safetensors = kwargs.pop("use_safetensors", None)
559559
quantization_config = kwargs.pop("quantization_config", None)
560+
dduf_reader = kwargs.pop("dduf_reader", None)
560561

561562
allow_pickle = False
562563
if use_safetensors is None:
@@ -649,6 +650,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
649650
revision=revision,
650651
subfolder=subfolder,
651652
user_agent=user_agent,
653+
dduf_reader=dduf_reader,
652654
**kwargs,
653655
)
654656
# no in-place modification of the original config.
@@ -724,6 +726,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
724726
"revision": revision,
725727
"user_agent": user_agent,
726728
"commit_hash": commit_hash,
729+
"dduf_reader": dduf_reader,
727730
}
728731
index_file = _fetch_index_file(**index_file_kwargs)
729732
# In case the index file was not found we still have to consider the legacy format.
@@ -759,7 +762,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
759762

760763
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
761764
else:
762-
if is_sharded:
765+
# in the case it is sharded, we have already the index
766+
if is_sharded and not dduf_reader:
763767
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
764768
pretrained_model_name_or_path,
765769
index_file,
@@ -790,6 +794,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
790794
subfolder=subfolder,
791795
user_agent=user_agent,
792796
commit_hash=commit_hash,
797+
dduf_reader=dduf_reader,
793798
)
794799

795800
except IOError as e:
@@ -813,6 +818,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
813818
subfolder=subfolder,
814819
user_agent=user_agent,
815820
commit_hash=commit_hash,
821+
dduf_reader=dduf_reader,
816822
)
817823

818824
if low_cpu_mem_usage:
@@ -837,7 +843,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
837843
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
838844
elif is_quant_method_bnb:
839845
param_device = torch.cuda.current_device()
840-
state_dict = load_state_dict(model_file, variant=variant)
846+
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
841847
model._convert_deprecated_attention_blocks(state_dict)
842848

843849
# move the params from meta device to cpu
@@ -937,7 +943,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
937943
else:
938944
model = cls.from_config(config, **unused_kwargs)
939945

940-
state_dict = load_state_dict(model_file, variant=variant)
946+
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
941947
model._convert_deprecated_attention_blocks(state_dict)
942948

943949
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def load_sub_model(
627627
low_cpu_mem_usage: bool,
628628
cached_folder: Union[str, os.PathLike],
629629
use_safetensors: bool,
630+
dduf_reader,
630631
):
631632
"""Helper method to load the module `name` from `library_name` and `class_name`"""
632633

@@ -721,7 +722,10 @@ def load_sub_model(
721722
loading_kwargs["low_cpu_mem_usage"] = False
722723

723724
# check if the module is in a subdirectory
724-
if os.path.isdir(os.path.join(cached_folder, name)):
725+
if dduf_reader:
726+
loading_kwargs["dduf_reader"] = dduf_reader
727+
loaded_sub_model = load_method(name, **loading_kwargs)
728+
elif os.path.isdir(os.path.join(cached_folder, name)):
725729
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
726730
else:
727731
# else load from the root directory

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
CONFIG_NAME,
5151
DEPRECATED_REVISION_ARGS,
5252
BaseOutput,
53+
DDUFReader,
5354
PushToHubMixin,
5455
is_accelerate_available,
5556
is_accelerate_version,
@@ -343,7 +344,7 @@ def zipdir(dir_to_archive, zipf):
343344
self.save_config(save_directory)
344345

345346
# Takes care of including the "model_index.json" inside the ZIP.
346-
# TODO: Include a DDUF a metadata file.
347+
# TODO: Include a DDUF a metadata file.
347348
if dduf_filename:
348349
import zipfile
349350

@@ -811,30 +812,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811812
)
812813
logger.warning(warn_msg)
813814

815+
dduf_reader = None
814816
if dduf:
815-
import zipfile
816-
817817
zip_file_path = os.path.join(cached_folder, dduf)
818-
extract_to = os.path.join(cached_folder, f"{dduf}_extracted")
819-
# if zip file, we need to extract the zipfile and remove it
820-
if os.path.isfile(zip_file_path):
821-
if zipfile.is_zipfile(zip_file_path):
822-
# with zipfile.ZipFile(zip_file_path, "r") as zipf:
823-
# zipf.extractall(extract_to)
824-
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
825-
file_list = zip_ref.infolist()
826-
for file_info in tqdm(file_list, desc="Extracting files"):
827-
zip_ref.extract(file_info, extract_to)
828-
# remove zip archive to free memory
829-
os.remove(zip_file_path)
830-
# rename folder to match the name of the dduf archive
831-
os.rename(extract_to, zip_file_path)
832-
else:
833-
raise RuntimeError("The dduf path passed is not a zip archive")
834-
# udapte cached folder location as the dduf content is in a seperate folder
835-
cached_folder = zip_file_path
818+
dduf_reader = DDUFReader(zip_file_path)
819+
# The reader contains already all the files needed, no need to check it again
820+
cached_folder = ""
836821

837-
config_dict = cls.load_config(cached_folder)
822+
config_dict = cls.load_config(cached_folder, dduf_reader=dduf_reader)
838823

839824
# pop out "_ignore_files" as it is only needed for download
840825
config_dict.pop("_ignore_files", None)
@@ -991,6 +976,7 @@ def load_module(name, value):
991976
low_cpu_mem_usage=low_cpu_mem_usage,
992977
cached_folder=cached_folder,
993978
use_safetensors=use_safetensors,
979+
dduf_reader=dduf_reader,
994980
)
995981
logger.info(
996982
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."

src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
3535
from .modeling_stable_audio import StableAudioProjectionModel
3636

37+
3738
if is_torch_xla_available():
3839
import torch_xla.core.xla_model as xm
3940

@@ -732,7 +733,7 @@ def __call__(
732733
if callback is not None and i % callback_steps == 0:
733734
step_idx = i // getattr(self.scheduler, "order", 1)
734735
callback(step_idx, t, latents)
735-
736+
736737
if XLA_AVAILABLE:
737738
xm.mark_step()
738739

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
WEIGHTS_INDEX_NAME,
3636
WEIGHTS_NAME,
3737
)
38+
from .dduf import DDUFReader
3839
from .deprecation_utils import deprecate
3940
from .doc_utils import replace_example_docstring
4041
from .dynamic_modules_utils import get_class_from_dynamic_module

src/diffusers/utils/dduf.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import zipfile
2+
3+
4+
class DDUFReader:
5+
def __init__(self, dduf_file):
6+
self.dduf_file = dduf_file
7+
self.files = []
8+
self.post_init()
9+
10+
def post_init(self):
11+
"""
12+
Check that the DDUF file is valid
13+
"""
14+
if not zipfile.is_zipfile(self.dduf_file):
15+
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
16+
17+
try:
18+
with zipfile.ZipFile(self.dduf_file, "r") as zf:
19+
# Check integrity and store file list
20+
zf.testzip() # Returns None if no corrupt files are found
21+
self.files = zf.namelist()
22+
except zipfile.BadZipFile:
23+
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
24+
except Exception as e:
25+
raise RuntimeError(f"An error occurred while validating the ZIP file: {e}")
26+
27+
def has_file(self, file):
28+
return file in self.files
29+
30+
def read_file(self, file_name, encoding=None):
31+
"""
32+
Reads the content of a specific file in the ZIP archive without extracting.
33+
"""
34+
if file_name not in self.files:
35+
raise ValueError(f"{file_name} is not in the list of files {self.files}")
36+
with zipfile.ZipFile(self.dduf_file, "r") as zf:
37+
with zf.open(file_name) as file:
38+
file = file.read()
39+
if encoding is not None:
40+
file = file.decode(encoding)
41+
return file

src/diffusers/utils/hub_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,20 @@ def _get_model_file(
291291
user_agent: Optional[Union[Dict, str]] = None,
292292
revision: Optional[str] = None,
293293
commit_hash: Optional[str] = None,
294+
dduf_reader=None,
294295
):
295296
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
296-
if os.path.isfile(pretrained_model_name_or_path):
297+
298+
if dduf_reader:
299+
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, weights_name)):
300+
return os.path.join(pretrained_model_name_or_path, weights_name)
301+
elif subfolder is not None and os.path.isfile(
302+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
303+
):
304+
return os.path.join(pretrained_model_name_or_path, weights_name)
305+
else:
306+
raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_reader.files}.")
307+
elif os.path.isfile(pretrained_model_name_or_path):
297308
return pretrained_model_name_or_path
298309
elif os.path.isdir(pretrained_model_name_or_path):
299310
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):

0 commit comments

Comments
 (0)