@@ -538,38 +538,26 @@ def test_download_variant_partly(self):
538538            variant  =  "no_ema" 
539539
540540            with  tempfile .TemporaryDirectory () as  tmpdirname :
541-                 if  use_safetensors :
542-                     with  self .assertRaises (OSError ) as  error_context :
543-                         tmpdirname  =  StableDiffusionPipeline .download (
544-                             "hf-internal-testing/stable-diffusion-all-variants" ,
545-                             cache_dir = tmpdirname ,
546-                             variant = variant ,
547-                             use_safetensors = use_safetensors ,
548-                         )
549-                     assert  "Could not find the necessary `safetensors` weights"  in  str (error_context .exception )
550-                 else :
551-                     tmpdirname  =  StableDiffusionPipeline .download (
552-                         "hf-internal-testing/stable-diffusion-all-variants" ,
553-                         cache_dir = tmpdirname ,
554-                         variant = variant ,
555-                         use_safetensors = use_safetensors ,
556-                     )
557-                     all_root_files  =  [t [- 1 ] for  t  in  os .walk (tmpdirname )]
558-                     files  =  [item  for  sublist  in  all_root_files  for  item  in  sublist ]
541+                 tmpdirname  =  StableDiffusionPipeline .download (
542+                     "hf-internal-testing/stable-diffusion-all-variants" ,
543+                     cache_dir = tmpdirname ,
544+                     variant = variant ,
545+                     use_safetensors = use_safetensors ,
546+                 )
547+                 all_root_files  =  [t [- 1 ] for  t  in  os .walk (tmpdirname )]
548+                 files  =  [item  for  sublist  in  all_root_files  for  item  in  sublist ]
559549
560-                     unet_files  =  os .listdir (os .path .join (tmpdirname , "unet" ))
561- 
562-                     # Some of the downloaded files should be a non-variant file, check: 
563-                     # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet 
564-                     assert  len (files ) ==  15 , f"We should only download 15 files, not { len (files )}  
565-                     # only unet has "no_ema" variant 
566-                     assert  f"diffusion_pytorch_model.{ variant } { this_format }   in  unet_files 
567-                     assert  len ([f  for  f  in  files  if  f .endswith (f"{ variant } { this_format }  )]) ==  1 
568-                     # vae, safety_checker and text_encoder should have no variant 
569-                     assert  (
570-                         sum (f .endswith (this_format ) and  not  f .endswith (f"{ variant } { this_format }  ) for  f  in  files ) ==  3 
571-                     )
572-                     assert  not  any (f .endswith (other_format ) for  f  in  files )
550+                 unet_files  =  os .listdir (os .path .join (tmpdirname , "unet" ))
551+ 
552+                 # Some of the downloaded files should be a non-variant file, check: 
553+                 # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet 
554+                 assert  len (files ) ==  15 , f"We should only download 15 files, not { len (files )}  
555+                 # only unet has "no_ema" variant 
556+                 assert  f"diffusion_pytorch_model.{ variant } { this_format }   in  unet_files 
557+                 assert  len ([f  for  f  in  files  if  f .endswith (f"{ variant } { this_format }  )]) ==  1 
558+                 # vae, safety_checker and text_encoder should have no variant 
559+                 assert  sum (f .endswith (this_format ) and  not  f .endswith (f"{ variant } { this_format }  ) for  f  in  files ) ==  3 
560+                 assert  not  any (f .endswith (other_format ) for  f  in  files )
573561
574562    def  test_download_variants_with_sharded_checkpoints (self ):
575563        # Here we test for downloading of "variant" files belonging to the `unet` and 
0 commit comments