|  | 
| 36 | 36 | import torch | 
| 37 | 37 | import torch.nn as nn | 
| 38 | 38 | from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size | 
| 39 |  | -from huggingface_hub import ModelCard, delete_repo, snapshot_download | 
|  | 39 | +from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache | 
| 40 | 40 | from huggingface_hub.utils import is_jinja_available | 
| 41 | 41 | from parameterized import parameterized | 
| 42 | 42 | from requests.exceptions import HTTPError | 
| 43 | 43 | 
 | 
| 44 |  | -from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel | 
|  | 44 | +from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel | 
| 45 | 45 | from diffusers.models.attention_processor import ( | 
| 46 | 46 |     AttnProcessor, | 
| 47 | 47 |     AttnProcessor2_0, | 
| @@ -291,6 +291,54 @@ def test_cached_files_are_used_when_no_internet(self): | 
| 291 | 291 |             if p1.data.ne(p2.data).sum() > 0: | 
| 292 | 292 |                 assert False, "Parameters not the same!" | 
| 293 | 293 | 
 | 
|  | 294 | +    def test_local_files_only_with_sharded_checkpoint(self): | 
|  | 295 | +        repo_id = "hf-internal-testing/tiny-flux-sharded" | 
|  | 296 | +        error_response = mock.Mock( | 
|  | 297 | +            status_code=500, | 
|  | 298 | +            headers={}, | 
|  | 299 | +            raise_for_status=mock.Mock(side_effect=HTTPError), | 
|  | 300 | +            json=mock.Mock(return_value={}), | 
|  | 301 | +        ) | 
|  | 302 | + | 
|  | 303 | +        with tempfile.TemporaryDirectory() as tmpdir: | 
|  | 304 | +            model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir) | 
|  | 305 | + | 
|  | 306 | +            with mock.patch("requests.Session.get", return_value=error_response): | 
|  | 307 | +                # Should fail with local_files_only=False (network required) | 
|  | 308 | +                # We would make a network call with model_info | 
|  | 309 | +                with self.assertRaises(OSError): | 
|  | 310 | +                    FluxTransformer2DModel.from_pretrained( | 
|  | 311 | +                        repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False | 
|  | 312 | +                    ) | 
|  | 313 | + | 
|  | 314 | +                # Should succeed with local_files_only=True (uses cache) | 
|  | 315 | +                # model_info call skipped | 
|  | 316 | +                local_model = FluxTransformer2DModel.from_pretrained( | 
|  | 317 | +                    repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True | 
|  | 318 | +                ) | 
|  | 319 | + | 
|  | 320 | +            assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( | 
|  | 321 | +                "Model parameters don't match!" | 
|  | 322 | +            ) | 
|  | 323 | + | 
|  | 324 | +            # Remove a shard file | 
|  | 325 | +            cached_shard_file = try_to_load_from_cache( | 
|  | 326 | +                repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir | 
|  | 327 | +            ) | 
|  | 328 | +            os.remove(cached_shard_file) | 
|  | 329 | + | 
|  | 330 | +            # Attempting to load from cache should raise an error | 
|  | 331 | +            with self.assertRaises(OSError) as context: | 
|  | 332 | +                FluxTransformer2DModel.from_pretrained( | 
|  | 333 | +                    repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True | 
|  | 334 | +                ) | 
|  | 335 | + | 
|  | 336 | +            # Verify error mentions the missing shard | 
|  | 337 | +            error_msg = str(context.exception) | 
|  | 338 | +            assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( | 
|  | 339 | +                f"Expected error about missing shard, got: {error_msg}" | 
|  | 340 | +            ) | 
|  | 341 | + | 
| 294 | 342 |     @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") | 
| 295 | 343 |     @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") | 
| 296 | 344 |     def test_one_request_upon_cached(self): | 
|  | 
0 commit comments