Skip to content

Commit 0709650

Browse files
[Local loading] Correct bug with local files only (#4318)
* [Local loading] Correct bug with local files only * file not found error * fix * finish
1 parent a982916 commit 0709650

File tree

8 files changed

+66
-16
lines changed

8 files changed

+66
-16
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,11 +1474,25 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14741474
user_agent=user_agent,
14751475
)
14761476

1477-
if pipeline_class._load_connected_pipes:
1477+
# retrieve pipeline class from local file
1478+
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1479+
pipeline_class = getattr(diffusers, cls_name, None)
1480+
1481+
if pipeline_class is not None and pipeline_class._load_connected_pipes:
14781482
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
14791483
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
14801484
for connected_pipe_repo_id in connected_pipes:
1481-
DiffusionPipeline.download(connected_pipe_repo_id)
1485+
download_kwargs = {
1486+
"cache_dir": cache_dir,
1487+
"resume_download": resume_download,
1488+
"force_download": force_download,
1489+
"proxies": proxies,
1490+
"local_files_only": local_files_only,
1491+
"use_auth_token": use_auth_token,
1492+
"variant": variant,
1493+
"use_safetensors": use_safetensors,
1494+
}
1495+
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
14821496

14831497
return cached_folder
14841498

tests/models/test_lora_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,10 @@ def get_dummy_components(self):
563563
projection_dim=32,
564564
)
565565
text_encoder = CLIPTextModel(text_encoder_config)
566-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
566+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
567567

568568
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
569-
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
569+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
570570

571571
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
572572
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def get_dummy_components(self):
100100
projection_dim=32,
101101
)
102102
text_encoder = CLIPTextModel(text_encoder_config)
103-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
103+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
104104

105105
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
106-
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
106+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
107107

108108
components = {
109109
"unet": unet,

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def get_dummy_components(self, skip_first_text_encoder=False):
100100
projection_dim=32,
101101
)
102102
text_encoder = CLIPTextModel(text_encoder_config)
103-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
103+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
104104

105105
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
106-
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
106+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
107107

108108
components = {
109109
"unet": unet,

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ def get_dummy_components(self, skip_first_text_encoder=False):
102102
projection_dim=32,
103103
)
104104
text_encoder = CLIPTextModel(text_encoder_config)
105-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
105+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
106106

107107
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
108-
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
108+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
109109

110110
components = {
111111
"unet": unet,

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def get_dummy_components(self):
105105
projection_dim=32,
106106
)
107107
text_encoder = CLIPTextModel(text_encoder_config)
108-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
108+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
109109

110110
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
111-
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
111+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
112112

113113
components = {
114114
"unet": unet,

tests/pipelines/test_pipelines.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,15 @@ def test_cached_files_are_used_when_no_internet(self):
374374
response_mock.json.return_value = {}
375375

376376
# Download this model to make sure it's in the cache.
377-
orig_pipe = StableDiffusionPipeline.from_pretrained(
377+
orig_pipe = DiffusionPipeline.from_pretrained(
378378
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
379379
)
380380
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
381381

382382
# Under the mock environment we get a 500 error when trying to reach the model.
383383
with mock.patch("requests.request", return_value=response_mock):
384384
# Download this model to make sure it's in the cache.
385-
pipe = StableDiffusionPipeline.from_pretrained(
385+
pipe = DiffusionPipeline.from_pretrained(
386386
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
387387
)
388388
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
@@ -392,6 +392,42 @@ def test_cached_files_are_used_when_no_internet(self):
392392
if p1.data.ne(p2.data).sum() > 0:
393393
assert False, "Parameters not the same!"
394394

395+
def test_local_files_only_are_used_when_no_internet(self):
396+
# A mock response for an HTTP head request to emulate server down
397+
response_mock = mock.Mock()
398+
response_mock.status_code = 500
399+
response_mock.headers = {}
400+
response_mock.raise_for_status.side_effect = HTTPError
401+
response_mock.json.return_value = {}
402+
403+
# first check that with local files only the pipeline can only be used if cached
404+
with self.assertRaises(FileNotFoundError):
405+
with tempfile.TemporaryDirectory() as tmpdirname:
406+
orig_pipe = DiffusionPipeline.from_pretrained(
407+
"hf-internal-testing/tiny-stable-diffusion-torch", local_files_only=True, cache_dir=tmpdirname
408+
)
409+
410+
# now download
411+
orig_pipe = DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-torch")
412+
413+
# make sure it can be loaded with local_files_only
414+
orig_pipe = DiffusionPipeline.from_pretrained(
415+
"hf-internal-testing/tiny-stable-diffusion-torch", local_files_only=True
416+
)
417+
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
418+
419+
# Under the mock environment we get a 500 error when trying to connect to the internet.
420+
# Make sure it works local_files_only only works here!
421+
with mock.patch("requests.request", return_value=response_mock):
422+
# Download this model to make sure it's in the cache.
423+
pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
424+
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
425+
426+
for m1, m2 in zip(orig_comps.values(), comps.values()):
427+
for p1, p2 in zip(m1.parameters(), m2.parameters()):
428+
if p1.data.ne(p2.data).sum() > 0:
429+
assert False, "Parameters not the same!"
430+
395431
def test_download_from_variant_folder(self):
396432
for safe_avail in [False, True]:
397433
import diffusers

tests/pipelines/test_pipelines_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _test_inference_batch_consistent(
387387
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
388388

389389
# make last batch super long
390-
batched_inputs[name][-1] = 2000 * "very long"
390+
batched_inputs[name][-1] = 100 * "very long"
391391
# or else we have images
392392
else:
393393
batched_inputs[name] = batch_size * [value]
@@ -462,7 +462,7 @@ def _test_inference_batch_single_identical(
462462
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
463463

464464
# make last batch super long
465-
batched_inputs[name][-1] = 2000 * "very long"
465+
batched_inputs[name][-1] = 100 * "very long"
466466
# or else we have images
467467
else:
468468
batched_inputs[name] = batch_size * [value]

0 commit comments

Comments
 (0)