Skip to content

Commit 3db5a69

Browse files
committed
update
1 parent 6899f40 commit 3db5a69

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ def filter_model_files(filenames):
154154
if is_transformers_available():
155155
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
156156

157-
return [f for f in filenames if any(f.endswith(wn) for wn in weight_names)]
157+
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
158+
159+
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
158160

159161

160162
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14321432
expected_components, _ = cls._get_signature_keys(pipeline_class)
14331433
passed_components = [k for k in expected_components if k in kwargs]
14341434

1435+
# retrieve the names of the folders containing model weights
14351436
model_folder_names = {
14361437
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
14371438
}

0 commit comments

Comments
 (0)