Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_transformers_available,
logging,
)
from ..utils.hub_utils import _check_legacy_sharding_variant_format
from ..utils.torch_utils import is_compiled_module


Expand Down Expand Up @@ -838,3 +839,127 @@ def get_connected_passed_kwargs(prefix):
)

return init_kwargs


def _maybe_raise_warning_for_variant_checkpoint_format(
folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None
) -> None:
if variant is not None and _check_legacy_sharding_variant_format(
folder=folder, filenames=filenames, variant=variant
):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)


def _get_custom_components_and_folders(
pretrained_model_name: str,
config_dict: Dict[str, Any],
filenames: Optional[List[str]] = None,
variant_filenames: Optional[List[str]] = None,
variant: Optional[str] = None,
):
config_dict = config_dict.copy()

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None or not isinstance(module_candidate, str):
continue

# We compute candidate file path on the Hub. Do not use `os.path.join`.
candidate_file = f"{component}/{module_candidate}.py"

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

return custom_components, folder_names


def _get_ignore_patterns(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
use_onnx: bool,
is_onnx: bool,
variant: Optional[str] = None,
) -> List[str]:
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)

if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]

elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)

else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)

return ignore_patterns
120 changes: 22 additions & 98 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
logging,
numpy_to_pil,
)
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module


Expand All @@ -71,15 +71,17 @@
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_get_custom_components_and_folders,
_get_custom_pipeline_class,
_get_final_device_map,
_get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
_maybe_raise_warning_for_inpainting,
_maybe_raise_warning_for_variant_checkpoint_format,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
is_safetensors_compatible,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
Expand Down Expand Up @@ -736,18 +738,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# The variant filenames can have the legacy sharding checkpoint format that we check and throw
# a warning if detected.
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)
_maybe_raise_warning_for_variant_checkpoint_format(folder=cached_folder, variant=variant)

config_dict = cls.load_config(cached_folder)

Expand Down Expand Up @@ -1270,18 +1261,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

if not local_files_only:
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)
_maybe_raise_warning_for_variant_checkpoint_format(filenames=filenames, variant=variant)

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

Expand All @@ -1298,44 +1278,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None or not isinstance(module_candidate, str):
continue

# We compute candidate file path on the Hub. Do not use `os.path.join`.
candidate_file = f"{component}/{module_candidate}.py"

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)

# if the whole pipeline is cached we don't have to ping the Hub
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)

custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant_filenames, variant
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

custom_class_name = None
Expand Down Expand Up @@ -1395,49 +1349,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]

if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
# retrieve all patterns that should not be downloaded
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
model_filenames,
variant_filenames,
use_safetensors,
from_flax,
allow_pickle,
use_onnx,
pipeline_class._is_onnx,
variant,
)

# Don't download any objects that are passed
allow_patterns = [
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
from diffusers.utils.testing_utils import torch_device


Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def test_download_variants_with_sharded_checkpoints(self):

def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"

for is_local in [True, False]:
Expand Down
Loading