Skip to content

Commit 8d1314c

Browse files
committed
apply suggestions from review
1 parent 09fa592 commit 8d1314c

File tree

3 files changed

+26
-24
lines changed

3 files changed

+26
-24
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
is_transformers_available,
4141
logging,
4242
)
43-
from ..utils.hub_utils import _check_legacy_sharding_variant_format
4443
from ..utils.torch_utils import is_compiled_module
4544

4645

@@ -841,25 +840,6 @@ def get_connected_passed_kwargs(prefix):
841840
return init_kwargs
842841

843842

844-
def _maybe_raise_warning_for_variant_checkpoint_format(
845-
folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None
846-
) -> None:
847-
if variant is not None and _check_legacy_sharding_variant_format(
848-
folder=folder, filenames=filenames, variant=variant
849-
):
850-
warn_msg = (
851-
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
852-
"Please check your files carefully:\n\n"
853-
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
854-
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
855-
"If you find any files in the deprecated format:\n"
856-
"1. Remove all existing checkpoint files for this variant.\n"
857-
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
858-
"This will ensure you're using the most up-to-date and compatible checkpoint format."
859-
)
860-
logger.warning(warn_msg)
861-
862-
863843
def _get_custom_components_and_folders(
864844
pretrained_model_name: str,
865845
config_dict: Dict[str, Any],

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
CONNECTED_PIPES_KEYS,
7171
CUSTOM_PIPELINE_FILE_NAME,
7272
LOADABLE_CLASSES,
73+
_check_legacy_sharding_variant_format,
7374
_fetch_class_library_tuple,
7475
_get_custom_components_and_folders,
7576
_get_custom_pipeline_class,
@@ -78,7 +79,6 @@
7879
_get_pipeline_class,
7980
_identify_model_variants,
8081
_maybe_raise_warning_for_inpainting,
81-
_maybe_raise_warning_for_variant_checkpoint_format,
8282
_resolve_custom_pipeline_and_cls,
8383
_unwrap_model,
8484
_update_init_kwargs_with_connected_pipeline,
@@ -738,7 +738,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
738738

739739
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
740740
# a warning if detected.
741-
_maybe_raise_warning_for_variant_checkpoint_format(folder=cached_folder, variant=variant)
741+
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
742+
warn_msg = (
743+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
744+
"Please check your files carefully:\n\n"
745+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
746+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
747+
"If you find any files in the deprecated format:\n"
748+
"1. Remove all existing checkpoint files for this variant.\n"
749+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
750+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
751+
)
752+
logger.warning(warn_msg)
742753

743754
config_dict = cls.load_config(cached_folder)
744755

@@ -1261,7 +1272,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12611272

12621273
if not local_files_only:
12631274
filenames = {sibling.rfilename for sibling in info.siblings}
1264-
_maybe_raise_warning_for_variant_checkpoint_format(filenames=filenames, variant=variant)
1275+
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
1276+
warn_msg = (
1277+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
1278+
"Please check your files carefully:\n\n"
1279+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
1280+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
1281+
"If you find any files in the deprecated format:\n"
1282+
"1. Remove all existing checkpoint files for this variant.\n"
1283+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
1284+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
1285+
)
1286+
logger.warning(warn_msg)
12651287

12661288
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
12671289

tests/pipelines/test_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def test_download_variants_with_sharded_checkpoints(self):
578578

579579
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
580580
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
581-
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
581+
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
582582
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
583583

584584
for is_local in [True, False]:

0 commit comments

Comments
 (0)