Skip to content

Commit 6899f40

Browse files
committed
update
1 parent abba8e0 commit 6899f40

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
141141
return True
142142

143143

144+
def filter_model_files(filenames):
145+
"""Filter model repo files for just files/folders that contain model weights"""
146+
weight_names = [
147+
WEIGHTS_NAME,
148+
SAFETENSORS_WEIGHTS_NAME,
149+
FLAX_WEIGHTS_NAME,
150+
ONNX_WEIGHTS_NAME,
151+
ONNX_EXTERNAL_WEIGHTS_NAME,
152+
]
153+
154+
if is_transformers_available():
155+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
156+
157+
return [f for f in filenames if any(f.endswith(wn) for wn in weight_names)]
158+
159+
144160
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
145161
weight_names = [
146162
WEIGHTS_NAME,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
_resolve_custom_pipeline_and_cls,
8888
_unwrap_model,
8989
_update_init_kwargs_with_connected_pipeline,
90+
filter_model_files,
9091
load_sub_model,
9192
maybe_raise_or_warn,
9293
variant_compatible_siblings,
@@ -1415,7 +1416,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14151416
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
14161417
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
14171418
)
1418-
model_folder_names = {os.path.split(f)[0] for f in filenames if os.path.split(f)[0] in folder_names}
14191419

14201420
# retrieve passed components that should not be downloaded
14211421
pipeline_class = _get_pipeline_class(
@@ -1432,6 +1432,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14321432
expected_components, _ = cls._get_signature_keys(pipeline_class)
14331433
passed_components = [k for k in expected_components if k in kwargs]
14341434

1435+
model_folder_names = {
1436+
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
1437+
}
14351438
# retrieve all patterns that should not be downloaded and error out when needed
14361439
ignore_patterns = _get_ignore_patterns(
14371440
passed_components,

0 commit comments

Comments
 (0)