Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,19 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class


def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, device="cpu"):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return safetensors.torch.load_file(checkpoint_file, device=device)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
map_location=device,
**weights_only_kwarg,
)
except Exception as e:
Expand Down
19 changes: 11 additions & 8 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,8 @@ def stable_unclip_image_noising_components(

If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
"""
from ...models.model_loading_utils import load_state_dict

noise_aug_config = original_config["model"]["params"]["noise_aug_config"]
noise_aug_class = noise_aug_config["target"]
noise_aug_class = noise_aug_class.split(".")[-1]
Expand All @@ -1068,8 +1070,7 @@ def stable_unclip_image_noising_components(
if "clip_stats_path" in noise_aug_config:
if clip_stats_path is None:
raise ValueError("This stable unclip config requires a `clip_stats_path`")

clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
clip_mean, clip_std = load_state_dict(clip_stats_path, device=device)
clip_mean = clip_mean[None, :]
clip_std = clip_std[None, :]

Expand Down Expand Up @@ -1264,11 +1265,12 @@ def download_from_original_stable_diffusion_ckpt(

checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
from ...models.model_loading_utils import load_state_dict

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else:
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
checkpoint = load_state_dict(checkpoint_path_or_dict, device=device)

elif isinstance(checkpoint_path_or_dict, dict):
checkpoint = checkpoint_path_or_dict

Expand Down Expand Up @@ -1834,11 +1836,12 @@ def download_controlnet_from_original_ckpt(
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
from ...models.model_loading_utils import load_state_dict

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)

checkpoint = load_state_dict(checkpoint_path, device=device)

# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,11 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -


def load_pt(url: str):
from ..models.model_loading_utils import load_state_dict

response = requests.get(url)
response.raise_for_status()
arry = torch.load(BytesIO(response.content))
arry = load_state_dict(BytesIO(response.content))
return arry


Expand Down
Loading