diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 969eb5f5fa37..a22be61aba81 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -95,19 +95,19 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, device="cpu"): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: - return safetensors.torch.load_file(checkpoint_file, device="cpu") + return safetensors.torch.load_file(checkpoint_file, device=device) else: weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} return torch.load( checkpoint_file, - map_location="cpu", + map_location=device, **weights_only_kwarg, ) except Exception as e: diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 53dc98aea698..63369e04c9f1 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1052,6 +1052,8 @@ def stable_unclip_image_noising_components( If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ + from ...models.model_loading_utils import load_state_dict + noise_aug_config = original_config["model"]["params"]["noise_aug_config"] noise_aug_class = noise_aug_config["target"] noise_aug_class = noise_aug_class.split(".")[-1] @@ -1068,8 +1070,7 @@ def stable_unclip_image_noising_components( if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean, clip_std = load_state_dict(clip_stats_path, device=device) clip_mean = clip_mean[None, :] clip_std = clip_std[None, :] @@ -1264,11 +1265,12 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") else: + from ...models.model_loading_utils import load_state_dict + if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) - else: - checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + checkpoint = load_state_dict(checkpoint_path_or_dict, device=device) + elif isinstance(checkpoint_path_or_dict, dict): checkpoint = checkpoint_path_or_dict @@ -1834,11 +1836,12 @@ def download_controlnet_from_original_ckpt( for key in f.keys(): checkpoint[key] = f.get_tensor(key) else: + from ...models.model_loading_utils import load_state_dict + if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) + + checkpoint = load_state_dict(checkpoint_path, device=device) # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index be3e9983c80f..e8069d2e47b4 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -428,9 +428,11 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) - def load_pt(url: str): + from ..models.model_loading_utils import load_state_dict + response = requests.get(url) response.raise_for_status() - arry = torch.load(BytesIO(response.content)) + arry = load_state_dict(BytesIO(response.content)) return arry