Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components.setdefault(component, [])
components[component].append(component_filename)

# If there are no component folders check the main directory for safetensors files
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sufficient as a check? 👁️

if not components:
return any(".safetensors" in filename for filename in filenames)
Comment on lines +122 to +123
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also mention the description of this function because this is not accounted for there.

Converting default pytorch serialized filenames to safetensors serialized filenames:


# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ def test_diffusers_is_compatible_only_variants(self):
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_no_components(self):
filenames = [
"diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_no_components_only_variants(self):
filenames = [
"diffusion_pytorch_model.fp16.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))
Comment on lines +200 to +210
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it also make sense to check the dummy repo here?



class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
Expand Down
Loading