Skip to content

[core] respect local_files_only=True when using sharded checkpoints #12005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 14, 2025
Merged
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,23 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]

ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the purpose of this check is to verify if the necessary sharded files are present in the model repo before attempting a download, presumably to avoid a large download if all files aren't present. If we cannot connect to the hub, we just have to assume the necessary shard files are already present locally.

I think we can just skip this check if local_files_only=True and then check if all the shard filenames are present in the cached_folder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about now?

Copy link
Collaborator

@DN6 DN6 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just this is sufficient

if not local_files_only:
    # run model_info check

Run snapshot download

Then after the cached_filenames is created, iterate over the files to verify they exist

for filename in cached_filename:
      if not if not os.path.exists(filename):
           raise EnvironmentError("expected file not present in {cached_folder}")

Copy link
Member Author

@sayakpaul sayakpaul Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We don't have to run snapshot_download() when local_files_only=False, that might be unnecessary.
  2. Why run snapshot_download() after also running model_info()?
  3. Even if we run snapshot_download() regardless of local_files_only var, I think we should have it inside try-except in case the endpoint cannot be pinged for some reason and raise the ConnectionError as before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if b7af511 resolves this.

Copy link
Member Author

@sayakpaul sayakpaul Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what you mean. Let me update. Sorry about the back and forth.

for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:

# If the repo doesn't have the required shards, error out early even before downloading anything.
if not local_files_only:
try:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
except ConnectionError as e:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e

try:
# Load from URL
Expand All @@ -428,6 +436,16 @@ def _get_checkpoint_shard_files(
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)

# Check again after downloading/loading from the cache.
model_files_info = _get_filepaths_for_folder(cached_folder)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a new function to walk over the cached folder. If you just iterate over cached_filenames and check if the file exits. Avoid looping over files multiple times this way.

for cached_file in cached_filenames:
      if not os.path.exists(cached_file):
                raise EnvironmentError(f"{cached_file} not found in {cached_folder which is required..." 

for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k for k in model_files_info)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)

# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
Expand All @@ -441,6 +459,16 @@ def _get_checkpoint_shard_files(
return cached_filenames, sharded_metadata


def _get_filepaths_for_folder(folder):
relative_paths = []
for root, dirs, files in os.walk(folder):
for fname in files:
abs_path = os.path.join(root, fname)
rel_path = os.path.relpath(abs_path, start=folder)
relative_paths.append(rel_path)
return relative_paths


def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")
Expand Down
Loading