Skip to content

Commit 7bc9347

Browse files
committed
from suggestions
1 parent 02a368b commit 7bc9347

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/diffusers/configuration_utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,12 @@ def load_config(
361361
)
362362
# Custom path for now
363363
if dduf_entries:
364-
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, subfolder, dduf_entries)
364+
if subfolder is not None:
365+
raise ValueError(
366+
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
367+
"Please check the DDUF structure"
368+
)
369+
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
365370
elif os.path.isfile(pretrained_model_name_or_path):
366371
config_file = pretrained_model_name_or_path
367372
elif os.path.isdir(pretrained_model_name_or_path):
@@ -623,14 +628,7 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
623628
writer.write(self.to_json_string())
624629

625630
@classmethod
626-
def _get_config_file_from_dduf(
627-
cls, pretrained_model_name_or_path: str, subfolder: str, dduf_entries: Dict[str, DDUFEntry]
628-
):
629-
if subfolder is not None:
630-
raise ValueError(
631-
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
632-
"Please check the DDUF structure"
633-
)
631+
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
634632
# paths inside a DDUF file must always be "/"
635633
config_file = (
636634
cls.config_name

src/diffusers/models/model_loading_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from huggingface_hub.utils import EntryNotFoundError
2929

3030
from ..utils import (
31+
GGUF_FILE_EXTENSION,
3132
SAFE_WEIGHTS_INDEX_NAME,
3233
SAFETENSORS_FILE_EXTENSION,
3334
WEIGHTS_INDEX_NAME,
@@ -152,7 +153,8 @@ def load_state_dict(
152153
return safetensors.torch.load(mm)
153154
else:
154155
return safetensors.torch.load_file(checkpoint_file, device="cpu")
155-
156+
elif file_extension == GGUF_FILE_EXTENSION:
157+
return load_gguf_checkpoint(checkpoint_file)
156158
else:
157159
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
158160
return torch.load(

0 commit comments

Comments
 (0)