@@ -212,6 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
212212
213213class  VariantCompatibleSiblingsTest (unittest .TestCase ):
214214    def  test_only_non_variants_downloaded (self ):
215+         use_safetensors  =  True 
215216        variant  =  "fp16" 
216217        filenames  =  [
217218            f"vae/diffusion_pytorch_model.{ variant }  ,
@@ -222,10 +223,13 @@ def test_only_non_variants_downloaded(self):
222223            "unet/diffusion_pytorch_model.safetensors" ,
223224        ]
224225
225-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = None )
226+         model_filenames , variant_filenames  =  variant_compatible_siblings (
227+             filenames , variant = None , use_safetensors = use_safetensors 
228+         )
226229        assert  all (variant  not  in f  for  f  in  model_filenames )
227230
228231    def  test_only_variants_downloaded (self ):
232+         use_safetensors  =  True 
229233        variant  =  "fp16" 
230234        filenames  =  [
231235            f"vae/diffusion_pytorch_model.{ variant }  ,
@@ -236,10 +240,13 @@ def test_only_variants_downloaded(self):
236240            "unet/diffusion_pytorch_model.safetensors" ,
237241        ]
238242
239-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
243+         model_filenames , variant_filenames  =  variant_compatible_siblings (
244+             filenames , variant = variant , use_safetensors = use_safetensors 
245+         )
240246        assert  all (variant  in  f  for  f  in  model_filenames )
241247
242248    def  test_mixed_variants_downloaded (self ):
249+         use_safetensors  =  True 
243250        variant  =  "fp16" 
244251        non_variant_file  =  "text_encoder/model.safetensors" 
245252        filenames  =  [
@@ -249,21 +256,27 @@ def test_mixed_variants_downloaded(self):
249256            f"unet/diffusion_pytorch_model.{ variant }  ,
250257            "unet/diffusion_pytorch_model.safetensors" ,
251258        ]
252-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
259+         model_filenames , variant_filenames  =  variant_compatible_siblings (
260+             filenames , variant = variant , use_safetensors = use_safetensors 
261+         )
253262        assert  all (variant  in  f  if  f  !=  non_variant_file  else  variant  not  in f  for  f  in  model_filenames )
254263
255264    def  test_non_variants_in_main_dir_downloaded (self ):
265+         use_safetensors  =  True 
256266        variant  =  "fp16" 
257267        filenames  =  [
258268            f"diffusion_pytorch_model.{ variant }  ,
259269            "diffusion_pytorch_model.safetensors" ,
260270            "model.safetensors" ,
261271            f"model.{ variant }  ,
262272        ]
263-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = None )
273+         model_filenames , variant_filenames  =  variant_compatible_siblings (
274+             filenames , variant = None , use_safetensors = use_safetensors 
275+         )
264276        assert  all (variant  not  in f  for  f  in  model_filenames )
265277
266278    def  test_variants_in_main_dir_downloaded (self ):
279+         use_safetensors  =  True 
267280        variant  =  "fp16" 
268281        filenames  =  [
269282            f"diffusion_pytorch_model.{ variant }  ,
@@ -273,21 +286,27 @@ def test_variants_in_main_dir_downloaded(self):
273286            f"diffusion_pytorch_model.{ variant }  ,
274287            "diffusion_pytorch_model.safetensors" ,
275288        ]
276-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
289+         model_filenames , variant_filenames  =  variant_compatible_siblings (
290+             filenames , variant = variant , use_safetensors = use_safetensors 
291+         )
277292        assert  all (variant  in  f  for  f  in  model_filenames )
278293
279294    def  test_mixed_variants_in_main_dir_downloaded (self ):
295+         use_safetensors  =  True 
280296        variant  =  "fp16" 
281297        non_variant_file  =  "model.safetensors" 
282298        filenames  =  [
283299            f"diffusion_pytorch_model.{ variant }  ,
284300            "diffusion_pytorch_model.safetensors" ,
285301            "model.safetensors" ,
286302        ]
287-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
303+         model_filenames , variant_filenames  =  variant_compatible_siblings (
304+             filenames , variant = variant , use_safetensors = use_safetensors 
305+         )
288306        assert  all (variant  in  f  if  f  !=  non_variant_file  else  variant  not  in f  for  f  in  model_filenames )
289307
290308    def  test_sharded_variants_in_main_dir_downloaded (self ):
309+         use_safetensors  =  True 
291310        variant  =  "fp16" 
292311        filenames  =  [
293312            "diffusion_pytorch_model.safetensors.index.json" ,
@@ -298,10 +317,13 @@ def test_sharded_variants_in_main_dir_downloaded(self):
298317            f"diffusion_pytorch_model.{ variant }  ,
299318            f"diffusion_pytorch_model.safetensors.index.{ variant }  ,
300319        ]
301-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
320+         model_filenames , variant_filenames  =  variant_compatible_siblings (
321+             filenames , variant = variant , use_safetensors = use_safetensors 
322+         )
302323        assert  all (variant  in  f  for  f  in  model_filenames )
303324
304325    def  test_mixed_sharded_and_variant_in_main_dir_downloaded (self ):
326+         use_safetensors  =  True 
305327        variant  =  "fp16" 
306328        filenames  =  [
307329            "diffusion_pytorch_model.safetensors.index.json" ,
@@ -310,10 +332,13 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
310332            "diffusion_pytorch_model-00003-of-00003.safetensors" ,
311333            f"diffusion_pytorch_model.{ variant }  ,
312334        ]
313-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
335+         model_filenames , variant_filenames  =  variant_compatible_siblings (
336+             filenames , variant = variant , use_safetensors = use_safetensors 
337+         )
314338        assert  all (variant  in  f  for  f  in  model_filenames )
315339
316340    def  test_mixed_sharded_non_variants_in_main_dir_downloaded (self ):
341+         use_safetensors  =  True 
317342        variant  =  "fp16" 
318343        filenames  =  [
319344            f"diffusion_pytorch_model.safetensors.index.{ variant }  ,
@@ -324,10 +349,13 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
324349            f"diffusion_pytorch_model.{ variant }  ,
325350            f"diffusion_pytorch_model.{ variant }  ,
326351        ]
327-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = None )
352+         model_filenames , variant_filenames  =  variant_compatible_siblings (
353+             filenames , variant = None , use_safetensors = use_safetensors 
354+         )
328355        assert  all (variant  not  in f  for  f  in  model_filenames )
329356
330357    def  test_sharded_non_variants_downloaded (self ):
358+         use_safetensors  =  True 
331359        variant  =  "fp16" 
332360        filenames  =  [
333361            f"unet/diffusion_pytorch_model.safetensors.index.{ variant }  ,
@@ -338,10 +366,13 @@ def test_sharded_non_variants_downloaded(self):
338366            f"unet/diffusion_pytorch_model.{ variant }  ,
339367            f"unet/diffusion_pytorch_model.{ variant }  ,
340368        ]
341-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = None )
369+         model_filenames , variant_filenames  =  variant_compatible_siblings (
370+             filenames , variant = None , use_safetensors = use_safetensors 
371+         )
342372        assert  all (variant  not  in f  for  f  in  model_filenames )
343373
344374    def  test_sharded_variants_downloaded (self ):
375+         use_safetensors  =  True 
345376        variant  =  "fp16" 
346377        filenames  =  [
347378            f"unet/diffusion_pytorch_model.safetensors.index.{ variant }  ,
@@ -352,10 +383,13 @@ def test_sharded_variants_downloaded(self):
352383            f"unet/diffusion_pytorch_model.{ variant }  ,
353384            f"unet/diffusion_pytorch_model.{ variant }  ,
354385        ]
355-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
386+         model_filenames , variant_filenames  =  variant_compatible_siblings (
387+             filenames , variant = variant , use_safetensors = use_safetensors 
388+         )
356389        assert  all (variant  in  f  for  f  in  model_filenames )
357390
358391    def  test_single_variant_with_sharded_non_variant_downloaded (self ):
392+         use_safetensors  =  True 
359393        variant  =  "fp16" 
360394        filenames  =  [
361395            "unet/diffusion_pytorch_model.safetensors.index.json" ,
@@ -364,10 +398,13 @@ def test_single_variant_with_sharded_non_variant_downloaded(self):
364398            "unet/diffusion_pytorch_model-00003-of-00003.safetensors" ,
365399            f"unet/diffusion_pytorch_model.{ variant }  ,
366400        ]
367-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
401+         model_filenames , variant_filenames  =  variant_compatible_siblings (
402+             filenames , variant = variant , use_safetensors = use_safetensors 
403+         )
368404        assert  all (variant  in  f  for  f  in  model_filenames )
369405
370406    def  test_mixed_single_variant_with_sharded_non_variant_downloaded (self ):
407+         use_safetensors  =  True 
371408        variant  =  "fp16" 
372409        allowed_non_variant  =  "unet" 
373410        filenames  =  [
@@ -381,10 +418,13 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
381418            "unet/diffusion_pytorch_model-00002-of-00003.safetensors" ,
382419            "unet/diffusion_pytorch_model-00003-of-00003.safetensors" ,
383420        ]
384-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
421+         model_filenames , variant_filenames  =  variant_compatible_siblings (
422+             filenames , variant = variant , use_safetensors = use_safetensors 
423+         )
385424        assert  all (variant  in  f  if  allowed_non_variant  not  in f  else  variant  not  in f  for  f  in  model_filenames )
386425
387426    def  test_sharded_mixed_variants_downloaded (self ):
427+         use_safetensors  =  True 
388428        variant  =  "fp16" 
389429        allowed_non_variant  =  "unet" 
390430        filenames  =  [
@@ -400,15 +440,72 @@ def test_sharded_mixed_variants_downloaded(self):
400440            "vae/diffusion_pytorch_model-00002-of-00003.safetensors" ,
401441            "vae/diffusion_pytorch_model-00003-of-00003.safetensors" ,
402442        ]
403-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
443+         model_filenames , variant_filenames  =  variant_compatible_siblings (
444+             filenames , variant = variant , use_safetensors = use_safetensors 
445+         )
404446        assert  all (variant  in  f  if  allowed_non_variant  not  in f  else  variant  not  in f  for  f  in  model_filenames )
405447
448+     def  test_variant_ignored_if_use_safetensors (self ):
449+         use_safetensors  =  True 
450+         variant  =  "fp16" 
451+         filenames  =  [
452+             f"vae/diffusion_pytorch_model.{ variant }  ,
453+             f"text_encoder/model.{ variant }  ,
454+             f"unet/diffusion_pytorch_model.{ variant }  ,
455+             "vae/diffusion_pytorch_model.safetensors" ,
456+             "text_encoder/model.safetensors" ,
457+             "unet/diffusion_pytorch_model.safetensors" ,
458+         ]
459+         model_filenames , variant_filenames  =  variant_compatible_siblings (
460+             filenames , variant = variant , use_safetensors = use_safetensors 
461+         )
462+         assert  all (variant  not  in f  for  f  in  model_filenames )
463+ 
406464    def  test_downloading_when_no_variant_exists (self ):
465+         use_safetensors  =  True 
407466        variant  =  "fp16" 
408467        filenames  =  ["model.safetensors" , "diffusion_pytorch_model.safetensors" ]
409-         model_filenames , variant_filenames  =  variant_compatible_siblings (filenames , variant = variant )
468+         model_filenames , variant_filenames  =  variant_compatible_siblings (
469+             filenames , variant = variant , use_safetensors = use_safetensors 
470+         )
410471        assert  len (model_filenames ) !=  0 
411472
473+     def  test_downloading_use_safetensors_no_variant_exists (self ):
474+         use_safetensors  =  True 
475+         variant  =  "fp16" 
476+         filenames  =  ["text_encoder/model.bin" , "unet/diffusion_pytorch_model.bin" ]
477+         model_filenames , variant_filenames  =  variant_compatible_siblings (
478+             filenames , variant = variant , use_safetensors = use_safetensors 
479+         )
480+         assert  all (variant  not  in f  for  f  in  model_filenames )
481+ 
482+     def  test_downloading_use_safetensors_false (self ):
483+         use_safetensors  =  False 
484+         variant  =  "fp16" 
485+         filenames  =  [
486+             "text_encoder/model.bin" ,
487+             "unet/diffusion_pytorch_model.bin" ,
488+             "unet/diffusion_pytorch_model.safetensors" ,
489+         ]
490+         model_filenames , variant_filenames  =  variant_compatible_siblings (
491+             filenames , variant = variant , use_safetensors = use_safetensors 
492+         )
493+ 
494+         assert  all (".safetensors"  not  in f  for  f  in  model_filenames )
495+ 
496+     def  test_non_variant_in_main_dir_with_variant_in_subfolder (self ):
497+         use_safetensors  =  True 
498+         variant  =  "fp16" 
499+         allowed_non_variant  =  "diffusion_pytorch_model.safetensors" 
500+         filenames  =  [
501+             f"unet/diffusion_pytorch_model.{ variant }  ,
502+             "diffusion_pytorch_model.safetensors" ,
503+         ]
504+         model_filenames , variant_filenames  =  variant_compatible_siblings (
505+             filenames , variant = variant , use_safetensors = use_safetensors 
506+         )
507+         assert  all (variant  in  f  if  allowed_non_variant  not  in f  else  variant  not  in f  for  f  in  model_filenames )
508+ 
412509
413510class  ProgressBarTests (unittest .TestCase ):
414511    def  get_dummy_components_image_generation (self ):
0 commit comments