Skip to content

Commit 54991ac

Browse files
committed
Merge remote-tracking branch 'origin/dduf' into dduf
2 parents 53e100b + 1e5ebf5 commit 54991ac

File tree

5 files changed

+32
-56
lines changed

5 files changed

+32
-56
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def _fetch_index_file(
281281
revision,
282282
user_agent,
283283
commit_hash,
284-
dduf_entries=None,
284+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
285285
):
286286
if is_local:
287287
index_file = Path(
@@ -359,7 +359,7 @@ def _fetch_index_file_legacy(
359359
revision,
360360
user_agent,
361361
commit_hash,
362-
dduf_entries=None,
362+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
363363
):
364364
if is_local:
365365
index_file = Path(

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
586586
variant = kwargs.pop("variant", None)
587587
use_safetensors = kwargs.pop("use_safetensors", None)
588588
quantization_config = kwargs.pop("quantization_config", None)
589-
dduf_entries = kwargs.pop("dduf_entries", None)
589+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
590590

591591
allow_pickle = False
592592
if use_safetensors is None:

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12901290
use_onnx = kwargs.pop("use_onnx", None)
12911291
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
12921292
trust_remote_code = kwargs.pop("trust_remote_code", False)
1293-
dduf_file = kwargs.pop("dduf_file", None)
1293+
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
12941294

12951295
allow_pickle = False
12961296
if use_safetensors is None:
@@ -1309,11 +1309,11 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13091309
local_files_only = True
13101310
model_info_call_error = e # save error to reraise it if model is not cached locally
13111311

1312-
if dduf_file is not None and not local_files_only:
1313-
dduf_available = False
1314-
for sibling in info.siblings:
1315-
dduf_available = dduf_file in sibling.rfilename
1316-
if not dduf_available:
1312+
if (
1313+
not local_files_only
1314+
and dduf_file is not None
1315+
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
1316+
):
13171317
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
13181318

13191319
if not local_files_only and not dduf_file:
@@ -1477,27 +1477,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14771477
)
14781478

14791479
# retrieve pipeline class from local file
1480-
if not dduf_file:
1481-
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1482-
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
1483-
1484-
diffusers_module = importlib.import_module(__name__.split(".")[0])
1485-
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
1486-
1487-
if pipeline_class is not None and pipeline_class._load_connected_pipes:
1488-
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
1489-
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
1490-
for connected_pipe_repo_id in connected_pipes:
1491-
download_kwargs = {
1492-
"cache_dir": cache_dir,
1493-
"force_download": force_download,
1494-
"proxies": proxies,
1495-
"local_files_only": local_files_only,
1496-
"token": token,
1497-
"variant": variant,
1498-
"use_safetensors": use_safetensors,
1499-
}
1500-
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
1480+
if dduf_file:
1481+
return cached_folder
1482+
1483+
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1484+
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
1485+
1486+
diffusers_module = importlib.import_module(__name__.split(".")[0])
1487+
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
1488+
1489+
if pipeline_class is not None and pipeline_class._load_connected_pipes:
1490+
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
1491+
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
1492+
for connected_pipe_repo_id in connected_pipes:
1493+
download_kwargs = {
1494+
"cache_dir": cache_dir,
1495+
"force_download": force_download,
1496+
"proxies": proxies,
1497+
"local_files_only": local_files_only,
1498+
"token": token,
1499+
"variant": variant,
1500+
"use_safetensors": use_safetensors,
1501+
}
1502+
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
15011503

15021504
return cached_folder
15031505

src/diffusers/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
WEIGHTS_INDEX_NAME,
3636
WEIGHTS_NAME,
3737
)
38-
from .dduf import DDUFReader
3938
from .deprecation_utils import deprecate
4039
from .doc_utils import replace_example_docstring
4140
from .dynamic_modules_utils import get_class_from_dynamic_module
@@ -68,7 +67,6 @@
6867
is_flax_available,
6968
is_ftfy_available,
7069
is_google_colab,
71-
is_huggingface_hub_version,
7270
is_inflect_available,
7371
is_invisible_watermark_available,
7472
is_k_diffusion_available,

src/diffusers/utils/import_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,6 @@
317317
_timm_available = False
318318

319319

320-
_huggingface_hub_available = importlib.util.find_spec("huggingface_hub") is not None
321-
if _huggingface_hub_available:
322-
try:
323-
_huggingface_hub_version = importlib_metadata.version("huggingface_hub")
324-
logger.info(f"huggingface_hub version {_huggingface_hub_version} available.")
325-
except importlib_metadata.PackageNotFoundError:
326-
_huggingface_hub_available = False
327-
328-
329320
def is_timm_available():
330321
return _timm_available
331322

@@ -798,21 +789,6 @@ def is_k_diffusion_version(operation: str, version: str):
798789
return compare_versions(parse(_k_diffusion_version), operation, version)
799790

800791

801-
def is_huggingface_hub_version(operation: str, version: str):
802-
"""
803-
Compares the current huggingface_hub version to a given reference with an operation.
804-
805-
Args:
806-
operation (`str`):
807-
A string representation of an operator, such as `">"` or `"<="`
808-
version (`str`):
809-
A version string
810-
"""
811-
if not _huggingface_hub_available:
812-
return False
813-
return compare_versions(parse(_huggingface_hub_version), operation, version)
814-
815-
816792
def get_objects_from_module(module):
817793
"""
818794
Returns a dict of object names and values in a module, while skipping private/internal objects

0 commit comments

Comments
 (0)