Skip to content

[core] respect local_files_only=True when using sharded checkpoints #12005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,17 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]

ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the purpose of this check is to verify if the necessary sharded files are present in the model repo before attempting a download, presumably to avoid a large download if all files aren't present. If we cannot connect to the hub, we just have to assume the necessary shard files are already present locally.

I think we can just skip this check if local_files_only=True and then check if all the shard filenames are present in the cached_folder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about now?

Copy link
Collaborator

@DN6 DN6 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just this is sufficient

if not local_files_only:
    # run model_info check

Run snapshot download

Then after the cached_filenames is created, iterate over the files to verify they exist

for filename in cached_filename:
      if not if not os.path.exists(filename):
           raise EnvironmentError("expected file not present in {cached_folder}")

Copy link
Member Author

@sayakpaul sayakpaul Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We don't have to run snapshot_download() when local_files_only=False, that might be unnecessary.
  2. Why run snapshot_download() after also running model_info()?
  3. Even if we run snapshot_download() regardless of local_files_only var, I think we should have it inside try-except in case the endpoint cannot be pinged for some reason and raise the ConnectionError as before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if b7af511 resolves this.

Copy link
Member Author

@sayakpaul sayakpaul Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what you mean. Let me update. Sorry about the back and forth.

for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)

# If the repo doesn't have the required shards, error out early even before downloading anything.
if not local_files_only:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)

try:
# Load from URL
Expand All @@ -437,6 +439,11 @@ def _get_checkpoint_shard_files(
) from e

cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
for cached_file in cached_filenames:
if not os.path.isfile(cached_file):
raise EnvironmentError(
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
)

return cached_filenames, sharded_metadata

Expand Down
52 changes: 50 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError

from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
Expand Down Expand Up @@ -291,6 +291,54 @@ def test_cached_files_are_used_when_no_internet(self):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"

def test_local_files_only_with_sharded_checkpoint(self):
repo_id = "hf-internal-testing/tiny-flux-sharded"
error_response = mock.Mock(
status_code=500,
headers={},
raise_for_status=mock.Mock(side_effect=HTTPError),
json=mock.Mock(return_value={}),
)

with tempfile.TemporaryDirectory() as tmpdir:
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)

with mock.patch("requests.Session.get", return_value=error_response):
# Should fail with local_files_only=False (network required)
# We would make a network call with model_info
with self.assertRaises(OSError):
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
)

# Should succeed with local_files_only=True (uses cache)
# model_info call skipped
local_model = FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)

assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
"Model parameters don't match!"
)

# Remove a shard file
cached_shard_file = try_to_load_from_cache(
repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
)
os.remove(cached_shard_file)

# Attempting to load from cache should raise an error
with self.assertRaises(OSError) as context:
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)

# Verify error mentions the missing shard
error_msg = str(context.exception)
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
f"Expected error about missing shard, got: {error_msg}"
)

@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
def test_one_request_upon_cached(self):
Expand Down