Skip to content

Commit 5ec3951

Browse files
committed
quality
1 parent 0cb1b98 commit 5ec3951

File tree

8 files changed

+27
-15
lines changed

8 files changed

+27
-15
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"filelock",
102102
"flax>=0.4.1",
103103
"hf-doc-builder>=0.3.0",
104-
"huggingface-hub>=0.27.0",
104+
"huggingface-hub>=0.26.0",
105105
"requests-mock==1.10.0",
106106
"importlib_metadata",
107107
"invisible-watermark>=0.2.0",

src/diffusers/configuration_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import re
2525
from collections import OrderedDict
2626
from pathlib import Path
27-
from typing import Any, Dict, Tuple, Union, Optional, Dict
27+
from typing import Any, Dict, Optional, Tuple, Union
2828

2929
import numpy as np
30-
from huggingface_hub import create_repo, hf_hub_download, DDUFEntry
30+
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
3131
from huggingface_hub.utils import (
3232
EntryNotFoundError,
3333
RepositoryNotFoundError,
@@ -560,7 +560,9 @@ def extract_init_dict(cls, config_dict, **kwargs):
560560
return init_dict, unused_kwargs, hidden_config_dict
561561

562562
@classmethod
563-
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None):
563+
def _dict_from_json_file(
564+
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
565+
):
564566
if dduf_entries:
565567
text = dduf_entries[json_file].read_text()
566568
else:

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"filelock": "filelock",
1010
"flax": "flax>=0.4.1",
1111
"hf-doc-builder": "hf-doc-builder>=0.3.0",
12-
"huggingface-hub": "huggingface-hub>=0.27.0",
12+
"huggingface-hub": "huggingface-hub>=0.26.0",
1313
"requests-mock": "requests-mock==1.10.0",
1414
"importlib_metadata": "importlib_metadata",
1515
"invisible-watermark": "invisible-watermark>=0.2.0",

src/diffusers/models/model_loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ def _fetch_remapped_cls_from_config(config, old_class):
129129
return old_class
130130

131131

132-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_entries: Optional[Dict[str, DDUFEntry]]=None):
132+
def load_state_dict(
133+
checkpoint_file: Union[str, os.PathLike],
134+
variant: Optional[str] = None,
135+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
136+
):
133137
"""
134138
Reads a checkpoint file, returning properly formatted errors if they arise.
135139
"""

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import safetensors
2929
import torch
30-
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
30+
from huggingface_hub import DDUFEntry, Dict, create_repo, split_torch_state_dict_into_shards
3131
from huggingface_hub.utils import validate_hf_hub_args
3232
from torch import Tensor, nn
3333

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import requests
2929
import torch
3030
from huggingface_hub import (
31+
DDUFEntry,
3132
ModelCard,
3233
create_repo,
3334
hf_hub_download,
@@ -1313,8 +1314,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13131314
not local_files_only
13141315
and dduf_file is not None
13151316
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
1316-
):
1317-
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
1317+
):
1318+
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
13181319

13191320
if not local_files_only and not dduf_file:
13201321
filenames = {sibling.rfilename for sibling in info.siblings}

src/diffusers/pipelines/transformers_loading_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from huggingface_hub import DDUFEntry
2121
from tqdm import tqdm
2222

23-
from ..utils import is_safetensors_available, is_transformers_version
23+
from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version
2424

2525

2626
if TYPE_CHECKING:
@@ -93,15 +93,16 @@ def load_transformers_model_from_dduf(
9393

9494
with tempfile.TemporaryDirectory() as tmp_dir:
9595
from transformers import AutoConfig, GenerationConfig
96+
9697
tmp_config_file = os.path.join(tmp_dir, "config.json")
9798
with open(tmp_config_file, "w") as f:
9899
f.write(config_file.read_text())
99100
config = AutoConfig.from_pretrained(tmp_config_file)
100101
if generation_config is not None:
101-
tmp_generation_config_file = os.path.join(tmp_generation_config_file, "generation_config.json")
102+
tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json")
102103
with open(tmp_generation_config_file, "w") as f:
103104
f.write(generation_config.read_text())
104-
generation_config = GenerationConfig.from_pretrained(tmp_config_file)
105+
generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file)
105106
state_dict = {}
106107
with contextlib.ExitStack() as stack:
107108
for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
@@ -112,5 +113,9 @@ def load_transformers_model_from_dduf(
112113
# Update the state dictionary with tensors
113114
state_dict.update(tensors)
114115
return cls.from_pretrained(
115-
pretrained_model_name_or_path=None, config=config, generation_config=generation_config, state_dict=state_dict, **kwargs
116-
)
116+
pretrained_model_name_or_path=None,
117+
config=config,
118+
generation_config=generation_config,
119+
state_dict=state_dict,
120+
**kwargs,
121+
)

src/diffusers/utils/hub_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _get_model_file(
292292
user_agent: Optional[Union[Dict, str]] = None,
293293
revision: Optional[str] = None,
294294
commit_hash: Optional[str] = None,
295-
dduf_entries: Optional[Dict[str, DDUFEntry]]=None,
295+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
296296
):
297297
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
298298

0 commit comments

Comments
 (0)