diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 9c838ac61476..f019a3cc67a6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -134,19 +134,6 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class -def _check_archive_and_maybe_raise_error(checkpoint_file, format_list): - """ - Check format of the archive - """ - with safetensors.safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata is not None and metadata.get("format") not in format_list: - raise OSError( - f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " - "you save your model with the `save_pretrained` method." - ) - - def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]): """ Find the device of param_name from the device_map. @@ -183,7 +170,6 @@ def load_state_dict( # tensors are loaded on cpu with dduf_entries[checkpoint_file].as_mmap() as mm: return safetensors.torch.load(mm) - _check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"]) if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: