Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 50 additions & 33 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,39 +103,55 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
# Fast path: nothing to check
if not filenames:
return False

passed_components = passed_components or []
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
passed_components_set = set(passed_components)

# extract all components of the pipeline and their associated files
# If folder filter, do it only once and only on valid "/" files
if folder_names:
folder_names_set = set(folder_names)
# Fast path: skip os.path.split in tight loop, use partition
filenames_filtered = []
for f in filenames:
idx = f.find("/")
if idx != -1 and f[:idx] in folder_names_set:
filenames_filtered.append(f)
filenames = filenames_filtered

# Build components mapping in one fast scan
components = {}
for filename in filenames:
if not len(filename.split("/")) == 2:
append = components.setdefault
# Avoid repeated splits
for f in filenames:
idx = f.find("/")
if idx == -1 or idx != f.rfind("/"): # only accept one slash
continue

component, component_filename = filename.split("/")
if component in passed_components:
component = f[:idx]
if component in passed_components_set:
continue

components.setdefault(component, [])
components[component].append(component_filename)
component_filename = f[idx+1:]
append(component, []).append(component_filename)

# If there are no component folders check the main directory for safetensors files
if not components:
return any(".safetensors" in filename for filename in filenames)

# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
for component_filename in component_filenames:
filename, extension = os.path.splitext(component_filename)

match_exists = extension == ".safetensors"
matches.append(match_exists)

if not any(matches):
# Use any generator not list comp
for filename in filenames:
if ".safetensors" in filename:
return True
return False

# For each component, does any of its files have the .safetensors extension?
for component_filenames in components.values():
found = False
# Instead of os.path.splitext, just check if endswith
for name in component_filenames:
if name.endswith(".safetensors"):
found = True
break
if not found:
return False

return True
Expand Down Expand Up @@ -980,13 +996,15 @@ def _get_ignore_patterns(
is_onnx: bool,
variant: Optional[str] = None,
) -> List[str]:
if (
# This function is not a hotspot, but keep fast paths and function call reductions
safetensors_ok = (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
)
if safetensors_ok:
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
Expand All @@ -998,16 +1016,15 @@ def _get_ignore_patterns(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
# Faster local assignment than repeated ternary ops in both branches
onnx_enabled = use_onnx if use_onnx is not None else is_onnx
if not onnx_enabled:
ignore_patterns += ["*.onnx", "*.pb"]

else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
onnx_enabled = use_onnx if use_onnx is not None else is_onnx
if not onnx_enabled:
ignore_patterns += ["*.onnx", "*.pb"]

return ignore_patterns
Expand Down