@@ -167,9 +167,9 @@ def test_one_request_upon_cached(self):
167167 download_requests = [r .method for r in m .request_history ]
168168 assert download_requests .count ("HEAD" ) == 15 , "15 calls to files"
169169 assert download_requests .count ("GET" ) == 17 , "15 calls to files + model_info + model_index.json"
170- assert len ( download_requests ) == 32 , (
171- "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
172- )
170+ assert (
171+ len ( download_requests ) == 32
172+ ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
173173
174174 with requests_mock .mock (real_http = True ) as m :
175175 DiffusionPipeline .download (
@@ -179,9 +179,9 @@ def test_one_request_upon_cached(self):
179179 cache_requests = [r .method for r in m .request_history ]
180180 assert cache_requests .count ("HEAD" ) == 1 , "model_index.json is only HEAD"
181181 assert cache_requests .count ("GET" ) == 1 , "model info is only GET"
182- assert len ( cache_requests ) == 2 , (
183- "We should call only `model_info` to check for _commit hash and `send_telemetry`"
184- )
182+ assert (
183+ len ( cache_requests ) == 2
184+ ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
185185
186186 def test_less_downloads_passed_object (self ):
187187 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -217,9 +217,9 @@ def test_less_downloads_passed_object_calls(self):
217217 assert download_requests .count ("HEAD" ) == 13 , "13 calls to files"
218218 # 17 - 2 because no call to config or model file for `safety_checker`
219219 assert download_requests .count ("GET" ) == 15 , "13 calls to files + model_info + model_index.json"
220- assert len ( download_requests ) == 28 , (
221- "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
222- )
220+ assert (
221+ len ( download_requests ) == 28
222+ ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
223223
224224 with requests_mock .mock (real_http = True ) as m :
225225 DiffusionPipeline .download (
@@ -229,9 +229,9 @@ def test_less_downloads_passed_object_calls(self):
229229 cache_requests = [r .method for r in m .request_history ]
230230 assert cache_requests .count ("HEAD" ) == 1 , "model_index.json is only HEAD"
231231 assert cache_requests .count ("GET" ) == 1 , "model info is only GET"
232- assert len ( cache_requests ) == 2 , (
233- "We should call only `model_info` to check for _commit hash and `send_telemetry`"
234- )
232+ assert (
233+ len ( cache_requests ) == 2
234+ ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
235235
236236 def test_download_only_pytorch (self ):
237237 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -538,26 +538,38 @@ def test_download_variant_partly(self):
538538 variant = "no_ema"
539539
540540 with tempfile .TemporaryDirectory () as tmpdirname :
541- tmpdirname = StableDiffusionPipeline .download (
542- "hf-internal-testing/stable-diffusion-all-variants" ,
543- cache_dir = tmpdirname ,
544- variant = variant ,
545- use_safetensors = use_safetensors ,
546- )
547- all_root_files = [t [- 1 ] for t in os .walk (tmpdirname )]
548- files = [item for sublist in all_root_files for item in sublist ]
549-
550- unet_files = os .listdir (os .path .join (tmpdirname , "unet" ))
541+ if use_safetensors :
542+ with self .assertRaises (OSError ) as error_context :
543+ tmpdirname = StableDiffusionPipeline .download (
544+ "hf-internal-testing/stable-diffusion-all-variants" ,
545+ cache_dir = tmpdirname ,
546+ variant = variant ,
547+ use_safetensors = use_safetensors ,
548+ )
549+ assert "Could not find the necessary `safetensors` weights" in str (error_context .exception )
550+ else :
551+ tmpdirname = StableDiffusionPipeline .download (
552+ "hf-internal-testing/stable-diffusion-all-variants" ,
553+ cache_dir = tmpdirname ,
554+ variant = variant ,
555+ use_safetensors = use_safetensors ,
556+ )
557+ all_root_files = [t [- 1 ] for t in os .walk (tmpdirname )]
558+ files = [item for sublist in all_root_files for item in sublist ]
551559
552- # Some of the downloaded files should be a non-variant file, check:
553- # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
554- assert len (files ) == 15 , f"We should only download 15 files, not { len (files )} "
555- # only unet has "no_ema" variant
556- assert f"diffusion_pytorch_model.{ variant } { this_format } " in unet_files
557- assert len ([f for f in files if f .endswith (f"{ variant } { this_format } " )]) == 1
558- # vae, safety_checker and text_encoder should have no variant
559- assert sum (f .endswith (this_format ) and not f .endswith (f"{ variant } { this_format } " ) for f in files ) == 3
560- assert not any (f .endswith (other_format ) for f in files )
560+ unet_files = os .listdir (os .path .join (tmpdirname , "unet" ))
561+
562+ # Some of the downloaded files should be a non-variant file, check:
563+ # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
564+ assert len (files ) == 15 , f"We should only download 15 files, not { len (files )} "
565+ # only unet has "no_ema" variant
566+ assert f"diffusion_pytorch_model.{ variant } { this_format } " in unet_files
567+ assert len ([f for f in files if f .endswith (f"{ variant } { this_format } " )]) == 1
568+ # vae, safety_checker and text_encoder should have no variant
569+ assert (
570+ sum (f .endswith (this_format ) and not f .endswith (f"{ variant } { this_format } " ) for f in files ) == 3
571+ )
572+ assert not any (f .endswith (other_format ) for f in files )
561573
562574 def test_download_variants_with_sharded_checkpoints (self ):
563575 # Here we test for downloading of "variant" files belonging to the `unet` and
@@ -588,19 +600,16 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
588600 logger = logging .get_logger ("diffusers.pipelines.pipeline_utils" )
589601 deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
590602
591- for is_local in [True , False ]:
592- with CaptureLogger (logger ) as cap_logger :
593- with tempfile .TemporaryDirectory () as tmpdirname :
594- local_repo_id = repo_id
595- if is_local :
596- local_repo_id = snapshot_download (repo_id , cache_dir = tmpdirname )
603+ with CaptureLogger (logger ) as cap_logger :
604+ with tempfile .TemporaryDirectory () as tmpdirname :
605+ local_repo_id = snapshot_download (repo_id , cache_dir = tmpdirname )
597606
598- _ = DiffusionPipeline .from_pretrained (
599- local_repo_id ,
600- safety_checker = None ,
601- variant = "fp16" ,
602- use_safetensors = True ,
603- )
607+ _ = DiffusionPipeline .from_pretrained (
608+ local_repo_id ,
609+ safety_checker = None ,
610+ variant = "fp16" ,
611+ use_safetensors = True ,
612+ )
604613 assert deprecated_warning_msg in str (cap_logger ), "Deprecation warning not found in logs"
605614
606615 def test_download_safetensors_only_variant_exists_for_model (self ):
@@ -616,7 +625,7 @@ def test_download_safetensors_only_variant_exists_for_model(self):
616625 variant = variant ,
617626 use_safetensors = use_safetensors ,
618627 )
619- assert "Error no file name " in str (error_context .exception )
628+ assert "Could not find the necessary `safetensors` weights " in str (error_context .exception )
620629
621630 # text encoder has fp16 variants so we can load it
622631 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -675,7 +684,7 @@ def test_download_safetensors_variant_does_not_exist_for_model(self):
675684 use_safetensors = use_safetensors ,
676685 )
677686
678- assert "Error no file name " in str (error_context .exception )
687+ assert "Could not find the necessary `safetensors` weights " in str (error_context .exception )
679688
680689 def test_download_bin_variant_does_not_exist_for_model (self ):
681690 variant = "no_ema"
0 commit comments