|
70 | 70 | CONNECTED_PIPES_KEYS, |
71 | 71 | CUSTOM_PIPELINE_FILE_NAME, |
72 | 72 | LOADABLE_CLASSES, |
| 73 | + _check_legacy_sharding_variant_format, |
73 | 74 | _fetch_class_library_tuple, |
74 | 75 | _get_custom_components_and_folders, |
75 | 76 | _get_custom_pipeline_class, |
|
78 | 79 | _get_pipeline_class, |
79 | 80 | _identify_model_variants, |
80 | 81 | _maybe_raise_warning_for_inpainting, |
81 | | - _maybe_raise_warning_for_variant_checkpoint_format, |
82 | 82 | _resolve_custom_pipeline_and_cls, |
83 | 83 | _unwrap_model, |
84 | 84 | _update_init_kwargs_with_connected_pipeline, |
@@ -738,7 +738,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
738 | 738 |
|
739 | 739 | # The variant filenames can have the legacy sharding checkpoint format that we check and throw |
740 | 740 | # a warning if detected. |
741 | | - _maybe_raise_warning_for_variant_checkpoint_format(folder=cached_folder, variant=variant) |
| 741 | + if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant): |
| 742 | + warn_msg = ( |
| 743 | + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " |
| 744 | + "Please check your files carefully:\n\n" |
| 745 | + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" |
| 746 | + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" |
| 747 | + "If you find any files in the deprecated format:\n" |
| 748 | + "1. Remove all existing checkpoint files for this variant.\n" |
| 749 | + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" |
| 750 | + "This will ensure you're using the most up-to-date and compatible checkpoint format." |
| 751 | + ) |
| 752 | + logger.warning(warn_msg) |
742 | 753 |
|
743 | 754 | config_dict = cls.load_config(cached_folder) |
744 | 755 |
|
@@ -1261,7 +1272,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: |
1261 | 1272 |
|
1262 | 1273 | if not local_files_only: |
1263 | 1274 | filenames = {sibling.rfilename for sibling in info.siblings} |
1264 | | - _maybe_raise_warning_for_variant_checkpoint_format(filenames=filenames, variant=variant) |
| 1275 | + if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): |
| 1276 | + warn_msg = ( |
| 1277 | + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " |
| 1278 | + "Please check your files carefully:\n\n" |
| 1279 | + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" |
| 1280 | + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" |
| 1281 | + "If you find any files in the deprecated format:\n" |
| 1282 | + "1. Remove all existing checkpoint files for this variant.\n" |
| 1283 | + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" |
| 1284 | + "This will ensure you're using the most up-to-date and compatible checkpoint format." |
| 1285 | + ) |
| 1286 | + logger.warning(warn_msg) |
1265 | 1287 |
|
1266 | 1288 | model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
1267 | 1289 |
|
|
0 commit comments