Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,108 @@ def get_connected_passed_kwargs(prefix):
)

return init_kwargs


def _get_custom_components_and_folders(
pretrained_model_name: str,
config_dict: Dict[str, Any],
filenames: Optional[List[str]] = None,
variant_filenames: Optional[List[str]] = None,
variant: Optional[str] = None,
):
config_dict = config_dict.copy()

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None or not isinstance(module_candidate, str):
continue

# We compute candidate file path on the Hub. Do not use `os.path.join`.
candidate_file = f"{component}/{module_candidate}.py"

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

return custom_components, folder_names


def _get_ignore_patterns(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
use_onnx: bool,
is_onnx: bool,
variant: Optional[str] = None,
) -> List[str]:
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)

if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]

elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)

else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)

return ignore_patterns
91 changes: 18 additions & 73 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_get_custom_components_and_folders,
_get_custom_pipeline_class,
_get_final_device_map,
_get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
is_safetensors_compatible,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
Expand Down Expand Up @@ -1298,44 +1299,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None or not isinstance(module_candidate, str):
continue

# We compute candidate file path on the Hub. Do not use `os.path.join`.
candidate_file = f"{component}/{module_candidate}.py"

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)

# if the whole pipeline is cached we don't have to ping the Hub
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)

custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant_filenames, variant
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

custom_class_name = None
Expand Down Expand Up @@ -1395,49 +1370,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]

if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
model_filenames,
variant_filenames,
use_safetensors,
from_flax,
allow_pickle,
use_onnx,
pipeline_class._is_onnx,
variant,
)

# Don't download any objects that are passed
allow_patterns = [
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
from diffusers.utils.testing_utils import torch_device


Expand Down
Loading