From 9d57f792cc6bf3fc67a46f6d39debd2ddf029846 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Sep 2024 07:33:19 +0200 Subject: [PATCH 1/7] update --- .../pipelines/pipeline_loading_utils.py | 107 ++++++++++++++++++ src/diffusers/pipelines/pipeline_utils.py | 104 ++--------------- 2 files changed, 115 insertions(+), 96 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 318599f56063..6b99ca52564e 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -40,6 +40,7 @@ is_transformers_available, logging, ) +from ..utils.hub_utils import _check_legacy_sharding_variant_format from ..utils.torch_utils import is_compiled_module @@ -836,3 +837,109 @@ def get_connected_passed_kwargs(prefix): ) return init_kwargs + + +def _maybe_raise_warning_for_variant_checkpoint_format(folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None) -> None: + if variant is not None and _check_legacy_sharding_variant_format(folder=folder, filenames=filenames, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) + + +def _get_custom_components_and_folders(pretrained_model_name: str, config_dict: Dict[str, Any], filenames: Optional[List[str]] = None, variant_filenames: Optional[List[str]] = None, variant: Optional[str] = None): + config_dict = config_dict.copy() + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # optionally create a custom component <> custom file mapping + custom_components = {} + for component in folder_names: + module_candidate = config_dict[component][0] + + if module_candidate is None or not isinstance(module_candidate, str): + continue + + # We compute candidate file path on the Hub. Do not use `os.path.join`. + candidate_file = f"{component}/{module_candidate}.py" + + if candidate_file in filenames: + custom_components[component] = module_candidate + elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): + raise ValueError( + f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." + ) + + if len(variant_filenames) == 0 and variant is not None: + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) + + return custom_components, folder_names + + +def _get_ignore_patterns(passed_components, model_folder_names: List[str], model_filenames: List[str], variant_filenames: List[str], use_safetensors: bool, from_flax: bool, allow_pickle: bool, use_onnx: bool, is_onnx: bool, variant: Optional[str] = None) -> List[str]: + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ) + ): + raise EnvironmentError( + f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) + + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + + elif use_safetensors and is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ): + ignore_patterns = ["*.bin", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if ( + len(safetensors_variant_filenames) > 0 + and safetensors_model_filenames != safetensors_variant_filenames + ): + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " + f"expected, please check your folder structure." + ) + + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " + f"your folder structure." + ) + + return ignore_patterns diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6721706b5689..38560b1e7633 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -76,6 +76,7 @@ _get_pipeline_class, _identify_model_variants, _maybe_raise_warning_for_inpainting, + _maybe_raise_warning_for_variant_checkpoint_format, _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, @@ -84,6 +85,8 @@ maybe_raise_or_warn, variant_compatible_siblings, warn_deprecated_model_variant, + _get_custom_components_and_folders, + _get_ignore_patterns, ) @@ -736,18 +739,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The variant filenames can have the legacy sharding checkpoint format that we check and throw # a warning if detected. - if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant): - warn_msg = ( - f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " - "Please check your files carefully:\n\n" - "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" - "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" - "If you find any files in the deprecated format:\n" - "1. Remove all existing checkpoint files for this variant.\n" - "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" - "This will ensure you're using the most up-to-date and compatible checkpoint format." - ) - logger.warning(warn_msg) + _maybe_raise_warning_for_variant_checkpoint_format(folder=cached_folder, variant=variant) config_dict = cls.load_config(cached_folder) @@ -1269,18 +1261,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: if not local_files_only: filenames = {sibling.rfilename for sibling in info.siblings} - if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): - warn_msg = ( - f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " - "Please check your files carefully:\n\n" - "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" - "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" - "If you find any files in the deprecated format:\n" - "1. Remove all existing checkpoint files for this variant.\n" - "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" - "This will ensure you're using the most up-to-date and compatible checkpoint format." - ) - logger.warning(warn_msg) + _maybe_raise_warning_for_variant_checkpoint_format(filenames=filenames, variant=variant) model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) @@ -1297,44 +1278,16 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: config_dict = cls._dict_from_json_file(config_file) ignore_filenames = config_dict.pop("_ignore_files", []) - # retrieve all folder_names that contain relevant files - folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] - - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - - # optionally create a custom component <> custom file mapping - custom_components = {} - for component in folder_names: - module_candidate = config_dict[component][0] - - if module_candidate is None or not isinstance(module_candidate, str): - continue - - # We compute candidate file path on the Hub. Do not use `os.path.join`. - candidate_file = f"{component}/{module_candidate}.py" - - if candidate_file in filenames: - custom_components[component] = module_candidate - elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): - raise ValueError( - f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." - ) - - if len(variant_filenames) == 0 and variant is not None: - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - raise ValueError(error_message) - # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) - # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.22.0"): warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) + custom_components, folder_names = _get_custom_components_and_folders(pretrained_model_name, config_dict, filenames, variant_filenames, variant) model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} custom_class_name = None @@ -1394,49 +1347,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] - if ( - use_safetensors - and not allow_pickle - and not is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names - ) - ): - raise EnvironmentError( - f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" - ) - if from_flax: - ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names - ): - ignore_patterns = ["*.bin", "*.msgpack"] - - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx - if not use_onnx: - ignore_patterns += ["*.onnx", "*.pb"] - - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if ( - len(safetensors_variant_filenames) > 0 - and safetensors_model_filenames != safetensors_variant_filenames - ): - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - else: - ignore_patterns = ["*.safetensors", "*.msgpack"] - - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx - if not use_onnx: - ignore_patterns += ["*.onnx", "*.pb"] - - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) + # retrieve all patterns that should not be downloaded + ignore_patterns = _get_ignore_patterns(passed_components, model_folder_names, model_filenames, variant_filenames, use_safetensors, from_flax, allow_pickle, use_onnx, pipeline_class._is_onnx, variant) # Don't download any objects that are passed allow_patterns = [ From 5db35067f2fc6cdd45b7eb5e05a3d17276540449 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Sep 2024 07:33:58 +0200 Subject: [PATCH 2/7] make style --- .../pipelines/pipeline_loading_utils.py | 44 +++++++++++++------ src/diffusers/pipelines/pipeline_utils.py | 24 +++++++--- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 6b99ca52564e..c0a06451a350 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -839,8 +839,12 @@ def get_connected_passed_kwargs(prefix): return init_kwargs -def _maybe_raise_warning_for_variant_checkpoint_format(folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None) -> None: - if variant is not None and _check_legacy_sharding_variant_format(folder=folder, filenames=filenames, variant=variant): +def _maybe_raise_warning_for_variant_checkpoint_format( + folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None +) -> None: + if variant is not None and _check_legacy_sharding_variant_format( + folder=folder, filenames=filenames, variant=variant + ): warn_msg = ( f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " "Please check your files carefully:\n\n" @@ -854,7 +858,13 @@ def _maybe_raise_warning_for_variant_checkpoint_format(folder: Optional[str] = N logger.warning(warn_msg) -def _get_custom_components_and_folders(pretrained_model_name: str, config_dict: Dict[str, Any], filenames: Optional[List[str]] = None, variant_filenames: Optional[List[str]] = None, variant: Optional[str] = None): +def _get_custom_components_and_folders( + pretrained_model_name: str, + config_dict: Dict[str, Any], + filenames: Optional[List[str]] = None, + variant_filenames: Optional[List[str]] = None, + variant: Optional[str] = None, +): config_dict = config_dict.copy() # retrieve all folder_names that contain relevant files @@ -888,7 +898,18 @@ def _get_custom_components_and_folders(pretrained_model_name: str, config_dict: return custom_components, folder_names -def _get_ignore_patterns(passed_components, model_folder_names: List[str], model_filenames: List[str], variant_filenames: List[str], use_safetensors: bool, from_flax: bool, allow_pickle: bool, use_onnx: bool, is_onnx: bool, variant: Optional[str] = None) -> List[str]: +def _get_ignore_patterns( + passed_components, + model_folder_names: List[str], + model_filenames: List[str], + variant_filenames: List[str], + use_safetensors: bool, + from_flax: bool, + allow_pickle: bool, + use_onnx: bool, + is_onnx: bool, + variant: Optional[str] = None, +) -> List[str]: if ( use_safetensors and not allow_pickle @@ -899,10 +920,10 @@ def _get_ignore_patterns(passed_components, model_folder_names: List[str], model raise EnvironmentError( f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) - + if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - + elif use_safetensors and is_safetensors_compatible( model_filenames, passed_components=passed_components, folder_names=model_folder_names ): @@ -914,21 +935,18 @@ def _get_ignore_patterns(passed_components, model_folder_names: List[str], model safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if ( - len(safetensors_variant_filenames) > 0 - and safetensors_model_filenames != safetensors_variant_filenames - ): + if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: logger.warning( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " f"expected, please check your folder structure." ) - + else: ignore_patterns = ["*.safetensors", "*.msgpack"] - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + use_onnx = use_onnx if use_onnx is not None else is_onnx if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] @@ -941,5 +959,5 @@ def _get_ignore_patterns(passed_components, model_folder_names: List[str], model f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " f"your folder structure." ) - + return ignore_patterns diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 38560b1e7633..43002aadc08e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -57,7 +57,7 @@ logging, numpy_to_pil, ) -from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card +from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module @@ -71,8 +71,10 @@ CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, _fetch_class_library_tuple, + _get_custom_components_and_folders, _get_custom_pipeline_class, _get_final_device_map, + _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, _maybe_raise_warning_for_inpainting, @@ -80,13 +82,10 @@ _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, - is_safetensors_compatible, load_sub_model, maybe_raise_or_warn, variant_compatible_siblings, warn_deprecated_model_variant, - _get_custom_components_and_folders, - _get_ignore_patterns, ) @@ -1287,7 +1286,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) >= version.parse("0.22.0"): warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) - custom_components, folder_names = _get_custom_components_and_folders(pretrained_model_name, config_dict, filenames, variant_filenames, variant) + custom_components, folder_names = _get_custom_components_and_folders( + pretrained_model_name, config_dict, filenames, variant_filenames, variant + ) model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} custom_class_name = None @@ -1348,7 +1349,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: passed_components = [k for k in expected_components if k in kwargs] # retrieve all patterns that should not be downloaded - ignore_patterns = _get_ignore_patterns(passed_components, model_folder_names, model_filenames, variant_filenames, use_safetensors, from_flax, allow_pickle, use_onnx, pipeline_class._is_onnx, variant) + ignore_patterns = _get_ignore_patterns( + passed_components, + model_folder_names, + model_filenames, + variant_filenames, + use_safetensors, + from_flax, + allow_pickle, + use_onnx, + pipeline_class._is_onnx, + variant, + ) # Don't download any objects that are passed allow_patterns = [ From 87805a5b4d4468f3e6878e0a91f64047f0a26aa5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Sep 2024 12:50:38 +0200 Subject: [PATCH 3/7] fix import --- tests/pipelines/test_pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 697244dcb105..5eedd393c8f8 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -18,7 +18,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.pipeline_utils import is_safetensors_compatible +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible from diffusers.utils.testing_utils import torch_device From 20c08f5400ba2ae842b65497bf37c49f4c028f30 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Sep 2024 13:00:21 +0200 Subject: [PATCH 4/7] update test --- tests/pipelines/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 8b087db6726e..c318ad491ad9 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -578,7 +578,7 @@ def test_download_variants_with_sharded_checkpoints(self): def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds" - logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils") deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" for is_local in [True, False]: From 09fa592ff7aad03d291c47c7573c31de86e63ac7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 16:51:47 +0530 Subject: [PATCH 5/7] Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 4c2ad6d92064..80006a2ea088 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1349,7 +1349,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] - # retrieve all patterns that should not be downloaded + # retrieve all patterns that should not be downloaded and error out when needed ignore_patterns = _get_ignore_patterns( passed_components, model_folder_names, From 8d1314c240b315ab38571947acf3a9b2d5991f6c Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 13:27:04 +0200 Subject: [PATCH 6/7] apply suggestions from review --- .../pipelines/pipeline_loading_utils.py | 20 ------------- src/diffusers/pipelines/pipeline_utils.py | 28 +++++++++++++++++-- tests/pipelines/test_pipelines.py | 2 +- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index bfa525f8e536..c16bd8ac2069 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -40,7 +40,6 @@ is_transformers_available, logging, ) -from ..utils.hub_utils import _check_legacy_sharding_variant_format from ..utils.torch_utils import is_compiled_module @@ -841,25 +840,6 @@ def get_connected_passed_kwargs(prefix): return init_kwargs -def _maybe_raise_warning_for_variant_checkpoint_format( - folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None -) -> None: - if variant is not None and _check_legacy_sharding_variant_format( - folder=folder, filenames=filenames, variant=variant - ): - warn_msg = ( - f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " - "Please check your files carefully:\n\n" - "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" - "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" - "If you find any files in the deprecated format:\n" - "1. Remove all existing checkpoint files for this variant.\n" - "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" - "This will ensure you're using the most up-to-date and compatible checkpoint format." - ) - logger.warning(warn_msg) - - def _get_custom_components_and_folders( pretrained_model_name: str, config_dict: Dict[str, Any], diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 80006a2ea088..f551c7a6628b 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -70,6 +70,7 @@ CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, + _check_legacy_sharding_variant_format, _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, @@ -78,7 +79,6 @@ _get_pipeline_class, _identify_model_variants, _maybe_raise_warning_for_inpainting, - _maybe_raise_warning_for_variant_checkpoint_format, _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, @@ -738,7 +738,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The variant filenames can have the legacy sharding checkpoint format that we check and throw # a warning if detected. - _maybe_raise_warning_for_variant_checkpoint_format(folder=cached_folder, variant=variant) + if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) config_dict = cls.load_config(cached_folder) @@ -1261,7 +1272,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: if not local_files_only: filenames = {sibling.rfilename for sibling in info.siblings} - _maybe_raise_warning_for_variant_checkpoint_format(filenames=filenames, variant=variant) + if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index c318ad491ad9..8b087db6726e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -578,7 +578,7 @@ def test_download_variants_with_sharded_checkpoints(self): def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds" - logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils") + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" for is_local in [True, False]: From e19e800d9674c9b4ef142c919e16132c9332ad71 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 13:29:27 +0200 Subject: [PATCH 7/7] fix import --- src/diffusers/pipelines/pipeline_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f551c7a6628b..2be0c5e7310c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -57,7 +57,7 @@ logging, numpy_to_pil, ) -from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module @@ -70,7 +70,6 @@ CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, - _check_legacy_sharding_variant_format, _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class,