Skip to content

Commit 1e5ebf5

Browse files
SunMarcWauplin
andauthored
Apply suggestions from code review
Co-authored-by: Lucain <[email protected]>
1 parent 63575af commit 1e5ebf5

File tree

5 files changed

+32
-56
lines changed

5 files changed

+32
-56
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def _fetch_index_file(
281281
revision,
282282
user_agent,
283283
commit_hash,
284-
dduf_entries=None,
284+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
285285
):
286286
if is_local:
287287
index_file = Path(
@@ -359,7 +359,7 @@ def _fetch_index_file_legacy(
359359
revision,
360360
user_agent,
361361
commit_hash,
362-
dduf_entries=None,
362+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
363363
):
364364
if is_local:
365365
index_file = Path(

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
586586
variant = kwargs.pop("variant", None)
587587
use_safetensors = kwargs.pop("use_safetensors", None)
588588
quantization_config = kwargs.pop("quantization_config", None)
589-
dduf_entries = kwargs.pop("dduf_entries", None)
589+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
590590

591591
allow_pickle = False
592592
if use_safetensors is None:

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12911291
use_onnx = kwargs.pop("use_onnx", None)
12921292
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
12931293
trust_remote_code = kwargs.pop("trust_remote_code", False)
1294-
dduf_file = kwargs.pop("dduf_file", None)
1294+
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
12951295

12961296
allow_pickle = False
12971297
if use_safetensors is None:
@@ -1310,11 +1310,11 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13101310
local_files_only = True
13111311
model_info_call_error = e # save error to reraise it if model is not cached locally
13121312

1313-
if dduf_file is not None and not local_files_only:
1314-
dduf_available = False
1315-
for sibling in info.siblings:
1316-
dduf_available = dduf_file in sibling.rfilename
1317-
if not dduf_available:
1313+
if (
1314+
not local_files_only
1315+
and dduf_file is not None
1316+
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
1317+
):
13181318
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
13191319

13201320
if not local_files_only and not dduf_file:
@@ -1478,27 +1478,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14781478
)
14791479

14801480
# retrieve pipeline class from local file
1481-
if not dduf_file:
1482-
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1483-
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
1484-
1485-
diffusers_module = importlib.import_module(__name__.split(".")[0])
1486-
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
1487-
1488-
if pipeline_class is not None and pipeline_class._load_connected_pipes:
1489-
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
1490-
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
1491-
for connected_pipe_repo_id in connected_pipes:
1492-
download_kwargs = {
1493-
"cache_dir": cache_dir,
1494-
"force_download": force_download,
1495-
"proxies": proxies,
1496-
"local_files_only": local_files_only,
1497-
"token": token,
1498-
"variant": variant,
1499-
"use_safetensors": use_safetensors,
1500-
}
1501-
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
1481+
if dduf_file:
1482+
return cached_folder
1483+
1484+
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1485+
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
1486+
1487+
diffusers_module = importlib.import_module(__name__.split(".")[0])
1488+
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
1489+
1490+
if pipeline_class is not None and pipeline_class._load_connected_pipes:
1491+
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
1492+
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
1493+
for connected_pipe_repo_id in connected_pipes:
1494+
download_kwargs = {
1495+
"cache_dir": cache_dir,
1496+
"force_download": force_download,
1497+
"proxies": proxies,
1498+
"local_files_only": local_files_only,
1499+
"token": token,
1500+
"variant": variant,
1501+
"use_safetensors": use_safetensors,
1502+
}
1503+
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
15021504

15031505
return cached_folder
15041506

src/diffusers/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
WEIGHTS_INDEX_NAME,
3636
WEIGHTS_NAME,
3737
)
38-
from .dduf import DDUFReader
3938
from .deprecation_utils import deprecate
4039
from .doc_utils import replace_example_docstring
4140
from .dynamic_modules_utils import get_class_from_dynamic_module
@@ -68,7 +67,6 @@
6867
is_flax_available,
6968
is_ftfy_available,
7069
is_google_colab,
71-
is_huggingface_hub_version,
7270
is_inflect_available,
7371
is_invisible_watermark_available,
7472
is_k_diffusion_available,

src/diffusers/utils/import_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,6 @@
317317
_timm_available = False
318318

319319

320-
_huggingface_hub_available = importlib.util.find_spec("huggingface_hub") is not None
321-
if _huggingface_hub_available:
322-
try:
323-
_huggingface_hub_version = importlib_metadata.version("huggingface_hub")
324-
logger.info(f"huggingface_hub version {_huggingface_hub_version} available.")
325-
except importlib_metadata.PackageNotFoundError:
326-
_huggingface_hub_available = False
327-
328-
329320
def is_timm_available():
330321
return _timm_available
331322

@@ -798,21 +789,6 @@ def is_k_diffusion_version(operation: str, version: str):
798789
return compare_versions(parse(_k_diffusion_version), operation, version)
799790

800791

801-
def is_huggingface_hub_version(operation: str, version: str):
802-
"""
803-
Compares the current huggingface_hub version to a given reference with an operation.
804-
805-
Args:
806-
operation (`str`):
807-
A string representation of an operator, such as `">"` or `"<="`
808-
version (`str`):
809-
A version string
810-
"""
811-
if not _huggingface_hub_available:
812-
return False
813-
return compare_versions(parse(_huggingface_hub_version), operation, version)
814-
815-
816792
def get_objects_from_module(module):
817793
"""
818794
Returns a dict of object names and values in a module, while skipping private/internal objects

0 commit comments

Comments
 (0)