5050 DEPRECATED_REVISION_ARGS ,
5151 BaseOutput ,
5252 PushToHubMixin ,
53- deprecate ,
5453 is_accelerate_available ,
5554 is_accelerate_version ,
5655 is_torch_npu_available ,
5756 is_torch_version ,
5857 logging ,
5958 numpy_to_pil ,
6059)
61- from ..utils .hub_utils import load_or_create_model_card , populate_model_card
60+ from ..utils .hub_utils import _check_legacy_sharding_variant_format , load_or_create_model_card , populate_model_card
6261from ..utils .torch_utils import is_compiled_module
6362
6463
@@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
735734 else :
736735 cached_folder = pretrained_model_name_or_path
737736
737+ # The variant filenames can have the legacy sharding checkpoint format that we check and throw
738+ # a warning if detected.
739+ if variant is not None and _check_legacy_sharding_variant_format (folder = cached_folder , variant = variant ):
740+ warn_msg = (
741+ f"Warning: The repository contains sharded checkpoints for variant '{ variant } ' maybe in a deprecated format. "
742+ "Please check your files carefully:\n \n "
743+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n "
744+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n \n "
745+ "If you find any files in the deprecated format:\n "
746+ "1. Remove all existing checkpoint files for this variant.\n "
747+ "2. Re-obtain the correct files by running `save_pretrained()`.\n \n "
748+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
749+ )
750+ logger .warning (warn_msg )
751+
738752 config_dict = cls .load_config (cached_folder )
739753
740754 # pop out "_ignore_files" as it is only needed for download
@@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
745759 # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
746760 # with variant being `"fp16"`.
747761 model_variants = _identify_model_variants (folder = cached_folder , variant = variant , config = config_dict )
762+ if len (model_variants ) == 0 and variant is not None :
763+ error_message = f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
764+ raise ValueError (error_message )
748765
749766 # 3. Load the pipeline class, if using custom module then load it from the hub
750767 # if we load from explicit class, let's use it
@@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12511268 model_info_call_error = e # save error to reraise it if model is not cached locally
12521269
12531270 if not local_files_only :
1271+ filenames = {sibling .rfilename for sibling in info .siblings }
1272+ if variant is not None and _check_legacy_sharding_variant_format (filenames = filenames , variant = variant ):
1273+ warn_msg = (
1274+ f"Warning: The repository contains sharded checkpoints for variant '{ variant } ' maybe in a deprecated format. "
1275+ "Please check your files carefully:\n \n "
1276+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n "
1277+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n \n "
1278+ "If you find any files in the deprecated format:\n "
1279+ "1. Remove all existing checkpoint files for this variant.\n "
1280+ "2. Re-obtain the correct files by running `save_pretrained()`.\n \n "
1281+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
1282+ )
1283+ logger .warning (warn_msg )
1284+
1285+ model_filenames , variant_filenames = variant_compatible_siblings (filenames , variant = variant )
1286+
12541287 config_file = hf_hub_download (
12551288 pretrained_model_name ,
12561289 cls .config_name ,
@@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12671300 # retrieve all folder_names that contain relevant files
12681301 folder_names = [k for k , v in config_dict .items () if isinstance (v , list ) and k != "_class_name" ]
12691302
1270- filenames = {sibling .rfilename for sibling in info .siblings }
1271- model_filenames , variant_filenames = variant_compatible_siblings (filenames , variant = variant )
1272-
12731303 diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
12741304 pipelines = getattr (diffusers_module , "pipelines" )
12751305
@@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12921322 )
12931323
12941324 if len (variant_filenames ) == 0 and variant is not None :
1295- deprecation_message = (
1296- f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
1297- f"The default model files: { model_filenames } will be loaded instead. Make sure to not load from `variant={ variant } `"
1298- "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1299- "modeling files is deprecated."
1300- )
1301- deprecate ("no variant default" , "0.24.0" , deprecation_message , standard_warn = False )
1325+ error_message = f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
1326+ raise ValueError (error_message )
13021327
13031328 # remove ignored filenames
13041329 model_filenames = set (model_filenames ) - set (ignore_filenames )
0 commit comments