Skip to content

Commit 0205cc8

Browse files
committed
add test for sharded checkpoint
1 parent a032025 commit 0205cc8

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
825825
dduf_entries=dduf_entries,
826826
)
827827
# TODO: https://github.com/huggingface/diffusers/issues/10013
828-
if hf_quantizer is not None:
828+
if hf_quantizer is not None or dduf_entries:
829829
model_file = _merge_sharded_checkpoints(
830830
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
831831
)

tests/pipelines/test_pipelines.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,6 +1875,18 @@ def test_wrong_model(self):
18751875
assert "is of type" in str(error_context.exception)
18761876
assert "but should be" in str(error_context.exception)
18771877

1878+
@require_hf_hub_version_greater("0.26.5")
1879+
@require_transformers_version_greater("4.47.1")
1880+
def test_dduf_load_sharded_checkpoint_diffusion_model(self):
1881+
with tempfile.TemporaryDirectory() as tmpdir:
1882+
pipe = DiffusionPipeline.from_pretrained(
1883+
"hf-internal-testing/tiny-flux-dev-pipe-sharded-checkpoint-DDUF",
1884+
dduf_file="tiny-flux-dev-pipe-sharded-checkpoint.dduf",
1885+
cache_dir=tmpdir,
1886+
).to(torch_device)
1887+
1888+
pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
1889+
18781890

18791891
@slow
18801892
@require_torch_gpu

0 commit comments

Comments
 (0)