Skip to content

Commit aa0d497

Browse files
committed
fix tests
1 parent 59929a5 commit aa0d497

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def is_saveable_module(name, value):
313313
dduf_file_path = os.path.join(save_directory, dduf_file)
314314
dir_to_archive = os.path.join(save_directory, pipeline_component_name)
315315
if os.path.isdir(dir_to_archive):
316-
export_folder_as_dduf(dduf_file_path, dir_to_archive, append=True, retain_base_folder=True)
316+
export_folder_as_dduf(dduf_file_path, dir_to_archive)
317317
shutil.rmtree(dir_to_archive)
318318

319319
# finally save the config

tests/pipelines/test_pipelines_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,8 @@ def test_StableDiffusionMixin_component(self):
19051905
# @pytest.mark.xfail(condition=not os.getenv("RUN_DDUF_TEST", False), strict=True)
19061906
# Should consider guarding the test with proper transformers and huggingface_hub versions.
19071907
def test_save_load_dduf(self):
1908+
from huggingface_hub import export_folder_as_dduf
1909+
19081910
components = self.get_dummy_components()
19091911
pipe = self.pipeline_class(**components)
19101912
pipe = pipe.to(torch_device)
@@ -1917,8 +1919,10 @@ def test_save_load_dduf(self):
19171919
pipeline_out = pipe(**inputs).images
19181920

19191921
with tempfile.TemporaryDirectory() as tmpdir:
1920-
dduf_filename = f"{pipe.__class__.__name__.lower()}.dduf"
1921-
pipe.save_pretrained(tmpdir, dduf_file=dduf_filename)
1922+
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
1923+
pipe.save_pretrained(tmpdir, safe_serialization=True)
1924+
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
1925+
print(f"{os.listdir(tmpdir)=}")
19221926
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
19231927

19241928
inputs["generator"] = torch.manual_seed(0)

0 commit comments

Comments
 (0)