Skip to content

Commit 67b617e

Browse files
committed
updates
1 parent 6a163c7 commit 67b617e

File tree

11 files changed

+227
-59
lines changed

11 files changed

+227
-59
lines changed

src/diffusers/configuration_utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -361,21 +361,7 @@ def load_config(
361361
)
362362
# Custom path for now
363363
if dduf_entries:
364-
if subfolder is not None:
365-
raise ValueError(
366-
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
367-
"Please check the DDUF structure"
368-
)
369-
# paths inside a DDUF file must always be "/"
370-
config_file = (
371-
cls.config_name
372-
if pretrained_model_name_or_path == ""
373-
else "/".join([pretrained_model_name_or_path, cls.config_name])
374-
)
375-
if config_file not in dduf_entries:
376-
raise ValueError(
377-
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
378-
)
364+
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, subfolder, dduf_entries)
379365
elif os.path.isfile(pretrained_model_name_or_path):
380366
config_file = pretrained_model_name_or_path
381367
elif os.path.isdir(pretrained_model_name_or_path):
@@ -636,6 +622,27 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
636622
with open(json_file_path, "w", encoding="utf-8") as writer:
637623
writer.write(self.to_json_string())
638624

625+
@classmethod
626+
def _get_config_file_from_dduf(
627+
cls, pretrained_model_name_or_path: str, subfolder: str, dduf_entries: Dict[str, DDUFEntry]
628+
):
629+
if subfolder is not None:
630+
raise ValueError(
631+
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
632+
"Please check the DDUF structure"
633+
)
634+
# paths inside a DDUF file must always be "/"
635+
config_file = (
636+
cls.config_name
637+
if pretrained_model_name_or_path == ""
638+
else "/".join([pretrained_model_name_or_path, cls.config_name])
639+
)
640+
if config_file not in dduf_entries:
641+
raise ValueError(
642+
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
643+
)
644+
return config_file
645+
639646

640647
def register_to_config(init):
641648
r"""

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from pathlib import Path
2020
from typing import Any, Callable, Dict, List, Optional, Union
2121

22+
import requests
2223
import torch
23-
from huggingface_hub import DDUFEntry, ModelCard, model_info
24-
from huggingface_hub.utils import validate_hf_hub_args
24+
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
25+
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
2526
from packaging import version
27+
from requests.exceptions import HTTPError
2628

2729
from .. import __version__
2830
from ..utils import (
@@ -36,6 +38,7 @@
3638
is_accelerate_available,
3739
is_peft_available,
3840
is_transformers_available,
41+
is_transformers_version,
3942
logging,
4043
)
4144
from ..utils.torch_utils import is_compiled_module
@@ -987,3 +990,70 @@ def _get_ignore_patterns(
987990
)
988991

989992
return ignore_patterns
993+
994+
995+
def _download_dduf_file(
996+
pretrained_model_name: str,
997+
dduf_file: str,
998+
pipeline_class_name: str,
999+
cache_dir: str,
1000+
proxies: str,
1001+
local_files_only: bool,
1002+
token: str,
1003+
revision: str,
1004+
):
1005+
model_info_call_error = None
1006+
if not local_files_only:
1007+
try:
1008+
info = model_info(pretrained_model_name, token=token, revision=revision)
1009+
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
1010+
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
1011+
local_files_only = True
1012+
model_info_call_error = e # save error to reraise it if model is not cached locally
1013+
1014+
if (
1015+
not local_files_only
1016+
and dduf_file is not None
1017+
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
1018+
):
1019+
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
1020+
1021+
try:
1022+
user_agent = {"pipeline_class": pipeline_class_name, "dduf": True}
1023+
cached_folder = snapshot_download(
1024+
pretrained_model_name,
1025+
cache_dir=cache_dir,
1026+
proxies=proxies,
1027+
local_files_only=local_files_only,
1028+
token=token,
1029+
revision=revision,
1030+
allow_patterns=[dduf_file],
1031+
user_agent=user_agent,
1032+
)
1033+
return cached_folder
1034+
except FileNotFoundError:
1035+
# Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
1036+
# This can happen in two cases:
1037+
# 1. If the user passed `local_files_only=True` => we raise the error directly
1038+
# 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
1039+
if model_info_call_error is None:
1040+
# 1. user passed `local_files_only=True`
1041+
raise
1042+
else:
1043+
# 2. we forced `local_files_only=True` when `model_info` failed
1044+
raise EnvironmentError(
1045+
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
1046+
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
1047+
" above."
1048+
) from model_info_call_error
1049+
1050+
1051+
def _maybe_raise_error_for_incorrect_transformers(config_dict):
1052+
has_transformers_component = False
1053+
for k in config_dict:
1054+
if isinstance(config_dict[k], list):
1055+
has_transformers_component = config_dict[k][0] == "transformers"
1056+
if has_transformers_component:
1057+
break
1058+
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
1059+
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,15 @@
7474
CONNECTED_PIPES_KEYS,
7575
CUSTOM_PIPELINE_FILE_NAME,
7676
LOADABLE_CLASSES,
77+
_download_dduf_file,
7778
_fetch_class_library_tuple,
7879
_get_custom_components_and_folders,
7980
_get_custom_pipeline_class,
8081
_get_final_device_map,
8182
_get_ignore_patterns,
8283
_get_pipeline_class,
8384
_identify_model_variants,
85+
_maybe_raise_error_for_incorrect_transformers,
8486
_maybe_raise_warning_for_inpainting,
8587
_resolve_custom_pipeline_and_cls,
8688
_unwrap_model,
@@ -728,8 +730,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
728730
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
729731
)
730732

731-
if dduf_file and custom_pipeline:
732-
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
733+
if dduf_file:
734+
if custom_pipeline:
735+
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
736+
if load_connected_pipeline:
737+
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
733738

734739
# 1. Download the checkpoints and configs
735740
# use snapshot download here to get it working from from_pretrained
@@ -785,14 +790,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
785790
config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries)
786791

787792
if dduf_file:
788-
has_transformers_component = False
789-
for k in config_dict:
790-
if isinstance(config_dict[k], list):
791-
has_transformers_component = config_dict[k][0] == "transformers"
792-
if has_transformers_component:
793-
break
794-
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
795-
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
793+
_maybe_raise_error_for_incorrect_transformers(config_dict)
796794

797795
# pop out "_ignore_files" as it is only needed for download
798796
config_dict.pop("_ignore_files", None)
@@ -1328,8 +1326,21 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13281326
trust_remote_code = kwargs.pop("trust_remote_code", False)
13291327
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
13301328

1331-
if dduf_file and custom_pipeline:
1332-
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
1329+
if dduf_file:
1330+
if custom_pipeline:
1331+
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
1332+
if load_connected_pipeline:
1333+
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
1334+
return _download_dduf_file(
1335+
pretrained_model_name=pretrained_model_name,
1336+
dduf_file=dduf_file,
1337+
pipeline_class_name=cls.__name__,
1338+
cache_dir=cache_dir,
1339+
proxies=proxies,
1340+
local_files_only=local_files_only,
1341+
token=token,
1342+
revision=revision,
1343+
)
13331344

13341345
allow_pickle = False
13351346
if use_safetensors is None:
@@ -1348,14 +1359,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13481359
local_files_only = True
13491360
model_info_call_error = e # save error to reraise it if model is not cached locally
13501361

1351-
if (
1352-
not local_files_only
1353-
and dduf_file is not None
1354-
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
1355-
):
1356-
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
1357-
1358-
if not local_files_only and not dduf_file:
1362+
if not local_files_only:
13591363
filenames = {sibling.rfilename for sibling in info.siblings}
13601364
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
13611365
warn_msg = (
@@ -1498,10 +1502,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14981502
user_agent["custom_pipeline"] = custom_pipeline
14991503

15001504
# download all allow_patterns - ignore_patterns
1501-
# also allow downloading the dduf_file
1502-
if dduf_file is not None:
1503-
allow_patterns = [dduf_file]
1504-
ignore_patterns = []
15051505
try:
15061506
cached_folder = snapshot_download(
15071507
pretrained_model_name,
@@ -1515,10 +1515,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
15151515
user_agent=user_agent,
15161516
)
15171517

1518-
# retrieve pipeline class from local file
1519-
if dduf_file:
1520-
return cached_folder
1521-
15221518
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
15231519
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
15241520

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
is_gguf_available,
7171
is_gguf_version,
7272
is_google_colab,
73+
is_hf_hub_version,
7374
is_inflect_available,
7475
is_invisible_watermark_available,
7576
is_k_diffusion_available,

src/diffusers/utils/import_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@
115115
except importlib_metadata.PackageNotFoundError:
116116
_transformers_available = False
117117

118+
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
119+
try:
120+
_hf_hub_version = importlib_metadata.version("huggingface_hub")
121+
logger.debug(f"Successfully imported huggingface_hub version {_transformers_version}")
122+
except importlib_metadata.PackageNotFoundError:
123+
_hf_hub_available = False
124+
118125

119126
_inflect_available = importlib.util.find_spec("inflect") is not None
120127
try:
@@ -767,6 +774,21 @@ def is_transformers_version(operation: str, version: str):
767774
return compare_versions(parse(_transformers_version), operation, version)
768775

769776

777+
def is_hf_hub_version(operation: str, version: str):
778+
"""
779+
Compares the current Hugging Face Hub version to a given reference with an operation.
780+
781+
Args:
782+
operation (`str`):
783+
A string representation of an operator, such as `">"` or `"<="`
784+
version (`str`):
785+
A version string
786+
"""
787+
if not _hf_hub_available:
788+
return False
789+
return compare_versions(parse(_hf_hub_version), operation, version)
790+
791+
770792
def is_accelerate_version(operation: str, version: str):
771793
"""
772794
Compares the current Accelerate version to a given reference with an operation.

tests/pipelines/kandinsky/test_kandinsky_combined.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def test_float16_inference(self):
139139
def test_dict_tuple_outputs_equivalent(self):
140140
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
141141

142+
@unittest.skip("Test not supported.")
143+
def test_save_load_dduf(self):
144+
pass
145+
142146

143147
class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
144148
pipeline_class = KandinskyImg2ImgCombinedPipeline
@@ -248,6 +252,10 @@ def test_dict_tuple_outputs_equivalent(self):
248252
def test_save_load_optional_components(self):
249253
super().test_save_load_optional_components(expected_max_difference=5e-4)
250254

255+
@unittest.skip("Test not supported.")
256+
def test_save_load_dduf(self):
257+
pass
258+
251259

252260
class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
253261
pipeline_class = KandinskyInpaintCombinedPipeline
@@ -363,3 +371,7 @@ def test_save_load_optional_components(self):
363371

364372
def test_save_load_local(self):
365373
super().test_save_load_local(expected_max_difference=5e-3)
374+
375+
@unittest.skip("Test not supported.")
376+
def test_save_load_dduf(self):
377+
pass

tests/pipelines/kandinsky2_2/test_kandinsky_combined.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def test_callback_inputs(self):
159159
def test_callback_cfg(self):
160160
pass
161161

162+
@unittest.skip("Test not supported.")
163+
def test_save_load_dduf(self):
164+
pass
165+
162166

163167
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
164168
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
@@ -281,6 +285,10 @@ def test_callback_inputs(self):
281285
def test_callback_cfg(self):
282286
pass
283287

288+
@unittest.skip("Test not supported.")
289+
def test_save_load_dduf(self):
290+
pass
291+
284292

285293
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
286294
pipeline_class = KandinskyV22InpaintCombinedPipeline
@@ -404,3 +412,7 @@ def test_callback_inputs(self):
404412

405413
def test_callback_cfg(self):
406414
pass
415+
416+
@unittest.skip("Test not supported.")
417+
def test_save_load_dduf(self):
418+
pass

tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ def test_adapter_lcm_custom_timesteps(self):
336336

337337
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
338338

339+
@unittest.skip("Test not supported.")
340+
def test_save_load_dduf(self):
341+
pass
342+
339343

340344
class StableDiffusionFullAdapterPipelineFastTests(
341345
AdapterTests, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,7 @@ def test_adapter_sdxl_lcm_custom_timesteps(self):
671671
print(",".join(debug))
672672

673673
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
674+
675+
@unittest.skip("Test not supported.")
676+
def test_save_load_dduf(self):
677+
pass

0 commit comments

Comments
 (0)