Skip to content
95 changes: 54 additions & 41 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names is not None:
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}

# extract all components of the pipeline and their associated files
Expand Down Expand Up @@ -141,7 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
return True


def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
Expand Down Expand Up @@ -177,54 +177,67 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")

if variant is not None:
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()

non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes
def filter_for_compatible_extensions(filenames, variant=None, use_safetensors=True):
def is_safetensors(filename):
return ".safetensors" in filename

# all variant filenames will be used by default
usable_filenames = set(variant_filenames)
def is_not_safetensors(filename):
return ".safetensors" not in filename

def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
if use_safetensors and is_safetensors_compatible(filenames):
extension_filter = is_safetensors
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
extension_filter = is_not_safetensors

tensor_files = {f for f in filenames if extension_filter(f)}
non_variant_indexes = {
f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f)
}
variant_indexes = {
f
for f in filenames
if variant is not None and variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f)
}

return tensor_files | non_variant_indexes | variant_indexes

def filter_for_weights_and_indexes(filenames, file_re, index_re):
weights = {f for f in filenames if file_re.match(f.split("/")[-1]) is not None}
indexes = {f for f in filenames if index_re.match(f.split("/")[-1]) is not None}
filtered_filenames = weights | indexes

def find_component(filename):
return filtered_filenames

# Group files by component
components = {}
for filename in filenames:
if not len(filename.split("/")) == 2:
return
component = filename.split("/")[0]
return component

def has_sharded_variant(component, variant, variant_filenames):
# If component exists check for sharded variant index filename
# If component doesn't exist check main dir for sharded variant index filename
component = component + "/" if component else ""
variant_index_re = re.compile(
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
components.setdefault("", []).append(filename)
continue

component, _ = filename.split("/")
components.setdefault(component, []).append(filename)

usable_filenames = set()
variant_filenames = set()
for component, component_filenames in components.items():
component_filenames = filter_for_compatible_extensions(
component_filenames, variant=variant, use_safetensors=use_safetensors
)
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)

for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue
component_variants = set()
if variant is not None:
component_variants = filter_for_weights_and_indexes(component_filenames, variant_file_re, variant_index_re)

component = find_component(filename)
# If a sharded variant exists skip adding to allowed patterns
if has_sharded_variant(component, variant, variant_filenames):
continue
if component_variants:
variant_filenames.update(component_variants)
usable_filenames.update(component_variants)

usable_filenames.add(filename)
else:
component_non_variants = filter_for_weights_and_indexes(
component_filenames, non_variant_file_re, non_variant_index_re
)
usable_filenames.update(component_non_variants)

return usable_filenames, variant_filenames

Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
logger.warning(warn_msg)

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, use_safetensors=use_safetensors
)

config_file = hf_hub_download(
pretrained_model_name,
Expand Down
Loading