Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/modular_pipelines/test_modular_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from huggingface_hub import hf_hub_download

import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
Expand Down Expand Up @@ -32,6 +33,33 @@
)


def _get_specified_components(path_or_repo_id, cache_dir=None):
if os.path.isdir(path_or_repo_id):
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
else:
try:
config_path = hf_hub_download(
repo_id=path_or_repo_id,
filename="modular_model_index.json",
local_dir=cache_dir,
)
except Exception:
return None

with open(config_path) as f:
config = json.load(f)

components = set()
for k, v in config.items():
if isinstance(v, (str, int, float, bool)):
continue
for entry in v:
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
components.add(k)
break
return components


class ModularPipelineTesterMixin:
"""
It provides a set of common tests for each modular pipeline,
Expand Down Expand Up @@ -360,6 +388,39 @@ def test_save_from_pretrained(self, tmp_path):

assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3

def test_load_expected_components_from_pretrained(self, tmp_path):
pipe = self.get_pipeline()
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
if not expected:
pytest.skip("Skipping test as we couldn't fetch the expected components.")

actual = {
name
for name in pipe.components
if getattr(pipe, name, None) is not None
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"

def test_load_expected_components_from_save_pretrained(self, tmp_path):
pipe = self.get_pipeline()
save_dir = str(tmp_path / "saved-pipeline")
pipe.save_pretrained(save_dir)

expected = _get_specified_components(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)

actual = {
name
for name in loaded_pipe.components
if getattr(loaded_pipe, name, None) is not None
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, (
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
)

def test_modular_index_consistency(self, tmp_path):
pipe = self.get_pipeline()
components_spec = pipe._component_specs
Expand Down
Loading