Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.2",
"huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.2",
"huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
Expand Down
25 changes: 20 additions & 5 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import importlib
import os
import re
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from huggingface_hub import ModelCard, model_info
Expand All @@ -41,11 +39,12 @@
logging,
)
from ..utils.torch_utils import is_compiled_module
from .transformers_loading_utils import load_tokenizer_from_dduf, load_transformers_model_from_dduf


if is_transformers_available():
import transformers
from transformers import PreTrainedModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
Expand Down Expand Up @@ -664,7 +663,7 @@ def load_sub_model(
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)

load_method = getattr(class_obj, load_method_name)
load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None)

# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
Expand Down Expand Up @@ -750,6 +749,22 @@ def load_sub_model(
return loaded_sub_model


def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable:
"""
Return the method to load the sub model.

In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object
except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading
method that we need to use (won't use `from_pretrained`).
"""
if is_dduf:
if issubclass(class_obj, PreTrainedTokenizerBase):
return lambda *args, **kwargs: load_tokenizer_from_dduf(class_obj, *args, **kwargs)
if issubclass(class_obj, PreTrainedModel):
return lambda *args, **kwargs: load_transformers_model_from_dduf(class_obj, *args, **kwargs)
return getattr(class_obj, load_method_name)


def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
Expand Down
14 changes: 3 additions & 11 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
create_repo,
hf_hub_download,
model_info,
read_dduf_file,
snapshot_download,
)
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
Expand All @@ -53,7 +54,6 @@
PushToHubMixin,
is_accelerate_available,
is_accelerate_version,
is_huggingface_hub_version,
is_torch_npu_available,
is_torch_version,
is_transformers_version,
Expand Down Expand Up @@ -677,7 +677,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
dduf_file(`str`, *optional*):
Load weights from the specified dduf file
Load weights from the specified dduf file.

<Tip>

Expand Down Expand Up @@ -822,15 +822,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

dduf_entries = None
if dduf_file:
if not is_huggingface_hub_version(">", "0.26.3"):
(">=", "0.17.0.dev0")
raise RuntimeError(
"To load a dduf file, you need to install huggingface_hub>0.26.3. "
"You can install it with the following: `pip install --upgrade huggingface_hub`."
)

from huggingface_hub import read_dduf_file

dduf_file_path = os.path.join(cached_folder, dduf_file)
dduf_entries = read_dduf_file(dduf_file_path)
# The reader contains already all the files needed, no need to check it again
Expand All @@ -845,6 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
# with variant being `"fp16"`.
# TODO: adapt logic for DDUF files (at the moment, scans the local directory which doesn't make sense in DDUF context)
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
if len(model_variants) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
Expand Down
97 changes: 97 additions & 0 deletions src/diffusers/pipelines/transformers_loading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import contextlib
import tempfile
from typing import TYPE_CHECKING, Dict

from huggingface_hub import DDUFEntry

from ..utils import is_safetensors_available, is_transformers_available


if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer

if is_safetensors_available():
import safetensors.torch


def load_tokenizer_from_dduf(
cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry]
) -> "PreTrainedTokenizer":
"""
Load a tokenizer from a DDUF archive.

In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a workaround
by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted files. There is an
extra cost of extracting the files, but of limited impact as the tokenizer files are usually small-ish.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
for entry_name, entry in dduf_entries.items():
if entry_name.startswith(name + "/"):
tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/"))
with open(tmp_entry_path, "wb") as f:
with entry.as_mmap() as mm:
f.write(mm)
return cls.from_pretrained(tmp_dir, **kwargs)


def load_transformers_model_from_dduf(
cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
) -> "PreTrainedModel":
"""
Load a transformers model from a DDUF archive.

In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround
by instantiating a model from the config file and loading the weights from the DDUF archive directly.
Comment on lines +61 to +62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we could extract the folder whose entries we're using here and pass the folder name to from_pretrained()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do it but that would mean extracting a much larger amount of data (encoders can easily be 10GB). For tokenizers it's less of a problem to duplicate I/O processing since we are talking about only ~300MB.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#10171 (comment) could work but I am not sure about its consequences. I will let @ArthurZucker comment here. But regardless, I think this can be merged and it will unblock @SunMarc a bit to make further progress on his PR.

"""
config_file = dduf_entries.get(f"{name}/config.json")
if config_file is None:
raise EnvironmentError(
f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
)

weight_files = [
entry
for entry_name, entry in dduf_entries.items()
if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors")
]
if not weight_files:
raise EnvironmentError(
f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})."
)
if not is_safetensors_available():
raise EnvironmentError(
"Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`."
)

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_config_file = os.path.join(tmp_dir, "config.json")
with open(tmp_config_file, "w") as f:
f.write(config_file.read_text())

with contextlib.ExitStack() as stack:
state_dict = {
key: tensor
for entry in weight_files # loop over safetensors files
for key, tensor in safetensors.torch.load( # load tensors from mmap-ed bytes
stack.enter_context(entry.as_mmap()) # use enter_context to close the mmap when done
).items()
}
return cls.from_pretrained(tmp_dir, state_dict=state_dict, **kwargs)
Loading