Skip to content

Commit ea0126d

Browse files
committed
add loading from the hub test
1 parent 73e81a5 commit ea0126d

File tree

5 files changed

+47
-5
lines changed

5 files changed

+47
-5
lines changed

src/diffusers/configuration_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,11 @@ def load_config(
364364
if subfolder is not None:
365365
raise ValueError("DDUF file only allow for 1 level of directory. Please check the DDUF structure")
366366
# paths inside a DDUF file must always be "/"
367-
config_file = cls.config_name if pretrained_model_name_or_path == "" else "/".join([pretrained_model_name_or_path, cls.config_name])
367+
config_file = (
368+
cls.config_name
369+
if pretrained_model_name_or_path == ""
370+
else "/".join([pretrained_model_name_or_path, cls.config_name])
371+
)
368372
if config_file not in dduf_entries:
369373
raise ValueError(
370374
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"

src/diffusers/utils/hub_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,11 @@ def _get_model_file(
299299
if dduf_entries:
300300
if subfolder is not None:
301301
raise ValueError("DDUF file only allow for 1 level of directory. Please check the DDUF structure")
302-
model_file = weights_name if pretrained_model_name_or_path == "" else "/".join([pretrained_model_name_or_path, weights_name])
302+
model_file = (
303+
weights_name
304+
if pretrained_model_name_or_path == ""
305+
else "/".join([pretrained_model_name_or_path, weights_name])
306+
)
303307
if model_file in dduf_entries:
304308
return model_file
305309
else:

src/diffusers/utils/testing_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,18 @@ def decorator(test_case):
476476
return decorator
477477

478478

479+
def require_hf_hub_version_greater(hf_hub_version):
480+
def decorator(test_case):
481+
correct_hf_hub_version = version.parse(
482+
version.parse(importlib.metadata.version("huggingface_hub")).base_version
483+
) > version.parse(hf_hub_version)
484+
return unittest.skipUnless(
485+
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
486+
)(test_case)
487+
488+
return decorator
489+
490+
479491
def deprecate_after_peft_backend(test_case):
480492
"""
481493
Decorator marking a test that will be skipped after PEFT backend

tests/pipelines/test_pipelines.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@
7575
nightly,
7676
require_compel,
7777
require_flax,
78+
require_hf_hub_version_greater,
7879
require_onnxruntime,
7980
require_torch_2,
8081
require_torch_gpu,
82+
require_transformers_version_greater,
8183
run_test_in_subprocess,
8284
slow,
8385
torch_device,
@@ -1802,6 +1804,25 @@ def test_pipe_same_device_id_offload(self):
18021804
sd.maybe_free_model_hooks()
18031805
assert sd._offload_gpu_id == 5
18041806

1807+
@require_hf_hub_version_greater("0.26.5")
1808+
@require_transformers_version_greater("4.47.0")
1809+
@parameterized.expand([torch.float32, torch.float16])
1810+
def test_load_dduf_from_hub(self, dtype):
1811+
with tempfile.TemporaryDirectory() as tmpdir:
1812+
pipe = DiffusionPipeline.from_pretrained(
1813+
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, torch_dtype=dtype
1814+
).to(torch_device)
1815+
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
1816+
1817+
pipe.save_pretrained(tmpdir)
1818+
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=dtype).to(torch_device)
1819+
1820+
out_2 = loaded_pipe(
1821+
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
1822+
).images
1823+
1824+
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
1825+
18051826

18061827
@slow
18071828
@require_torch_gpu

tests/pipelines/test_pipelines_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
CaptureLogger,
4444
require_accelerate_version_greater,
4545
require_accelerator,
46+
require_hf_hub_version_greater,
4647
require_torch,
48+
require_transformers_version_greater,
4749
skip_mps,
4850
torch_device,
4951
)
@@ -1902,8 +1904,8 @@ def test_StableDiffusionMixin_component(self):
19021904
)
19031905
)
19041906

1905-
# @pytest.mark.xfail(condition=not os.getenv("RUN_DDUF_TEST", False), strict=True)
1906-
# Should consider guarding the test with proper transformers and huggingface_hub versions.
1907+
@require_hf_hub_version_greater("0.26.5")
1908+
@require_transformers_version_greater("4.47.0")
19071909
def test_save_load_dduf(self):
19081910
from huggingface_hub import export_folder_as_dduf
19091911

@@ -1922,7 +1924,6 @@ def test_save_load_dduf(self):
19221924
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
19231925
pipe.save_pretrained(tmpdir, safe_serialization=True)
19241926
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
1925-
print(f"{os.listdir(tmpdir)=}")
19261927
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
19271928

19281929
inputs["generator"] = torch.manual_seed(0)

0 commit comments

Comments
 (0)