Skip to content

Commit 1b48db4

Browse files
sayakpaulDN6
andauthored
[core] respect local_files_only=True when using sharded checkpoints (#12005)
* tighten compilation tests for quantization * feat: model_info but local. * up * Revert "tighten compilation tests for quantization" This reverts commit 8d431dc. * up * reviewer feedback. * reviewer feedback. * up * up * empty * update --------- Co-authored-by: DN6 <[email protected]>
1 parent 46a0c6a commit 1b48db4

File tree

2 files changed

+66
-11
lines changed

2 files changed

+66
-11
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -402,15 +402,17 @@ def _get_checkpoint_shard_files(
402402
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
403403

404404
ignore_patterns = ["*.json", "*.md"]
405-
# `model_info` call must guarded with the above condition.
406-
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
407-
for shard_file in original_shard_filenames:
408-
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
409-
if not shard_file_present:
410-
raise EnvironmentError(
411-
f"{shards_path} does not appear to have a file named {shard_file} which is "
412-
"required according to the checkpoint index."
413-
)
405+
406+
# If the repo doesn't have the required shards, error out early even before downloading anything.
407+
if not local_files_only:
408+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
409+
for shard_file in original_shard_filenames:
410+
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
411+
if not shard_file_present:
412+
raise EnvironmentError(
413+
f"{shards_path} does not appear to have a file named {shard_file} which is "
414+
"required according to the checkpoint index."
415+
)
414416

415417
try:
416418
# Load from URL
@@ -437,6 +439,11 @@ def _get_checkpoint_shard_files(
437439
) from e
438440

439441
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
442+
for cached_file in cached_filenames:
443+
if not os.path.isfile(cached_file):
444+
raise EnvironmentError(
445+
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
446+
)
440447

441448
return cached_filenames, sharded_metadata
442449

tests/models/test_modeling_common.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636
import torch
3737
import torch.nn as nn
3838
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
39-
from huggingface_hub import ModelCard, delete_repo, snapshot_download
39+
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
4040
from huggingface_hub.utils import is_jinja_available
4141
from parameterized import parameterized
4242
from requests.exceptions import HTTPError
4343

44-
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
44+
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
4545
from diffusers.models.attention_processor import (
4646
AttnProcessor,
4747
AttnProcessor2_0,
@@ -291,6 +291,54 @@ def test_cached_files_are_used_when_no_internet(self):
291291
if p1.data.ne(p2.data).sum() > 0:
292292
assert False, "Parameters not the same!"
293293

294+
def test_local_files_only_with_sharded_checkpoint(self):
295+
repo_id = "hf-internal-testing/tiny-flux-sharded"
296+
error_response = mock.Mock(
297+
status_code=500,
298+
headers={},
299+
raise_for_status=mock.Mock(side_effect=HTTPError),
300+
json=mock.Mock(return_value={}),
301+
)
302+
303+
with tempfile.TemporaryDirectory() as tmpdir:
304+
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
305+
306+
with mock.patch("requests.Session.get", return_value=error_response):
307+
# Should fail with local_files_only=False (network required)
308+
# We would make a network call with model_info
309+
with self.assertRaises(OSError):
310+
FluxTransformer2DModel.from_pretrained(
311+
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
312+
)
313+
314+
# Should succeed with local_files_only=True (uses cache)
315+
# model_info call skipped
316+
local_model = FluxTransformer2DModel.from_pretrained(
317+
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
318+
)
319+
320+
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
321+
"Model parameters don't match!"
322+
)
323+
324+
# Remove a shard file
325+
cached_shard_file = try_to_load_from_cache(
326+
repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
327+
)
328+
os.remove(cached_shard_file)
329+
330+
# Attempting to load from cache should raise an error
331+
with self.assertRaises(OSError) as context:
332+
FluxTransformer2DModel.from_pretrained(
333+
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
334+
)
335+
336+
# Verify error mentions the missing shard
337+
error_msg = str(context.exception)
338+
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
339+
f"Expected error about missing shard, got: {error_msg}"
340+
)
341+
294342
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
295343
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
296344
def test_one_request_upon_cached(self):

0 commit comments

Comments
 (0)