11import contextlib
22import io
33import re
4- from shutil import ignore_patterns
54import unittest
65
76import torch
@@ -213,7 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
213212
214213class VariantCompatibleSiblingsTest (unittest .TestCase ):
215214 def test_only_non_variants_downloaded (self ):
216- ignore_patterns = ["*.bin" , "*.msgpack" ]
215+ ignore_patterns = ["*.bin" ]
217216 variant = "fp16"
218217 filenames = [
219218 f"vae/diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -230,7 +229,7 @@ def test_only_non_variants_downloaded(self):
230229 assert all (variant not in f for f in model_filenames )
231230
232231 def test_only_variants_downloaded (self ):
233- ignore_patterns = ["*.bin" , "*.msgpack" ]
232+ ignore_patterns = ["*.bin" ]
234233 variant = "fp16"
235234 filenames = [
236235 f"vae/diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -247,7 +246,7 @@ def test_only_variants_downloaded(self):
247246 assert all (variant in f for f in model_filenames )
248247
249248 def test_mixed_variants_downloaded (self ):
250- ignore_patterns = ["*.bin" , "*.msgpack" ]
249+ ignore_patterns = ["*.bin" ]
251250 variant = "fp16"
252251 non_variant_file = "text_encoder/model.safetensors"
253252 filenames = [
@@ -263,7 +262,7 @@ def test_mixed_variants_downloaded(self):
263262 assert all (variant in f if f != non_variant_file else variant not in f for f in model_filenames )
264263
265264 def test_non_variants_in_main_dir_downloaded (self ):
266- ignore_patterns = ["*.bin" , "*.msgpack" ]
265+ ignore_patterns = ["*.bin" ]
267266 variant = "fp16"
268267 filenames = [
269268 f"diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -277,6 +276,7 @@ def test_non_variants_in_main_dir_downloaded(self):
277276 assert all (variant not in f for f in model_filenames )
278277
279278 def test_variants_in_main_dir_downloaded (self ):
279+ ignore_patterns = ["*.bin" ]
280280 variant = "fp16"
281281 filenames = [
282282 f"diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -292,7 +292,7 @@ def test_variants_in_main_dir_downloaded(self):
292292 assert all (variant in f for f in model_filenames )
293293
294294 def test_mixed_variants_in_main_dir_downloaded (self ):
295- use_safetensors = True
295+ ignore_patterns = [ "*.bin" ]
296296 variant = "fp16"
297297 non_variant_file = "model.safetensors"
298298 filenames = [
@@ -306,7 +306,7 @@ def test_mixed_variants_in_main_dir_downloaded(self):
306306 assert all (variant in f if f != non_variant_file else variant not in f for f in model_filenames )
307307
308308 def test_sharded_variants_in_main_dir_downloaded (self ):
309- use_safetensors = True
309+ ignore_patterns = [ "*.bin" ]
310310 variant = "fp16"
311311 filenames = [
312312 "diffusion_pytorch_model.safetensors.index.json" ,
@@ -323,7 +323,7 @@ def test_sharded_variants_in_main_dir_downloaded(self):
323323 assert all (variant in f for f in model_filenames )
324324
325325 def test_mixed_sharded_and_variant_in_main_dir_downloaded (self ):
326- use_safetensors = True
326+ ignore_patterns = [ "*.bin" ]
327327 variant = "fp16"
328328 filenames = [
329329 "diffusion_pytorch_model.safetensors.index.json" ,
@@ -338,7 +338,7 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
338338 assert all (variant in f for f in model_filenames )
339339
340340 def test_mixed_sharded_non_variants_in_main_dir_downloaded (self ):
341- use_safetensors = True
341+ ignore_patterns = [ "*.bin" ]
342342 variant = "fp16"
343343 filenames = [
344344 f"diffusion_pytorch_model.safetensors.index.{ variant } .json" ,
@@ -355,7 +355,7 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
355355 assert all (variant not in f for f in model_filenames )
356356
357357 def test_sharded_non_variants_downloaded (self ):
358- use_safetensors = True
358+ ignore_patterns = [ "*.bin" ]
359359 variant = "fp16"
360360 filenames = [
361361 f"unet/diffusion_pytorch_model.safetensors.index.{ variant } .json" ,
@@ -372,7 +372,7 @@ def test_sharded_non_variants_downloaded(self):
372372 assert all (variant not in f for f in model_filenames )
373373
374374 def test_sharded_variants_downloaded (self ):
375- use_safetensors = True
375+ ignore_patterns = [ "*.bin" ]
376376 variant = "fp16"
377377 filenames = [
378378 f"unet/diffusion_pytorch_model.safetensors.index.{ variant } .json" ,
@@ -387,9 +387,10 @@ def test_sharded_variants_downloaded(self):
387387 filenames , variant = variant , ignore_patterns = ignore_patterns
388388 )
389389 assert all (variant in f for f in model_filenames )
390+ assert model_filenames == variant_filenames
390391
391392 def test_single_variant_with_sharded_non_variant_downloaded (self ):
392- use_safetensors = True
393+ ignore_patterns = [ "*.bin" ]
393394 variant = "fp16"
394395 filenames = [
395396 "unet/diffusion_pytorch_model.safetensors.index.json" ,
@@ -404,7 +405,7 @@ def test_single_variant_with_sharded_non_variant_downloaded(self):
404405 assert all (variant in f for f in model_filenames )
405406
406407 def test_mixed_single_variant_with_sharded_non_variant_downloaded (self ):
407- use_safetensors = True
408+ ignore_patterns = [ "*.bin" ]
408409 variant = "fp16"
409410 allowed_non_variant = "unet"
410411 filenames = [
@@ -424,7 +425,7 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
424425 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
425426
426427 def test_sharded_mixed_variants_downloaded (self ):
427- use_safetensors = True
428+ ignore_patterns = [ "*.bin" ]
428429 variant = "fp16"
429430 allowed_non_variant = "unet"
430431 filenames = [
@@ -445,56 +446,30 @@ def test_sharded_mixed_variants_downloaded(self):
445446 )
446447 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
447448
448- def test_variant_ignored_if_use_safetensors (self ):
449- use_safetensors = True
450- variant = "fp16"
451- filenames = [
452- f"vae/diffusion_pytorch_model.{ variant } .bin" ,
453- f"text_encoder/model.{ variant } .bin" ,
454- f"unet/diffusion_pytorch_model.{ variant } .bin" ,
455- "vae/diffusion_pytorch_model.safetensors" ,
456- "text_encoder/model.safetensors" ,
457- "unet/diffusion_pytorch_model.safetensors" ,
458- ]
459- model_filenames , variant_filenames = variant_compatible_siblings (
460- filenames , variant = variant , ignore_patterns = ignore_patterns
461- )
462- assert all (variant not in f for f in model_filenames )
463-
464449 def test_downloading_when_no_variant_exists (self ):
465- use_safetensors = True
450+ ignore_patterns = [ "*.bin" ]
466451 variant = "fp16"
467452 filenames = ["model.safetensors" , "diffusion_pytorch_model.safetensors" ]
468- model_filenames , variant_filenames = variant_compatible_siblings (
469- filenames , variant = variant , ignore_patterns = ignore_patterns
470- )
471- assert len (model_filenames ) != 0
472-
473- def test_downloading_use_safetensors_no_variant_exists (self ):
474- use_safetensors = True
475- variant = "fp16"
476- filenames = ["text_encoder/model.bin" , "unet/diffusion_pytorch_model.bin" ]
477- model_filenames , variant_filenames = variant_compatible_siblings (
478- filenames , variant = variant , ignore_patterns = ignore_patterns
479- )
480- assert all (variant not in f for f in model_filenames )
453+ with self .assertRaisesRegex (ValueError , "but no such modeling files are available. " ):
454+ model_filenames , variant_filenames = variant_compatible_siblings (
455+ filenames , variant = variant , ignore_patterns = ignore_patterns
456+ )
481457
482458 def test_downloading_use_safetensors_false (self ):
483- use_safetensors = False
484- variant = "fp16"
459+ ignore_patterns = ["*.safetensors" ]
485460 filenames = [
486461 "text_encoder/model.bin" ,
487462 "unet/diffusion_pytorch_model.bin" ,
488463 "unet/diffusion_pytorch_model.safetensors" ,
489464 ]
490465 model_filenames , variant_filenames = variant_compatible_siblings (
491- filenames , variant = variant , ignore_patterns = ignore_patterns
466+ filenames , variant = None , ignore_patterns = ignore_patterns
492467 )
493468
494469 assert all (".safetensors" not in f for f in model_filenames )
495470
496471 def test_non_variant_in_main_dir_with_variant_in_subfolder (self ):
497- use_safetensors = True
472+ ignore_patterns = [ "*.bin" ]
498473 variant = "fp16"
499474 allowed_non_variant = "diffusion_pytorch_model.safetensors"
500475 filenames = [
@@ -506,8 +481,8 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
506481 )
507482 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
508483
509- def test_download_variants_when_component_has_no_variant (self ):
510- use_safetensors = True
484+ def test_download_variants_when_component_has_no_safetensors_variant (self ):
485+ ignore_patterns = None
511486 variant = "fp16"
512487 filenames = [
513488 f"unet/diffusion_pytorch_model.{ variant } .bin" ,
@@ -522,8 +497,8 @@ def test_download_variants_when_component_has_no_variant(self):
522497 f"vae/diffusion_pytorch_model.{ variant } .safetensors" ,
523498 } == model_filenames
524499
525- def test_download_sharded_variants_when_component_has_no_safetensors_variant (self ):
526- use_safetensors = True
500+ def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant (self ):
501+ ignore_patterns = [ "*.bin" ]
527502 variant = "fp16"
528503 filenames = [
529504 f"vae/diffusion_pytorch_model.bin.index.{ variant } .json" ,
@@ -538,13 +513,13 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant(sel
538513 "unet/diffusion_pytorch_model-00003-of-00003.safetensors" ,
539514 f"vae/diffusion_pytorch_model.{ variant } -00001-of-00002.bin" ,
540515 ]
541- model_filenames , variant_filenames = variant_compatible_siblings (
542- filenames , variant = variant , ignore_patterns = ignore_patterns
543- )
544- assert all ( variant not in f for f in model_filenames )
516+ with self . assertRaisesRegex ( ValueError , "but no such modeling files are available. " ):
517+ model_filenames , variant_filenames = variant_compatible_siblings (
518+ filenames , variant = variant , ignore_patterns = ignore_patterns
519+ )
545520
546521 def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false (self ):
547- use_safetensors = False
522+ ignore_patterns = [ "*.safetensors" ]
548523 allowed_non_variant = "unet"
549524 variant = "fp16"
550525 filenames = [
0 commit comments