Skip to content

Commit 6bf814f

Browse files
committed
adjust
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent b0d17ad commit 6bf814f

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
121121
# -00001-of-00002
122122
transformers_index_format = r"\d{5}-of-\d{5}"
123123
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
124+
variant_file_re = re.compile(
125+
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
126+
)
124127
non_variant_file_re = re.compile(
125128
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
126129
)
@@ -151,9 +154,11 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
151154
# if variant is provided check if the variant of the safetensors exists
152155
for component, component_filenames in components.items():
153156
matches = []
157+
if variant is not None:
158+
component_filenames = filter_with_regex(component_filenames, variant_file_re)
159+
else:
160+
component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
154161
for component_filename in component_filenames:
155-
if variant is None:
156-
component_filename = filter_with_regex(component_filename, non_variant_file_re)
157162
filename, extension = os.path.splitext(component_filename)
158163

159164
match_exists = extension == ".safetensors"

0 commit comments

Comments
 (0)