Skip to content

Commit 9d59fdc

Browse files
authored
Merge branch 'main' into main
2 parents 0e8adaf + 5704376 commit 9d59fdc

File tree

3 files changed

+124
-74
lines changed

3 files changed

+124
-74
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,108 @@ def get_connected_passed_kwargs(prefix):
838838
)
839839

840840
return init_kwargs
841+
842+
843+
def _get_custom_components_and_folders(
844+
pretrained_model_name: str,
845+
config_dict: Dict[str, Any],
846+
filenames: Optional[List[str]] = None,
847+
variant_filenames: Optional[List[str]] = None,
848+
variant: Optional[str] = None,
849+
):
850+
config_dict = config_dict.copy()
851+
852+
# retrieve all folder_names that contain relevant files
853+
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
854+
855+
diffusers_module = importlib.import_module(__name__.split(".")[0])
856+
pipelines = getattr(diffusers_module, "pipelines")
857+
858+
# optionally create a custom component <> custom file mapping
859+
custom_components = {}
860+
for component in folder_names:
861+
module_candidate = config_dict[component][0]
862+
863+
if module_candidate is None or not isinstance(module_candidate, str):
864+
continue
865+
866+
# We compute candidate file path on the Hub. Do not use `os.path.join`.
867+
candidate_file = f"{component}/{module_candidate}.py"
868+
869+
if candidate_file in filenames:
870+
custom_components[component] = module_candidate
871+
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
872+
raise ValueError(
873+
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
874+
)
875+
876+
if len(variant_filenames) == 0 and variant is not None:
877+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
878+
raise ValueError(error_message)
879+
880+
return custom_components, folder_names
881+
882+
883+
def _get_ignore_patterns(
884+
passed_components,
885+
model_folder_names: List[str],
886+
model_filenames: List[str],
887+
variant_filenames: List[str],
888+
use_safetensors: bool,
889+
from_flax: bool,
890+
allow_pickle: bool,
891+
use_onnx: bool,
892+
is_onnx: bool,
893+
variant: Optional[str] = None,
894+
) -> List[str]:
895+
if (
896+
use_safetensors
897+
and not allow_pickle
898+
and not is_safetensors_compatible(
899+
model_filenames, passed_components=passed_components, folder_names=model_folder_names
900+
)
901+
):
902+
raise EnvironmentError(
903+
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
904+
)
905+
906+
if from_flax:
907+
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
908+
909+
elif use_safetensors and is_safetensors_compatible(
910+
model_filenames, passed_components=passed_components, folder_names=model_folder_names
911+
):
912+
ignore_patterns = ["*.bin", "*.msgpack"]
913+
914+
use_onnx = use_onnx if use_onnx is not None else is_onnx
915+
if not use_onnx:
916+
ignore_patterns += ["*.onnx", "*.pb"]
917+
918+
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
919+
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
920+
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
921+
logger.warning(
922+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
923+
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
924+
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
925+
f"expected, please check your folder structure."
926+
)
927+
928+
else:
929+
ignore_patterns = ["*.safetensors", "*.msgpack"]
930+
931+
use_onnx = use_onnx if use_onnx is not None else is_onnx
932+
if not use_onnx:
933+
ignore_patterns += ["*.onnx", "*.pb"]
934+
935+
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
936+
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
937+
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
938+
logger.warning(
939+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
940+
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
941+
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
942+
f"your folder structure."
943+
)
944+
945+
return ignore_patterns

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 18 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@
7171
CUSTOM_PIPELINE_FILE_NAME,
7272
LOADABLE_CLASSES,
7373
_fetch_class_library_tuple,
74+
_get_custom_components_and_folders,
7475
_get_custom_pipeline_class,
7576
_get_final_device_map,
77+
_get_ignore_patterns,
7678
_get_pipeline_class,
7779
_identify_model_variants,
7880
_maybe_raise_warning_for_inpainting,
7981
_resolve_custom_pipeline_and_cls,
8082
_unwrap_model,
8183
_update_init_kwargs_with_connected_pipeline,
82-
is_safetensors_compatible,
8384
load_sub_model,
8485
maybe_raise_or_warn,
8586
variant_compatible_siblings,
@@ -1298,44 +1299,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12981299
config_dict = cls._dict_from_json_file(config_file)
12991300
ignore_filenames = config_dict.pop("_ignore_files", [])
13001301

1301-
# retrieve all folder_names that contain relevant files
1302-
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
1303-
1304-
diffusers_module = importlib.import_module(__name__.split(".")[0])
1305-
pipelines = getattr(diffusers_module, "pipelines")
1306-
1307-
# optionally create a custom component <> custom file mapping
1308-
custom_components = {}
1309-
for component in folder_names:
1310-
module_candidate = config_dict[component][0]
1311-
1312-
if module_candidate is None or not isinstance(module_candidate, str):
1313-
continue
1314-
1315-
# We compute candidate file path on the Hub. Do not use `os.path.join`.
1316-
candidate_file = f"{component}/{module_candidate}.py"
1317-
1318-
if candidate_file in filenames:
1319-
custom_components[component] = module_candidate
1320-
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
1321-
raise ValueError(
1322-
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
1323-
)
1324-
1325-
if len(variant_filenames) == 0 and variant is not None:
1326-
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1327-
raise ValueError(error_message)
1328-
13291302
# remove ignored filenames
13301303
model_filenames = set(model_filenames) - set(ignore_filenames)
13311304
variant_filenames = set(variant_filenames) - set(ignore_filenames)
13321305

1333-
# if the whole pipeline is cached we don't have to ping the Hub
13341306
if revision in DEPRECATED_REVISION_ARGS and version.parse(
13351307
version.parse(__version__).base_version
13361308
) >= version.parse("0.22.0"):
13371309
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
13381310

1311+
custom_components, folder_names = _get_custom_components_and_folders(
1312+
pretrained_model_name, config_dict, filenames, variant_filenames, variant
1313+
)
13391314
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
13401315

13411316
custom_class_name = None
@@ -1395,49 +1370,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13951370
expected_components, _ = cls._get_signature_keys(pipeline_class)
13961371
passed_components = [k for k in expected_components if k in kwargs]
13971372

1398-
if (
1399-
use_safetensors
1400-
and not allow_pickle
1401-
and not is_safetensors_compatible(
1402-
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1403-
)
1404-
):
1405-
raise EnvironmentError(
1406-
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
1407-
)
1408-
if from_flax:
1409-
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1410-
elif use_safetensors and is_safetensors_compatible(
1411-
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1412-
):
1413-
ignore_patterns = ["*.bin", "*.msgpack"]
1414-
1415-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1416-
if not use_onnx:
1417-
ignore_patterns += ["*.onnx", "*.pb"]
1418-
1419-
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
1420-
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
1421-
if (
1422-
len(safetensors_variant_filenames) > 0
1423-
and safetensors_model_filenames != safetensors_variant_filenames
1424-
):
1425-
logger.warning(
1426-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1427-
)
1428-
else:
1429-
ignore_patterns = ["*.safetensors", "*.msgpack"]
1430-
1431-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1432-
if not use_onnx:
1433-
ignore_patterns += ["*.onnx", "*.pb"]
1434-
1435-
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
1436-
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1437-
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1438-
logger.warning(
1439-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1440-
)
1373+
# retrieve all patterns that should not be downloaded and error out when needed
1374+
ignore_patterns = _get_ignore_patterns(
1375+
passed_components,
1376+
model_folder_names,
1377+
model_filenames,
1378+
variant_filenames,
1379+
use_safetensors,
1380+
from_flax,
1381+
allow_pickle,
1382+
use_onnx,
1383+
pipeline_class._is_onnx,
1384+
variant,
1385+
)
14411386

14421387
# Don't download any objects that are passed
14431388
allow_patterns = [

tests/pipelines/test_pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
StableDiffusionPipeline,
1919
UNet2DConditionModel,
2020
)
21-
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
21+
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
2222
from diffusers.utils.testing_utils import torch_device
2323

2424

0 commit comments

Comments
 (0)