Skip to content

Commit bfc66f8

Browse files
committed
update
1 parent b5c2050 commit bfc66f8

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
146146
components[component].append(component_filename)
147147

148148
# If there are no component folders check the main directory for safetensors files
149+
filtered_filenames = set()
149150
if not components:
150151
if variant is not None:
151152
filtered_filenames = filter_with_regex(filenames, variant_file_re)
152-
else:
153+
154+
# If no variant filenames exist check if non-variant files are available
155+
if not filtered_filenames:
153156
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
154157
return any(".safetensors" in filename for filename in filtered_filenames)
155158

156159
# iterate over all files of a component
157160
# check if safetensor files exist for that component
158-
# if variant is provided check if the variant of the safetensors exists
159161
for component, component_filenames in components.items():
160162
matches = []
163+
filtered_component_filenames = set()
164+
# if variant is provided check if the variant of the safetensors exists
161165
if variant is not None:
162166
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
163-
else:
167+
168+
# if variant safetensor files do not exist check for non-variants
169+
if not filtered_component_filenames:
164170
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
165171
for component_filename in filtered_component_filenames:
166172
filename, extension = os.path.splitext(component_filename)

tests/pipelines/test_pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
217217
]
218218
self.assertFalse(is_safetensors_compatible(filenames))
219219

220+
def test_is_compatible_mixed_variants(self):
221+
filenames = [
222+
"unet/diffusion_pytorch_model.fp16.safetensors",
223+
"vae/diffusion_pytorch_model.safetensors",
224+
]
225+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
226+
220227

221228
class VariantCompatibleSiblingsTest(unittest.TestCase):
222229
def test_only_non_variants_downloaded(self):

0 commit comments

Comments
 (0)