|
36 | 36 | import torch
|
37 | 37 | import torch.nn as nn
|
38 | 38 | from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
39 |
| -from huggingface_hub import ModelCard, delete_repo, snapshot_download |
| 39 | +from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache |
40 | 40 | from huggingface_hub.utils import is_jinja_available
|
41 | 41 | from parameterized import parameterized
|
42 | 42 | from requests.exceptions import HTTPError
|
43 | 43 |
|
44 |
| -from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel |
| 44 | +from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel |
45 | 45 | from diffusers.models.attention_processor import (
|
46 | 46 | AttnProcessor,
|
47 | 47 | AttnProcessor2_0,
|
@@ -291,6 +291,54 @@ def test_cached_files_are_used_when_no_internet(self):
|
291 | 291 | if p1.data.ne(p2.data).sum() > 0:
|
292 | 292 | assert False, "Parameters not the same!"
|
293 | 293 |
|
| 294 | + def test_local_files_only_with_sharded_checkpoint(self): |
| 295 | + repo_id = "hf-internal-testing/tiny-flux-sharded" |
| 296 | + error_response = mock.Mock( |
| 297 | + status_code=500, |
| 298 | + headers={}, |
| 299 | + raise_for_status=mock.Mock(side_effect=HTTPError), |
| 300 | + json=mock.Mock(return_value={}), |
| 301 | + ) |
| 302 | + |
| 303 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 304 | + model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir) |
| 305 | + |
| 306 | + with mock.patch("requests.Session.get", return_value=error_response): |
| 307 | + # Should fail with local_files_only=False (network required) |
| 308 | + # We would make a network call with model_info |
| 309 | + with self.assertRaises(OSError): |
| 310 | + FluxTransformer2DModel.from_pretrained( |
| 311 | + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False |
| 312 | + ) |
| 313 | + |
| 314 | + # Should succeed with local_files_only=True (uses cache) |
| 315 | + # model_info call skipped |
| 316 | + local_model = FluxTransformer2DModel.from_pretrained( |
| 317 | + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True |
| 318 | + ) |
| 319 | + |
| 320 | + assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( |
| 321 | + "Model parameters don't match!" |
| 322 | + ) |
| 323 | + |
| 324 | + # Remove a shard file |
| 325 | + cached_shard_file = try_to_load_from_cache( |
| 326 | + repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir |
| 327 | + ) |
| 328 | + os.remove(cached_shard_file) |
| 329 | + |
| 330 | + # Attempting to load from cache should raise an error |
| 331 | + with self.assertRaises(OSError) as context: |
| 332 | + FluxTransformer2DModel.from_pretrained( |
| 333 | + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True |
| 334 | + ) |
| 335 | + |
| 336 | + # Verify error mentions the missing shard |
| 337 | + error_msg = str(context.exception) |
| 338 | + assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( |
| 339 | + f"Expected error about missing shard, got: {error_msg}" |
| 340 | + ) |
| 341 | + |
294 | 342 | @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
295 | 343 | @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
296 | 344 | def test_one_request_upon_cached(self):
|
|
0 commit comments