Skip to content

Commit 7fd1a82

Browse files
committed
update
1 parent 09e063c commit 7fd1a82

File tree

2 files changed

+63
-36
lines changed

2 files changed

+63
-36
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -405,20 +405,14 @@ def _get_checkpoint_shard_files(
405405

406406
# If the repo doesn't have the required shards, error out early even before downloading anything.
407407
if not local_files_only:
408-
try:
409-
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
410-
for shard_file in original_shard_filenames:
411-
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
412-
if not shard_file_present:
413-
raise EnvironmentError(
414-
f"{shards_path} does not appear to have a file named {shard_file} which is "
415-
"required according to the checkpoint index."
416-
)
417-
except ConnectionError as e:
418-
raise EnvironmentError(
419-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
420-
" again after checking your internet connection."
421-
) from e
408+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
409+
for shard_file in original_shard_filenames:
410+
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
411+
if not shard_file_present:
412+
raise EnvironmentError(
413+
f"{shards_path} does not appear to have a file named {shard_file} which is "
414+
"required according to the checkpoint index."
415+
)
422416

423417
try:
424418
# Load from URL
@@ -436,16 +430,6 @@ def _get_checkpoint_shard_files(
436430
if subfolder is not None:
437431
cached_folder = os.path.join(cached_folder, subfolder)
438432

439-
# Check again after downloading/loading from the cache.
440-
model_files_info = _get_filepaths_for_folder(cached_folder)
441-
for shard_file in original_shard_filenames:
442-
shard_file_present = any(shard_file in k for k in model_files_info)
443-
if not shard_file_present:
444-
raise EnvironmentError(
445-
f"{shards_path} does not appear to have a file named {shard_file} which is "
446-
"required according to the checkpoint index."
447-
)
448-
449433
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
450434
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
451435
except HTTPError as e:
@@ -455,20 +439,15 @@ def _get_checkpoint_shard_files(
455439
) from e
456440

457441
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
442+
for cached_file in cached_filenames:
443+
if not os.path.isfile(cached_file):
444+
raise EnvironmentError(
445+
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
446+
)
458447

459448
return cached_filenames, sharded_metadata
460449

461450

462-
def _get_filepaths_for_folder(folder):
463-
relative_paths = []
464-
for root, dirs, files in os.walk(folder):
465-
for fname in files:
466-
abs_path = os.path.join(root, fname)
467-
rel_path = os.path.relpath(abs_path, start=folder)
468-
relative_paths.append(rel_path)
469-
return relative_paths
470-
471-
472451
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
473452
if filenames and folder:
474453
raise ValueError("Both `filenames` and `folder` cannot be provided.")

tests/models/test_modeling_common.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636
import torch
3737
import torch.nn as nn
3838
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
39-
from huggingface_hub import ModelCard, delete_repo, snapshot_download
39+
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
4040
from huggingface_hub.utils import is_jinja_available
4141
from parameterized import parameterized
4242
from requests.exceptions import HTTPError
4343

44-
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
44+
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
4545
from diffusers.models.attention_processor import (
4646
AttnProcessor,
4747
AttnProcessor2_0,
@@ -291,6 +291,54 @@ def test_cached_files_are_used_when_no_internet(self):
291291
if p1.data.ne(p2.data).sum() > 0:
292292
assert False, "Parameters not the same!"
293293

294+
def test_local_files_only_with_sharded_checkpoint(self):
295+
repo_id = "hf-internal-testing/tiny-flux-sharded"
296+
error_response = mock.Mock(
297+
status_code=500,
298+
headers={},
299+
raise_for_status=mock.Mock(side_effect=HTTPError),
300+
json=mock.Mock(return_value={}),
301+
)
302+
303+
with tempfile.TemporaryDirectory() as tmpdir:
304+
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
305+
306+
with mock.patch("requests.Session.get", return_value=error_response):
307+
# Should fail with local_files_only=False (network required)
308+
# We would make a network call with model_info
309+
with self.assertRaises(OSError):
310+
FluxTransformer2DModel.from_pretrained(
311+
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
312+
)
313+
314+
# Should succeed with local_files_only=True (uses cache)
315+
# model_info call skipped
316+
local_model = FluxTransformer2DModel.from_pretrained(
317+
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
318+
)
319+
320+
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
321+
"Model parameters don't match!"
322+
)
323+
324+
# Remove a shard file
325+
cached_shard_file = try_to_load_from_cache(
326+
repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
327+
)
328+
os.remove(cached_shard_file)
329+
330+
# Attempting to load from cache should raise an error
331+
with self.assertRaises(OSError) as context:
332+
FluxTransformer2DModel.from_pretrained(
333+
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
334+
)
335+
336+
# Verify error mentions the missing shard
337+
error_msg = str(context.exception)
338+
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
339+
f"Expected error about missing shard, got: {error_msg}"
340+
)
341+
294342
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
295343
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
296344
def test_one_request_upon_cached(self):

0 commit comments

Comments
 (0)